
#include <lua.h>
#include <lauxlib.h>

#include "ode-solver.h"

static char const * const ode_solver_type_name = "GSL.ode_solver";

static int
ode_solver_dealloc (lua_State *L);

static const struct luaL_Reg ode_solver_methods[] = {
  {"__gc",          ode_solver_dealloc},
  {NULL, NULL}
};

struct ode_solver *
ode_solver_push_new (lua_State *L, const gsl_odeiv_step_type *type,
		     size_t dim, double eps_abs, double eps_rel,
		     size_t multiplicity)
{
  struct ode_solver *s;

  s = lua_newuserdata (L, sizeof (struct ode_solver));

  s->params->L = L;
  s->params->y = gsl_vector_alloc (dim);
  s->params->n = dim;
  if (multiplicity > 1)
    s->params->J = gsl_vector_alloc (dim * dim * multiplicity);
  else
    s->params->J = NULL;

  s->step = gsl_odeiv_step_alloc (type, dim);
  s->ctrl = gsl_odeiv_control_y_new (eps_abs, eps_rel);
  s->evol = gsl_odeiv_evolve_alloc (dim);

  luaL_getmetatable (L, ode_solver_type_name);
  lua_setmetatable (L, -2);

  return s;
}


struct ode_solver *
check_ode_solver (lua_State *L, int index)
{
  return luaL_checkudata (L, index, ode_solver_type_name);
}

int
ode_solver_dealloc (lua_State *L)
{
  struct ode_solver *s = check_ode_solver (L, 1);

  gsl_odeiv_evolve_free  (s->evol);
  gsl_odeiv_control_free (s->ctrl);
  gsl_odeiv_step_free    (s->step);

  gsl_vector_free (s->params->y);
  if (s->params->J)
    gsl_vector_free (s->params->J);

  return 0;
}

void
ode_solver_register (lua_State *L)
{
  /* ode solver declaration */
  luaL_newmetatable (L, ode_solver_type_name);
  luaL_register (L, NULL, ode_solver_methods);
  lua_pop (L, 1);
}

void
ode_solver_set (struct ode_solver *s, gsl_odeiv_system *sys)
{
  s->system->function  = sys->function;
  s->system->jacobian  = sys->jacobian;
  s->system->dimension = s->params->n;
  s->system->params    = s->params;
}
