// -*- C++ -*-
//
// Copyright (C) 1998, 1999, 2000, 2002  Los Alamos National Laboratory,
// Copyright (C) 1998, 1999, 2000, 2002  CodeSourcery, LLC
//
// This file is part of FreePOOMA.
//
// FreePOOMA is free software; you can redistribute it and/or modify it
// under the terms of the Expat license.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Expat
// license for more details.
//
// You should have received a copy of the Expat license along with
// FreePOOMA; see the file LICENSE.
//

//-----------------------------------------------------------------------------
// Class:
// Symmetric Gauss-Seidel (SGS)
// SymmetricFivePointMatrix
//-----------------------------------------------------------------------------

#include <cmath>

#include "FivePointMatrix.h"

struct ComputeSGS
{
  inline int lowerExtent(int) const { return 0; }
  inline int upperExtent(int) const { return 0; }  

  template<class Tag>
  static void
  apply(
	const Array<2, FivePoint, Tag> &values,
	const Array<2, FivePoint, Tag> &A
	)
  {
    Interval<2> domain = A.domain();

    int i0, i1;
    int b0 = domain[0].first();
    int b1 = domain[1].first();
    int e0 = domain[0].last();
    int e1 = domain[1].last();

    for (i0 = b0; i0 <= e0; i0++)
       for (i1 = b1; i1 <= e1; i1++)
          values(i0, i1) = A(i0, i1).jacobi();
  }
};

struct ApplySGS
{
  static void set_steps (const int in) {steps_m = in;}

  inline int lowerExtent(int) const { return 0; }
  inline int upperExtent(int) const { return 0; }  

  template<class Tag>
  static void
  apply(
	const Array<2, double, Tag> &res,
	const Array<2, double, Tag> &rhs,
	const Array<2, FivePoint, Tag> &values
	)
  {
    Interval<2> domain = values.domain();

    Array<2, double, Brick> y(domain);

    int i0, i1;
    int b0 = domain[0].first();
    int b1 = domain[1].first();
    int e0 = domain[0].last();
    int e1 = domain[1].last();

    if (steps_m == 1) {

       res = 0;

// do a forward sweep

       i1 = b1;
       i0 = b0;
       res(i0, i1) = values(i0, i1).center() * rhs(i0, i1) +
          values(i0, i1).east() * res(i0 + 1, i1) +
          values(i0, i1).north() * res(i0, i1 + 1);

       i1 = b1;
       for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
          res(i0, i1) = values(i0, i1).west() * res(i0 - 1, i1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).east() * res(i0 + 1, i1) +
             values(i0, i1).north() * res(i0, i1 + 1);  } // for i0

       i1 = b1;
       i0 = e0;
       res(i0, i1) = values(i0, i1).west() * res(i0 - 1, i1) +
          values(i0, i1).center() * rhs(i0, i1) +
          values(i0, i1).north() * res(i0, i1 + 1);

       for (i1 = b1 + 1; i1 <= e1 - 1; i1++) {
          i0 = b0;
          res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).east() * res(i0 + 1, i1) +
             values(i0, i1).north() * res(i0, i1 + 1);

          for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
             res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
                values(i0, i1).west() * res(i0 - 1, i1) +
                values(i0, i1).center() * rhs(i0, i1) +
                values(i0, i1).east() * res(i0 + 1, i1) +
                values(i0, i1).north() * res(i0, i1 + 1); } // for i0 

          i0 = e0;
          res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
             values(i0, i1).west() * res(i0 - 1, i1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).north() * res(i0, i1 + 1); } // for i1 

       i1 = e1;
       i0 = b0;
       res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
          values(i0, i1).center() * rhs(i0, i1) +
          values(i0, i1).east() * res(i0 + 1, i1);

       i1 = e1;
       for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
          res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
             values(i0, i1).west() * res(i0 - 1, i1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).east() * res(i0 + 1, i1); } // for i0

       i1 = e1;
       i0 = e0;
       res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
          values(i0, i1).west() * res(i0 - 1, i1) +
          values(i0, i1).center() * rhs(i0, i1);

// do a backward-facing sweep

       i1 = e1;
       i0 = e0;
       res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
          values(i0, i1).west() * res(i0 - 1, i1) +
          values(i0, i1).center() * rhs(i0, i1);

       i1 = e1;
       for (i0 = e0 - 1; i0 >= b0 + 1; i0--) {
          res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
             values(i0, i1).west() * res(i0 - 1, i1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).east() * res(i0 + 1, i1); } // for i0

       i1 = e1;
       i0 = b0;
       res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
          values(i0, i1).center() * rhs(i0, i1) +
          values(i0, i1).east() * res(i0 + 1, i1);

       for (i1 = e1 - 1; i1 >= b1 + 1; i1--) {
          i0 = e0;
          res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
             values(i0, i1).west() * res(i0 - 1, i1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).north() * res(i0, i1 + 1);
          for (i0 = e0 - 1; i0 >= b0 + 1; i0--) {
             res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
                values(i0, i1).west() * res(i0 - 1, i1) +
                values(i0, i1).center() * rhs(i0, i1) +
                values(i0, i1).east() * res(i0 + 1, i1) +
                values(i0, i1).north() * res(i0, i1 + 1); } // for i0 
          i0 = b0;
          res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).east() * res(i0 + 1, i1) +
             values(i0, i1).north() * res(i0, i1 + 1); } // for i1 

       i1 = b1;
       i0 = e0;
       res(i0, i1) = values(i0, i1).west() * res(i0 - 1, i1) +
          values(i0, i1).center() * rhs(i0, i1) +
          values(i0, i1).north() * res(i0, i1 + 1);

       i1 = b1;
       for (i0 = e0 - 1; i0 >= b0 + 1; i0--) {
          res(i0, i1) = values(i0, i1).west() * res(i0 - 1, i1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).east() * res(i0 + 1, i1) +
             values(i0, i1).north() * res(i0, i1 + 1);  } // for i0

       i1 = b1;
       i0 = b0;
       res(i0, i1) = values(i0, i1).center() * rhs(i0, i1) +
          values(i0, i1).east() * res(i0 + 1, i1) +
          values(i0, i1).north() * res(i0, i1 + 1);

    } else if (steps_m > 1) {

       res = 0;

       for (int count = 1; count <= steps_m; count++) {

// do a forward sweep

          i1 = b1;
          i0 = b0;
          res(i0, i1) = values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).east() * res(i0 + 1, i1) +
             values(i0, i1).north() * res(i0, i1 + 1);
   
          i1 = b1;
          for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
             res(i0, i1) = values(i0, i1).west() * res(i0 - 1, i1) +
                values(i0, i1).center() * rhs(i0, i1) +
                values(i0, i1).east() * res(i0 + 1, i1) +
                values(i0, i1).north() * res(i0, i1 + 1);  } // for i0
   
          i1 = b1;
          i0 = e0;
          res(i0, i1) = values(i0, i1).west() * res(i0 - 1, i1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).north() * res(i0, i1 + 1);
   
          for (i1 = b1 + 1; i1 <= e1 - 1; i1++) {
             i0 = b0;
             res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
                values(i0, i1).center() * rhs(i0, i1) +
                values(i0, i1).east() * res(i0 + 1, i1) +
                values(i0, i1).north() * res(i0, i1 + 1);
   
             for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
                res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
                   values(i0, i1).west() * res(i0 - 1, i1) +
                   values(i0, i1).center() * rhs(i0, i1) +
                   values(i0, i1).east() * res(i0 + 1, i1) +
                   values(i0, i1).north() * res(i0, i1 + 1); } // for i0 
   
             i0 = e0;
             res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
                values(i0, i1).west() * res(i0 - 1, i1) +
                values(i0, i1).center() * rhs(i0, i1) +
                values(i0, i1).north() * res(i0, i1 + 1); } // for i1 
   
          i1 = e1;
          i0 = b0;
          res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).east() * res(i0 + 1, i1);
   
          i1 = e1;
          for (i0 = b0 + 1; i0 <= e0 - 1; i0++) {
             res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
                values(i0, i1).west() * res(i0 - 1, i1) +
                values(i0, i1).center() * rhs(i0, i1) +
                values(i0, i1).east() * res(i0 + 1, i1); } // for i0
   
          i1 = e1;
          i0 = e0;
          res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
             values(i0, i1).west() * res(i0 - 1, i1) +
             values(i0, i1).center() * rhs(i0, i1);

// do a backward-facing sweep

          i1 = e1;
          i0 = e0;
          res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
             values(i0, i1).west() * res(i0 - 1, i1) +
             values(i0, i1).center() * rhs(i0, i1);

          i1 = e1;
          for (i0 = e0 - 1; i0 >= b0 + 1; i0--) {
             res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
                values(i0, i1).west() * res(i0 - 1, i1) +
                values(i0, i1).center() * rhs(i0, i1) +
                values(i0, i1).east() * res(i0 + 1, i1); } // for i0

          i1 = e1;
          i0 = b0;
          res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).east() * res(i0 + 1, i1);

          for (i1 = e1 - 1; i1 >= b1 + 1; i1--) {
             i0 = e0;
             res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
                values(i0, i1).west() * res(i0 - 1, i1) +
                values(i0, i1).center() * rhs(i0, i1) +
                values(i0, i1).north() * res(i0, i1 + 1);

             for (i0 = e0 - 1; i0 >= b0 + 1; i0--) {
                res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
                   values(i0, i1).west() * res(i0 - 1, i1) +
                   values(i0, i1).center() * rhs(i0, i1) +
                   values(i0, i1).east() * res(i0 + 1, i1) +
                   values(i0, i1).north() * res(i0, i1 + 1); } // for i0 

             i0 = b0;
             res(i0, i1) = values(i0, i1).south() * res(i0, i1 - 1) +
                values(i0, i1).center() * rhs(i0, i1) +
                values(i0, i1).east() * res(i0 + 1, i1) +
                values(i0, i1).north() * res(i0, i1 + 1); } // for i1 

          i1 = b1;
          i0 = e0;
          res(i0, i1) = values(i0, i1).west() * res(i0 - 1, i1) +
             values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).north() * res(i0, i1 + 1);

          i1 = b1;
          for (i0 = e0 - 1; i0 >= b0 + 1; i0--) {
             res(i0, i1) = values(i0, i1).west() * res(i0 - 1, i1) +
                values(i0, i1).center() * rhs(i0, i1) +
                values(i0, i1).east() * res(i0 + 1, i1) +
                values(i0, i1).north() * res(i0, i1 + 1);  } // for i0

          i1 = b1;
          i0 = b0;
          res(i0, i1) = values(i0, i1).center() * rhs(i0, i1) +
             values(i0, i1).east() * res(i0 + 1, i1) +
             values(i0, i1).north() * res(i0, i1 + 1);

       } // for count
    } // steps_m > 1

    else if (steps_m == 0) {
       for (i0 = b0; i0 <= e0; i0++)
          for (i1 = b1; i1 <= e1; i1++)
             res(i0, i1) = rhs(i0, i1); }

  } // apply()

private:
   static int steps_m;

};

template<class EngineTag>
class SGSPreconditioner
{
public:
  typedef Array<2, FivePoint, EngineTag> Values_t;

  inline Values_t &values() { return values_m; }
  inline const Values_t &values() const { return values_m; }

  inline const Interval<2> &domain() const { return domain_m; }

  template<class Layout, class Domain>
  SGSPreconditioner(const Layout &layout, const Domain &domain)
    : values_m(layout), domain_m(domain)
  {
  }

  static GuardLayers<2> guardLayers() { return GuardLayers<2>(0); }

  void operator()(
		  const Array<2, double, EngineTag> &x,
		  const Array<2, double, EngineTag> &y
		  ) const
  {
    GuardedPatchEvaluator<MainEvaluatorTag>::evaluate(y, x, values_m,
						      ApplySGS(),
						      domain_m);
  }

private:

  Values_t values_m;
  Interval<2> domain_m;
  double zero_m;
};

template<class Matrix, class Method>
struct GeneratePreconditioner;

struct FivePointSymmetricGaussSeidelTag { };

template<class EngineTag>
struct GeneratePreconditioner<FivePointMatrix<EngineTag>,
  FivePointSymmetricGaussSeidelTag>
{
  typedef SGSPreconditioner<EngineTag> Type_t;
  typedef FivePointMatrix<EngineTag> Input_t;

  static void fill(
		   const Input_t &input,
		   const Type_t &preconditioner
		   )
  {
    GuardedPatchEvaluator<MainEvaluatorTag>::evaluate(preconditioner.values(),
						      input.values(),
						      ComputeSGS(),
						      preconditioner.domain());
  }

  static void set_steps (const int steps) { ApplySGS::set_steps(steps); }

};

int ApplySGS::steps_m = 1;

