/*  
    This code is written by <albanese@fbk.it>.
    (C) 2008 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.
    
    See DWT in the GSL Library.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/


#include <Python.h>
#include <numpy/arrayobject.h>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <string.h>
#include <gsl/gsl_wavelet.h>
#include <gsl/gsl_math.h>


static PyObject *_dwt_dwt(PyObject *self, PyObject *args, PyObject *keywds)
{
  PyObject *x = NULL; PyObject *xcopy = NULL;
  
  char wf;
  int k, n;
  double *_xcopy;

  /* Parse Tuple*/
  static char *kwlist[] = {"x", "wf", "k", NULL};
  if (!PyArg_ParseTupleAndKeywords(args, keywds, "Oci", kwlist, &x, &wf, &k))
    return NULL;

   
  /* Build xcopy */
  xcopy = PyArray_FROM_OTF(x, NPY_DOUBLE, NPY_OUT_ARRAY | NPY_ENSURECOPY);
  if (xcopy == NULL) return NULL;
  
  n = (int) PyArray_DIM(xcopy, 0);
  _xcopy = (double *) PyArray_DATA(xcopy);
  
  gsl_wavelet *w;
  gsl_wavelet_workspace *work;
  
  switch (wf)
    {
    case 'd':
      w = gsl_wavelet_alloc (gsl_wavelet_daubechies, k);
      break;
      
    case 'h':
      w = gsl_wavelet_alloc (gsl_wavelet_haar, k);
      break;

    case 'b':
      w = gsl_wavelet_alloc (gsl_wavelet_bspline, k);
      break;

    default:
      PyErr_SetString(PyExc_ValueError, "invalid wavelet type (must be 'd', 'h', or 'b')");
      return NULL;
    }
  
  work = gsl_wavelet_workspace_alloc (n);
  
  gsl_wavelet_transform_forward (w, _xcopy, 1, n, work);
    
  gsl_wavelet_free (w);
  gsl_wavelet_workspace_free (work);
  
  return Py_BuildValue("N", xcopy);
}


static PyObject *_dwt_idwt(PyObject *self, PyObject *args, PyObject *keywds)
{
  PyObject *x = NULL; PyObject *xcopy = NULL;
  
  char wf;
  int k, n;
  double *_xcopy;

  /* Parse Tuple*/
  static char *kwlist[] = {"X", "wf", "k", NULL};
  if (!PyArg_ParseTupleAndKeywords(args, keywds, "Oci", kwlist, &x, &wf, &k))
    return NULL;

   
  /* Build xcopy */
  xcopy = PyArray_FROM_OTF(x, NPY_DOUBLE, NPY_OUT_ARRAY | NPY_ENSURECOPY);
  if (xcopy == NULL) return NULL;
  
  n = (int) PyArray_DIM(xcopy, 0);
  _xcopy = (double *) PyArray_DATA(xcopy);
  
  gsl_wavelet *w;
  gsl_wavelet_workspace *work;
  
  switch (wf)
    {
    case 'd':
      w = gsl_wavelet_alloc (gsl_wavelet_daubechies, k);
      break;
      
    case 'h':
      w = gsl_wavelet_alloc (gsl_wavelet_haar, k);
      break;

    case 'b':
      w = gsl_wavelet_alloc (gsl_wavelet_bspline, k);
      break;

    default:
      PyErr_SetString(PyExc_ValueError, "invalid wavelet type (must be 'd', 'h', or 'b')");
      return NULL;
    }
  
  work = gsl_wavelet_workspace_alloc (n);
  
  gsl_wavelet_transform_inverse (w, _xcopy, 1, n, work);
  
  gsl_wavelet_free (w);
  gsl_wavelet_workspace_free (work);
  
  return Py_BuildValue("N", xcopy);
}


/* Doc strings: */
static char module_doc[]  = "Discrete Wavelet Transform Module from GSL";

static char _dwt_dwt_doc[] =
  "Discrete Wavelet Tranform\n\n"
  "Input\n\n"
  "  * *x*  - [1D numpy array float] data (the length is restricted to powers of two) \n"
  "  * *wf* - [string] wavelet type ('d': daubechies, 'h': haar, 'b': bspline)\n"
  "  * *k*  - [integer] member of the wavelet family\n\n"
  "    * daubechies: k = 4, 6, ..., 20 with k even\n"
  "    * haar: the only valid choice of k is k = 2\n"
  "    * bspline: k = 103, 105, 202, 204, 206, 208, 301, 303, 305 307, 309\n\n"
  "Output\n\n"
  "  * *X* - [1D numpy array float] discrete wavelet transform"
  ;

static char _dwt_idwt_doc[] =
  "Inverse Discrete Wavelet Tranform\n\n"
  "Input\n\n"
  "  * *X*  - [1D numpy array float] data\n"
  "  * *wf* - [string] wavelet type ('d': daubechies, 'h': haar, 'b': bspline)\n"
  "  * *k*  - [integer] member of the wavelet family\n\n"
  "    * daubechies: k = 4, 6, ..., 20 with k even\n"
  "    * haar: the only valid choice of k is k = 2\n"
  "    * bspline: k = 103, 105, 202, 204, 206, 208, 301, 303, 305 307, 309\n\n"
  "Output\n\n"
  "  * *x* - [1D numpy array float]"
  ;


/* Method table */
static PyMethodDef _dwt_methods[] = {
  {"dwt",
   (PyCFunction)_dwt_dwt,
   METH_VARARGS | METH_KEYWORDS,
   _dwt_dwt_doc},
  {"idwt",
   (PyCFunction)_dwt_idwt,
   METH_VARARGS | METH_KEYWORDS,
   _dwt_idwt_doc},
  {NULL, NULL, 0, NULL}
};

/* Init */
void init_dwt()
{
  Py_InitModule3("_dwt", _dwt_methods, module_doc);
  import_array();
}

