// -*- C++ -*-
//
// Copyright (C) 1998, 1999, 2000, 2002  Los Alamos National Laboratory,
// Copyright (C) 1998, 1999, 2000, 2002  CodeSourcery, LLC
//
// This file is part of FreePOOMA.
//
// FreePOOMA is free software; you can redistribute it and/or modify it
// under the terms of the Expat license.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Expat
// license for more details.
//
// You should have received a copy of the Expat license along with
// FreePOOMA; see the file LICENSE.
//

//-----------------------------------------------------------------------------
// Class:
// Jacobi
// NinePointMatrix
//-----------------------------------------------------------------------------

#include <cmath>

#include "NinePointMatrix.h"

struct ComputeJacobi
{
  inline int lowerExtent(int) const { return 0; }
  inline int upperExtent(int) const { return 0; }

  template<class Tag>
  static void
  apply(
	const Array<2, NinePoint, Tag> &values,
	const Array<2, NinePoint, Tag> &A
	)
  {
    Interval<2> domain = A.domain();

    int i0, i1;
    int b0 = domain[0].first();
    int b1 = domain[1].first();
    int e0 = domain[0].last();
    int e1 = domain[1].last();

    for (i0 = b0; i0 <= e0; i0++)
       for (i1 = b1; i1 <= e1; i1++)
          values(i0, i1) = A(i0, i1).jacobi();
  }
};

struct ApplyJacobi
{
  static void set_steps (const int in) {steps_m = in;}

  inline int lowerExtent(int) const { return 0; }
  inline int upperExtent(int) const { return 0; }

  static GuardLayers<2> guardLayers() { return GuardLayers<2>(0); }

  template<class Tag>
  static void
  apply(
	const Array<2, double, Tag> &res,
	const Array<2, double, Tag> &rhs,
	const Array<2, NinePoint, Tag> &values
	)
  {
    Interval<2> domain = values.domain();

    Array<2, double, Brick> y(domain);

    int i0, i1;
    int b0 = domain[0].first();
    int b1 = domain[1].first();
    int e0 = domain[0].last();
    int e1 = domain[1].last();

    if (steps_m == 1) {
       for (i0 = b0; i0 <= e0; i0++)
          for (i1 = b1; i1 <= e1; i1++)
             res(i0, i1) = values(i0, i1).center() * rhs(i0, i1);
    } else if (steps_m > 1) {

// allocate some working vectors
       Array <2, double> t(Interval<1>(b0, e0), Interval<1>(b1, e1));
       t = 0;
       Array <2, double> u(Interval<1>(b0, e0), Interval<1>(b1, e1));
       u = 0;

// u is the result after one Jacobi step (a diagonal scaling)

       for (i0 = b0; i0 <= e0; i0++)
          for (i1 = b1; i1 <= e1; i1++)
             u(i0, i1) = values(i0, i1).center() * rhs(i0, i1);
       res = u;

       for (int count = 2; count <= steps_m; count++) {
          if (count%2 == 0) {

             i1 = b1;
             i0 = b0;
             t(i0, i1) = u(i0, i1) + 
                values(i0, i1).center() * res(i0, i1) + 
                values(i0, i1).east() * res(i0 + 1, i1) +
                values(i0, i1).north() * res(i0, i1 + 1) +
                values(i0, i1).northeast() * res(i0 + 1, i1 + 1);
      
             i1 = b1;
             for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
                t(i0, i1) = u(i0, i1) + 
                   values(i0, i1).west() * res(i0 - 1, i1) +
                   values(i0, i1).center() * res(i0, i1) + 
                   values(i0, i1).east() * res(i0 + 1, i1) +
                   values(i0, i1).northwest() * res(i0 - 1, i1 + 1) +
                   values(i0, i1).north() * res(i0, i1 + 1) +
                   values(i0, i1).northeast() * res(i0 + 1, i1 + 1);
             } // for i0
      
             i1 = b1;
             i0 = e0;
             t(i0, i1) = u(i0, i1) + 
                values(i0, i1).west() * res(i0 - 1, i1) +
                values(i0, i1).center() * res(i0, i1) + 
                values(i0, i1).northwest() * res(i0 - 1, i1 + 1) +
                values(i0, i1).north() * res(i0, i1 + 1);
      
             for (i1 = b1 + 1; i1 <= e1 - 1; i1++) {
                i0 = b0;
                t(i0, i1) = u(i0, i1) + 
                   values(i0, i1).south() * res(i0, i1 - 1) +
                   values(i0, i1).southeast() * res(i0 + 1, i1 - 1) +
                   values(i0, i1).center() * res(i0, i1) + 
                   values(i0, i1).east() * res(i0 + 1, i1) +
                   values(i0, i1).north() * res(i0, i1 + 1) +
                   values(i0, i1).northeast() * res(i0 + 1, i1 + 1);
      
                for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
                   t(i0, i1) = u(i0, i1) + 
                      values(i0, i1).southwest() * res(i0 - 1, i1 - 1) +
                      values(i0, i1).south() * res(i0, i1 - 1) +
                      values(i0, i1).southeast() * res(i0 + 1, i1 - 1) +
                      values(i0, i1).west() * res(i0 - 1, i1) +
                      values(i0, i1).center() * res(i0, i1) + 
                      values(i0, i1).east() * res(i0 + 1, i1) +
                      values(i0, i1).northwest() * res(i0 - 1, i1 + 1) +
                      values(i0, i1).north() * res(i0, i1 + 1) +
                      values(i0, i1).northeast() * res(i0 + 1, i1 + 1);
                   } // for i0 
      
                i0 = e0;
                t(i0, i1) = u(i0, i1) + 
                   values(i0, i1).southwest() * res(i0 - 1, i1 - 1) +
                   values(i0, i1).south() * res(i0, i1 - 1) +
                   values(i0, i1).west() * res(i0 - 1, i1) +
                   values(i0, i1).center() * res(i0, i1) + 
                   values(i0, i1).northwest() * res(i0 - 1, i1 + 1) +
                   values(i0, i1).north() * res(i0, i1 + 1);
                } // for i1 
      
             i1 = e1;
             i0 = b0;
             t(i0, i1) = u(i0, i1) + 
                values(i0, i1).south() * res(i0, i1 - 1) +
                values(i0, i1).southeast() * res(i0 + 1, i1 - 1) +
                values(i0, i1).center() * res(i0, i1) + 
                values(i0, i1).east() * res(i0 + 1, i1);
      
             i1 = e1;
             for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
                t(i0, i1) = u(i0, i1) + 
                   values(i0, i1).southwest() * res(i0 - 1, i1 - 1) +
                   values(i0, i1).south() * res(i0, i1 - 1) +
                   values(i0, i1).southeast() * res(i0 + 1, i1 - 1) +
                   values(i0, i1).west() * res(i0 - 1, i1) +
                   values(i0, i1).center() * res(i0, i1) + 
                   values(i0, i1).east() * res(i0 + 1, i1); } // for i0
      
             i1 = e1;
             i0 = e0;
             t(i0, i1) = u(i0, i1) + 
                values(i0, i1).southwest() * res(i0 - 1, i1 - 1) +
                values(i0, i1).south() * res(i0, i1 - 1) +
                values(i0, i1).west() * res(i0 - 1, i1) +
                values(i0, i1).center() * res(i0, i1);

          } else {

             i1 = b1;
             i0 = b0;
             res(i0, i1) = u(i0, i1) +
                values(i0, i1).center() * t(i0, i1) + 
                values(i0, i1).east() * t(i0 + 1, i1) +
                values(i0, i1).north() * t(i0, i1 + 1) +
                values(i0, i1).northeast() * t(i0 + 1, i1 + 1);
      
             i1 = b1;
             for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
                res(i0, i1) = u(i0, i1) + 
                   values(i0, i1).west() * t(i0 - 1, i1) +
                   values(i0, i1).center() * t(i0, i1) + 
                   values(i0, i1).east() * t(i0 + 1, i1) +
                   values(i0, i1).northwest() * t(i0 - 1, i1 + 1) +
                   values(i0, i1).north() * t(i0, i1 + 1) +
                   values(i0, i1).northeast() * t(i0 + 1, i1 + 1); } // for i0
      
             i1 = b1;
             i0 = e0;
             res(i0, i1) = u(i0, i1) +
                values(i0, i1).west() * t(i0 - 1, i1) +
                values(i0, i1).center() * t(i0, i1) + 
                values(i0, i1).northwest() * t(i0 - 1, i1 + 1) +
                values(i0, i1).north() * t(i0, i1 + 1);

      
             for (i1 = b1 + 1; i1 <= e1 - 1; i1++) {
                i0 = b0;
                res(i0, i1) = u(i0, i1) +
                   values(i0, i1).south() * t(i0, i1 - 1) +
                   values(i0, i1).southeast() * t(i0 + 1, i1 - 1) +
                   values(i0, i1).center() * t(i0, i1) + 
                   values(i0, i1).east() * t(i0 + 1, i1) +
                   values(i0, i1).north() * t(i0, i1 + 1) +
                   values(i0, i1).northeast() * t(i0 + 1, i1 + 1);
      
                for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
                   res(i0, i1) = u(i0, i1) +
                      values(i0, i1).southwest() * t(i0 - 1, i1 - 1) +
                      values(i0, i1).south() * t(i0, i1 - 1) +
                      values(i0, i1).southeast() * t(i0 + 1, i1 - 1) +
                      values(i0, i1).west() * t(i0 - 1, i1) +
                      values(i0, i1).center() * t(i0, i1) + 
                      values(i0, i1).east() * t(i0 + 1, i1) +
                      values(i0, i1).northwest() * t(i0 - 1, i1 + 1) +
                      values(i0, i1).north() * t(i0, i1 + 1) +
                      values(i0, i1).northeast() * t(i0 + 1, i1 + 1);
                   } // for i0 
      
                i0 = e0;
                res(i0, i1) = u(i0, i1) +
                   values(i0, i1).southwest() * t(i0 - 1, i1 - 1) +
                   values(i0, i1).south() * t(i0, i1 - 1) +
                   values(i0, i1).west() * t(i0 - 1, i1) +
                   values(i0, i1).center() * t(i0, i1) + 
                   values(i0, i1).northwest() * t(i0 - 1, i1 + 1) +
                   values(i0, i1).north() * t(i0, i1 + 1); } // for i1 

             i1 = e1;
             i0 = b0;
             res(i0, i1) = u(i0, i1) +
                values(i0, i1).south() * t(i0, i1 - 1) +
                values(i0, i1).southeast() * t(i0 + 1, i1 - 1) +
                values(i0, i1).center() * t(i0, i1) + 
                values(i0, i1).east() * t(i0 + 1, i1);
      
             i1 = e1;
             for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
                res(i0, i1) = u(i0, i1) +
                   values(i0, i1).southwest() * t(i0 - 1, i1 - 1) +
                   values(i0, i1).south() * t(i0, i1 - 1) +
                   values(i0, i1).southeast() * t(i0 + 1, i1 - 1) +
                   values(i0, i1).west() * t(i0 - 1, i1) +
                   values(i0, i1).center() * t(i0, i1) + 
                   values(i0, i1).east() * t(i0 + 1, i1); } // for i0
      
             i1 = e1;
             i0 = e0;
             res(i0, i1) = u(i0, i1) +
                values(i0, i1).southwest() * t(i0 - 1, i1 - 1) +
                values(i0, i1).south() * t(i0, i1 - 1) +
                values(i0, i1).west() * t(i0 - 1, i1) +
                values(i0, i1).center() * t(i0, i1);

          } // else
       } // for count
       if (steps_m%2 == 0) res = t;  // clean up even case
    } // if steps_m > 1

    else if (steps_m == 0) {

// copy the rhs into the outgoing result

       for (i0 = b0; i0 <= e0; i0++)
          for (i1 = b1; i1 <= e1; i1++)
             res(i0, i1) = rhs(i0, i1); }

  } // apply()

private:
   static int steps_m;
};

template<class EngineTag>
class JacobiPreconditioner
{
public:
  typedef Array<2, NinePoint, EngineTag> Values_t;

  inline Values_t &values() { return values_m; }
  inline const Values_t &values() const { return values_m; }

  inline const Interval<2> &domain() const { return domain_m; }

  template<class Layout, class Domain>
  JacobiPreconditioner(const Layout &layout, const Domain &domain)
    : values_m(layout), domain_m(domain)
  {
  }

  static GuardLayers<2> guardLayers() { return GuardLayers<2>(0); }

  void operator()(
		  const Array<2, double, EngineTag> &x,
		  const Array<2, double, EngineTag> &y
		  ) const
  {
    GuardedPatchEvaluator<MainEvaluatorTag>::evaluate(y, x, values_m,
						      ApplyJacobi(),
						      domain_m);
  }

private:

  Values_t values_m;
  Interval<2> domain_m;
  double zero_m;
};

template<class Matrix, class Method>
struct GeneratePreconditioner;

struct NinePointJacobiTag { };

template<class EngineTag>
struct GeneratePreconditioner<NinePointMatrix<EngineTag>,
  NinePointJacobiTag>
{
  typedef JacobiPreconditioner<EngineTag> Type_t;
  typedef NinePointMatrix<EngineTag> Input_t;

  static void fill(
		   const Input_t &input,
		   const Type_t &preconditioner
		   )
  {
    GuardedPatchEvaluator<MainEvaluatorTag>::evaluate(preconditioner.values(),
						      input.values(),
						      ComputeJacobi(),
						      preconditioner.domain());
  }

  static void set_steps (const int steps) { ApplyJacobi::set_steps(steps); }

};

    int ApplyJacobi::steps_m = 1;
