#!/usr/bin/env python
#-*- coding: latin-1 -*-

# gdalhelper.py
# Copyright 2008 Gregorio Díaz-Marta Mateos
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

# hillshade inspired by hillShadingPIL.py by John Barratt, Langarson Pty.
# Ltd. http://www.langarson.com.au/
# I don't think this is a derived work though.

"""Stuff using gdal and ogr libraries."""

import os
from math import atan, atan2, cos, degrees, pi, radians, sin, sqrt
from osgeo import gdal, osr

def add_srs_direct(name, a_srs, gt=None):
    """Add projection information to a raster file.

    Calls SetProjection and SetGeotransform on the dataset. Not every format
    supports this.

    """

    ds = gdal.Open(name, gdal.GA_Update)
    driver = ds.GetDriver()
    srs = osr.SpatialReference()
    srs.SetFromUserInput(a_srs)
    a_srs = srs.ExportToWkt()
    ds.SetProjection(a_srs)
    if not gt is None:
        ds.SetGeoTransform(gt)

    if driver.ShortName == 'AAIGrid':
        # This uses PAM but gdal itself is unable to retrieve projection
        # from the xml file, so we also create a prj file.
        prjname, ext = os.path.splitext(name)
        prjname = os.path.extsep.join((prjname, 'prj'))
        prjfile = open(prjname, 'w')
        prjfile.write('%s\n' % a_srs)
        prjfile.close()

def add_srs_through_vrt(name, a_srs, gt=None):
    """Add projection information to a raster file.

    This function creates a VRT copy of the dataset and calls SetProjection
    and SetGeotransform on it, then copies it to the original format). Useful
    for formats that don't allow update of existing files.

    """

    ds = gdal.Open(name, gdal.GA_ReadOnly)
    driver = ds.GetDriver()

    srs = osr.SpatialReference()
    srs.SetFromUserInput(a_srs)
    a_srs = srs.ExportToWkt()

    vrt_driver = gdal.GetDriverByName('VRT')
    vrt_ds = vrt_driver.CreateCopy('', ds)
    vrt_ds.SetProjection(a_srs)
    if not gt is None:
        vrt_ds.SetGeoTransform(gt)

    driver.CreateCopy(name, vrt_ds)

def add_srs(name, a_srs, gt=None):
    """Add projection information to a raster file."""

    try:
        add_srs_direct(name, a_srs, gt=gt)
    except AttributeError:
        add_srs_through_vrt(name, a_srs, gt=gt)

def hillshadeband(src_band, dst_band, scale=1.0, az=315.0, alt=45.0):
    """Create a hill shade version of a DEM."""
    az = radians(az)
    alt = radians(alt)
    sin_alt = sin(alt)
    cos_alt = cos(alt)
    comp_az = az - pi / 2
    dScale = 8.0 * scale
    src_data = src_band.ReadAsArray()
    dst_data = dst_band.ReadAsArray()
    xsize = dst_band.XSize
    ysize = dst_band.YSize

    for row in range(1, ysize - 1):
        for col in range(1, xsize - 1):
            window = src_data[row-1:row+2, col-1:col+2] 
            dx = ((window[0, 0] + 2 * window[0, 1] + window[0, 2]) - \
                  (window[2, 0] + 2 * window[2, 1] + window[2, 2])) / dScale
            dy = ((window[0, 2] + 2 * window[1, 2] + window[2, 2]) - \
                  (window[0, 0] + 2 * window[1, 0] + window[2, 0])) / dScale
            slope = pi / 2 - atan(sqrt(dx*dx + dy*dy))
            aspect = atan2(dx,dy)
            value = sin_alt * sin(slope) \
                  + cos_alt * cos(slope) * cos(comp_az - aspect)
            value = int(255 * (value + 1) / 2)
            dst_data[row, col] = value

    dst_band.WriteArray(dst_data)

def hillshade(src_name, dst_name=None, **kargs):
    """Create a hill shade version of a DEM."""
    if dst_name is None or dst_name == '':
        dst_name, ext = os.path.splitext(src_name)
        dst_name += '_shade'
        dst_name = os.path.extsep.join((dstname, 'tif'))

    src_ds = gdal.Open(src_name)
    xsize = src_ds.RasterXSize
    ysize = src_ds.RasterYSize
    prj = src_ds.GetProjection()
    gt = src_ds.GetGeoTransform()

    driver = gdal.GetDriverByName('GTiff')
    dst_ds = driver.Create(dst_name, xsize, ysize, eType=gdal.GDT_Int16)
    dst_ds.SetGeoTransform(gt)
    dst_ds.SetProjection(prj)

    src_band = src_ds.GetRasterBand(1)
    dst_band = dst_ds.GetRasterBand(1)

    hillshadeband(src_band, dst_band, *kargs)
    dst_ds.SetProjection(prj)

def translate(src_name, dst_name, of='GTiff', ot=None):
    """Converts raster data between different formats.

    This function provides some of the features of the gdal_translate program.

    """

    # XXX This should be done through a vrt dataset:
    #   1. Copy source to a vrt dataset.
    #   2. Change the datatype of the vrt dataset if we have to (HOW?).
    #   3. Copy the vrt dataset to the output format.
    # But I don't know how to code step 2.

    # Fetch source dataset

    src_ds = gdal.Open(src_name)
    xsize = src_ds.RasterXSize
    ysize = src_ds.RasterYSize
    bands = src_ds.RasterCount
    geotransform = src_ds.GetGeoTransform()
    if ot is None:
        ot = src_ds.GetRasterBand(1).DataType
    srs = src_ds.GetProjection()

    # Create destination dataset

    dst_driver = gdal.GetDriverByName(of)
    dst_ds = dst_driver.Create(dst_name, xsize, ysize, bands=bands, eType=ot)
    dst_ds.SetGeoTransform(geotransform)
    dst_ds.SetProjection(srs)

    # Create bands

    for idx in range(1, bands + 1):
        src_band = src_ds.GetRasterBand(idx)
        nodatavalue = src_band.GetNoDataValue()
        dst_band = dst_ds.GetRasterBand(idx)
        data = src_band.ReadAsArray()
        dst_band.WriteArray(data)
        if not nodatavalue is None:
            dst_band.SetNoDataValue(nodatavalue)

