// -*- C++ -*-
//
// Copyright (C) 1998, 1999, 2000, 2002  Los Alamos National Laboratory,
// Copyright (C) 1998, 1999, 2000, 2002  CodeSourcery, LLC
//
// This file is part of FreePOOMA.
//
// FreePOOMA is free software; you can redistribute it and/or modify it
// under the terms of the Expat license.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Expat
// license for more details.
//
// You should have received a copy of the Expat license along with
// FreePOOMA; see the file LICENSE.
//

//-----------------------------------------------------------------------------
// Class:
// level-one Incomplete Choleski (IC(1))
// NinePointMatrix
//-----------------------------------------------------------------------------

#include <cmath>

#include "NinePointMatrix.h"
#include "NinePointIC1Stencil.h"

struct ComputeIC1
{
  inline int lowerExtent(int) const { return 0; }
  inline int upperExtent(int) const { return 0; }

  template<class Tag>
  static void
  apply(
	const Array<2, NinePointIC1Stencil, Tag> &values,
	const Array<2, NinePoint, Tag> &A
	)
  {
    Interval<2> domain = A.domain();

    int i0, i1;
    int b0 = domain[0].first();
    int b1 = domain[1].first();
    int e0 = domain[0].last();
    int e1 = domain[1].last();

    double sw, s, se, ese, w2, w, c;

    // southwest corner (no legs)

     i0 = b0;
     i1 = b1;
     c = sqrt(A(i0, i1).center());
     values(i0, i1) = NinePointIC1Stencil (0, 0, 0, 0, 0, 0, c);

// one in from the southwest corner (center and west only)

     i0 = b0 + 1;
     i1 = b1;

     w = A(i0, i1).west() / values(i0 - 1, i1).center();
     c = sqrt(A(i0, i1).center() - w*w);
     values(i0, i1) = NinePointIC1Stencil (0, 0, 0, 0, 0, w, c);

     // south edge (just center, west, and w2)

     for (i0 = b0 + 2; i0 <= e0; i0++)
     {

// w2 is zero because its entire denominator is zero a term at a time
// (all the l_ik l_jk terms reach "south")

       w = A(i0, i1).west() / values(i0 - 1, i1).center();
       c = sqrt(A(i0, i1).center() - w*w);
       values(i0, i1) = NinePointIC1Stencil (0, 0, 0, 0, 0, w, c);

     }

     // now do the north-south interior

     for (i1 = b1 + 1; i1 <= e1; i1++)
     {

       // west edge (no west, southwest, or w2)

       i0 = b0;

       s = (A(i0, i1).south() / values(i0, i1 - 1).center()); 
       se = (A(i0, i1).southeast() - s * values(i0 + 1, i1 - 1).west()) /
          values(i0 + 1, i1 - 1).center();
       ese = - (s * values(i0 + 2, i1 - 1).w2()) /
          values(i0 + 2, i1 - 1).center();
       c = sqrt(A(i0, i1).center() - s*s - se*se - ese*ese);
       values(i0, i1) = NinePointIC1Stencil (0, s, se, ese, 0, 0, c);

// first column from west edge (no w2)

       i0 = b0 + 1;

       sw = A(i0, i1).southwest() / values(i0 - 1, i1 - 1).center();
       s = (A(i0, i1).south() - sw * values(i0, i1 - 1).west()) /
          values(i0, i1 - 1).center(); 
       se = (A(i0, i1).southeast() - s * values(i0 + 1, i1 - 1).west() -
          sw * values(i0 + 1, i1 - 1).w2()) / values(i0 + 1, 
          i1 - 1).center();
       ese = - (sw * values(i0 + 2, i1 - 1).west() + s * 
          values(i0 + 2, i1 - 1).w2()) / values(i0 + 2, i1 - 1).center();
       w = (A(i0, i1).west() - se * values(i0 - 1, i1).ese() -
          s * values(i0 - 1, i1).southeast() - sw * values(i0 - 1, 
          i1).south()) / values(i0 - 1, i1).center();
       c = sqrt(A(i0, i1).center() - sw*sw - s*s - se*se - ese*ese - w*w); 
       values(i0, i1) = NinePointIC1Stencil (sw, s, se, ese, 0, w, c);

       // interior

       for (i0 = b0 + 2; i0 <= e0 - 2; i0++) {
         sw = A(i0, i1).southwest() / values(i0 - 1, i1 - 1).center();
         s = (A(i0, i1).south() - sw * values(i0, i1 - 1).west()) /
            values(i0, i1 - 1).center(); 
         se = (A(i0, i1).southeast() - s * values(i0 + 1, i1 - 1).west() -
            sw * values(i0 + 1, i1 - 1).w2()) / values(i0 + 1, 
            i1 - 1).center();
         ese = - (sw * values(i0 + 2, i1 - 1).west() + s * 
            values(i0 + 2, i1 - 1).w2()) / values(i0 + 2, i1 - 1).center();
         w2 = - (s * values(i0 - 2, i1).ese() + sw * values(i0 - 2,
            i1).southeast()) / values(i0 - 2, i1).center();
         w = (A(i0, i1).west() - se * values(i0 - 1, i1).ese() -
            s * values(i0 - 1, i1).southeast() - sw * values(i0 - 1, 
            i1).south()) / values(i0 - 1, i1).center();
         c = sqrt(A(i0, i1).center() - sw*sw - s*s - se*se - ese*ese -
            w2*w2 - w*w); 
         values(i0, i1) = NinePointIC1Stencil (sw, s, se, ese, w2, w, c);
         } // for i0

// last column before east edge (no ese)

       i0 = e0 - 1;
       sw = A(i0, i1).southwest() / values(i0 - 1, i1 - 1).center();
       s = (A(i0, i1).south() - sw * values(i0, i1 - 1).west()) /
          values(i0, i1 - 1).center(); 
       se = (A(i0, i1).southeast() - s * values(i0 + 1, i1 - 1).west() -
          sw * values(i0 + 1, i1 - 1).w2()) / values(i0 + 1, 
          i1 - 1).center();
       w2 = - (s * values(i0 - 2, i1).ese() + sw * values(i0 - 2,
          i1).southeast()) / values(i0 - 2, i1).center();
       w = (A(i0, i1).west() - se * values(i0 - 1, i1).ese() -
          s * values(i0 - 1, i1).southeast() - sw * values(i0 - 1, 
          i1).south()) / values(i0 - 1, i1).center();
       c = sqrt(A(i0, i1).center() - sw*sw - s*s - se*se - w2*w2 - w*w); 
       values(i0, i1) = NinePointIC1Stencil (sw, s, se, 0, w2, w, c);

       // east edge (no southeast or ese)

       i0 = e0;
       sw = A(i0, i1).southwest() / values(i0 - 1, i1 - 1).center();
       s = (A(i0, i1).south() - sw * values(i0, i1 - 1).west()) /
          values(i0, i1 - 1).center(); 
       w2 = - (s * values(i0 - 2, i1).ese() + sw * values(i0 - 2,
          i1).southeast()) / values(i0 - 2, i1).center();
       w = (A(i0, i1).west() - s * values(i0 - 1, i1).southeast() -
          sw * values(i0 - 1, i1).south()) / values(i0 - 1, i1).center();
       c = sqrt(A(i0, i1).center() - sw*sw - s*s - w2*w2 - w*w); 
       values(i0, i1) = NinePointIC1Stencil (sw, s, 0, 0, w2, w, c);

     } // for i1

// run through the IC factor and kludge() the coefficients.
// this will allow the apply() to do a multiply instead of a divide

     for (i1 = b1; i1 <= e1; i1++)
        for (i0 = b0; i0 <= e0; i0++)
           values(i0, i1) = values(i0, i1).kludge();

// but gosh what an ugly solution
// a better solution would be to do an LDL^T decomposition

  }
};

struct ApplyIC1
{
  inline int lowerExtent(int) const { return 0; }
  inline int upperExtent(int) const { return 0; }

  template<class Tag>
  static void
  apply(
	const Array<2, double, Tag> &res,
	const Array<2, double, Tag> &rhs,
	const Array<2, NinePointIC1Stencil, Tag> &values
	)
  {
    Interval<2> domain = values.domain();

    Array<2, double, Brick> y(domain);

    int i0, i1;
    int b0 = domain[0].first();
    int b1 = domain[1].first();
    int e0 = domain[0].last();
    int e1 = domain[1].last();

    double c;

    // forward solve:

    // southwest corner (no legs)

    i0 = b0;
    i1 = b1;
    c = values(i0, i1).center();
    y(i0, i1) = c * rhs(i0, i1);

// next to the south edge (only one west leg)

    i0 = b0 + 1;
    i1 = b1;
    c = values(i0, i1).center();
    y(i0, i1) = c * rhs(i0, i1) -
       c * values(i0, i1).west() * y(i0 - 1, i1);

    // south edge and southeast corner (no south-reaching legs)

    i1 = b1;
    for (i0 = b0 + 2; i0 <= e0; i0++) {
      c = values(i0, i1).center();
      y(i0, i1) = c * rhs(i0, i1) -
         c * values(i0, i1).w2() * y(i0 - 2, i1) -
         c * values(i0, i1).west() * y(i0 - 1, i1); } // for i0

     // now do the north-south interior

     for (i1 = b1 + 1; i1 <= e1; i1++)
     {

       // west edge (no west, southwest, or w2 legs)

       i0 = b0;
       c = values(i0, i1).center();
       y(i0, i1) = c * rhs(i0, i1) -
          c * values(i0, i1).south() * y(i0, i1 - 1) -
          c * values(i0, i1).southeast() * y(i0 + 1, i1 - 1) -
          c * values(i0, i1).ese() * y(i0 + 2, i1 - 1);

// next column over (no w2 leg)

       i0 = b0 + 1;
       c = values(i0, i1).center();
       y(i0, i1) = c * rhs(i0, i1) -
          c * values(i0, i1).southwest() * y(i0 - 1, i1 - 1) -
          c * values(i0, i1).south() * y(i0, i1 - 1) -
          c * values(i0, i1).southeast() * y(i0 + 1, i1 - 1) -
          c * values(i0, i1).ese() * y(i0 + 2, i1 - 1) -
          c * values(i0, i1).west() * y(i0 - 1, i1);

       // interior

       for (i0 = b0 + 2; i0 <= e0 - 2; i0++) {
         c = values(i0, i1).center();
         y(i0, i1) = c * rhs(i0, i1) -
            c * values(i0, i1).southwest() * y(i0 - 1, i1 - 1) -
            c * values(i0, i1).south() * y(i0, i1 - 1) -
            c * values(i0, i1).southeast() * y(i0 + 1, i1 - 1) -
            c * values(i0, i1).ese() * y(i0 + 2, i1 - 1) -
            c * values(i0, i1).w2() * y(i0 - 2, i1) -
            c * values(i0, i1).west() * y(i0 - 1, i1); } // for i0

// next-to-last column (no ese)

       i0 = e0 - 1;
       c = values(i0, i1).center();
       y(i0, i1) = c * rhs(i0, i1) -
          c * values(i0, i1).southwest() * y(i0 - 1, i1 - 1) -
          c * values(i0, i1).south() * y(i0, i1 - 1) -
          c * values(i0, i1).southeast() * y(i0 + 1, i1 - 1) -
          c * values(i0, i1).w2() * y(i0 - 2, i1) -
          c * values(i0, i1).west() * y(i0 - 1, i1);

       // east edge (no east-reaching legs)

       i0 = e0;
       c = values(i0, i1).center();
       y(i0, i1) = c * rhs(i0, i1) -
          c * values(i0, i1).southwest() * y(i0 - 1, i1 - 1) -
          c * values(i0, i1).south() * y(i0, i1 - 1) -
          c * values(i0, i1).w2() * y(i0 - 2, i1) -
          c * values(i0, i1).west() * y(i0 - 1, i1);

     } // for i1

     // backsolve

     // northeast corner

     i1 = e1;
     i0 = e0;
     c = values(i0, i1).center();
     res(i0, i1) = c * y(i0, i1);

// next location over (center and west leg only)

     i1 = e1;
     i0 = i0 - 1;

     c = values(i0, i1).center();
     res(i0, i1) = c * y(i0, i1) -
        c * values(i0 + 1, i1).west() * res(i0 + 1, i1);

     // north edge and northwest corner

     i1 = e1;
     for (i0 = e0 - 2; i0 >= b0; i0--) {
       c = values(i0, i1).center();
       res(i0, i1) = c * y(i0, i1) -
          c * values(i0 + 2, i1).w2() * res(i0 + 2, i1) -
          c * values(i0 + 1, i1).west() * res(i0 + 1, i1); } // for i0

     // now do the north-south interior

     for (i1 = e1 - 1; i1 >= b1; i1--)
     {

       // east edge (no west, southwest, or w2 legs)

       i0 = e0;
       c = values(i0, i1).center();
       res(i0, i1) = c * y(i0, i1) -
          c * values(i0, i1 + 1).south() * res(i0, i1 + 1) -
          c * values(i0 - 1, i1 + 1).southeast() * res(i0 - 1, i1 + 1) -
          c * values(i0 - 2, i1 + 1).ese() * res(i0 - 2, i1 + 1);

// next-to-last column (no w2 leg)

       i0 = e0 - 1;
       c = values(i0, i1).center();
       res(i0, i1) = c * y(i0, i1) -
          c * values(i0 + 1, i1 + 1).southwest() * res(i0 + 1, i1 + 1) -
          c * values(i0, i1 + 1).south() * res(i0, i1 + 1) -
          c * values(i0 - 1, i1 + 1).southeast() * res(i0 - 1, i1 + 1) -
          c * values(i0 - 2, i1 + 1).ese() * res(i0 - 2, i1 + 1) -
          c * values(i0 + 1, i1).west() * res(i0 + 1, i1);

       // interior

       for (i0 = e0 - 2; i0 >= b0 + 2; i0--) {
         c = values(i0, i1).center();
         res(i0, i1) = c * y(i0, i1) -
            c * values(i0 + 1, i1 + 1).southwest() * res(i0 + 1, i1 + 1) -
            c * values(i0, i1 + 1).south() * res(i0, i1 + 1) -
            c * values(i0 - 1, i1 + 1).southeast() * res(i0 - 1, i1 + 1) -
            c * values(i0 - 2, i1 + 1).ese() * res(i0 - 2, i1 + 1) -
            c * values(i0 + 2, i1).w2() * res(i0 + 2, i1) -
            c * values(i0 + 1, i1).west() * res(i0 + 1, i1);
       } // for i0

// first interior column (no ese leg)

       i0 = b0 + 1;
       c = values(i0, i1).center();
       res(i0, i1) = c * y(i0, i1) -
          c * values(i0 + 1, i1 + 1).southwest() * res(i0 + 1, i1 + 1) -
          c * values(i0, i1 + 1).south() * res(i0, i1 + 1) -
          c * values(i0 - 1, i1 + 1).southeast() * res(i0 - 1, i1 + 1) -
          c * values(i0 + 2, i1).w2() * res(i0 + 2, i1) -
          c * values(i0 + 1, i1).west() * res(i0 + 1, i1);

       // west edge (no southeast or ese legs)

       i0 = b0;
       c = values(i0, i1).center();
       res(i0, i1) = c * y(i0, i1) -
          c * values(i0 + 1, i1 + 1).southwest() * res(i0 + 1, i1 + 1) -
          c * values(i0, i1 + 1).south() * res(i0, i1 + 1) -
          c * values(i0 + 2, i1).w2() * res(i0 + 2, i1) -
          c * values(i0 + 1, i1).west() * res(i0 + 1, i1);

     } // for i1

  } // apply()
};

template<class EngineTag>
class IC1Preconditioner
{
public:
  typedef Array<2, NinePointIC1Stencil, EngineTag> Values_t;

  inline Values_t &values() { return values_m; }
  inline const Values_t &values() const { return values_m; }

  inline const Interval<2> &domain() const { return domain_m; }

  template<class Layout, class Domain>
  IC1Preconditioner(const Layout &layout, const Domain &domain)
    : values_m(layout), domain_m(domain)
  {
  }

  static GuardLayers<2> guardLayers() { return GuardLayers<2>(0); }

  void operator()(
		  const Array<2, double, EngineTag> &x,
		  const Array<2, double, EngineTag> &y
		  ) const
  {
    GuardedPatchEvaluator<MainEvaluatorTag>::evaluate(y, x, values_m,
						      ApplyIC1(),
						      domain_m);
  }

private:

  Values_t values_m;
  Interval<2> domain_m;
  double zero_m;
};

template<class Matrix, class Method>
struct GeneratePreconditioner;

struct NinePointIncompleteCholeski1Tag { };

template<class EngineTag>
struct GeneratePreconditioner<NinePointMatrix<EngineTag>,
  NinePointIncompleteCholeski1Tag>
{
  typedef IC1Preconditioner<EngineTag> Type_t;
  typedef NinePointMatrix<EngineTag> Input_t;

  static void fill(
		   const Input_t &input,
		   const Type_t &preconditioner
		   )
  {
    GuardedPatchEvaluator<MainEvaluatorTag>::evaluate(preconditioner.values(),
						      input.values(),
						      ComputeIC1(),
						      preconditioner.domain());
  }
};

