///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
/// 
/// =========================================================================

#include "rheolef/csr.h"
#include "rheolef/asr.h"

#include "rheolef/msg_util.h"
#include "rheolef/asr_to_csr.h"
#include "rheolef/csr_to_asr.h"
#include "rheolef/csr_amux.h"
#include "rheolef/csr_cumul_trans_mult.h"
using namespace std;
namespace rheolef {
// ----------------------------------------------------------------------------
// class member functions
// ----------------------------------------------------------------------------

template<class T>
csr_seq_rep<T>::csr_seq_rep(const distributor& row_ownership, const distributor& col_ownership, size_type nnz1) 
 : vector_of_iterator<pair_type>(row_ownership.size()+1),
   _row_ownership (row_ownership),
   _col_ownership (col_ownership),
   _data (nnz1),
   _is_symmetric (false),
   _pattern_dimension (0)
{
}
template<class T>
void
csr_seq_rep<T>::resize (const distributor& row_ownership, const distributor& col_ownership, size_type nnz1) 
{
   vector_of_iterator<pair_type>::resize (row_ownership.size()+1);
   _row_ownership = row_ownership;
   _col_ownership = col_ownership;
   _data.resize (nnz1);
   // first pointer points to the beginning of the data:
   vector_of_iterator<pair_type>::operator[](0) = _data.begin().operator->();
}
template<class T>
csr_seq_rep<T>::csr_seq_rep(size_type loc_nrow1, size_type loc_ncol1, size_type loc_nnz1) 
 : vector_of_iterator<pair_type> (loc_nrow1+1),
   _row_ownership (distributor::decide, communicator(), loc_nrow1),
   _col_ownership (distributor::decide, communicator(), loc_ncol1),
   _data (loc_nnz1),
   _is_symmetric (false),
   _pattern_dimension (0)
{
}
template<class T>
void
csr_seq_rep<T>::resize (size_type loc_nrow1, size_type loc_ncol1, size_type loc_nnz1)
{
   vector_of_iterator<pair_type>::resize (loc_nrow1+1);
   _row_ownership = distributor (distributor::decide, communicator(), loc_nrow1);
   _col_ownership = distributor (distributor::decide, communicator(), loc_ncol1);
   _data.resize (loc_nnz1);
   // first pointer points to the beginning of the data:
   vector_of_iterator<pair_type>::operator[](0) = _data.begin().operator->();
}
template<class T>
csr_seq_rep<T>::csr_seq_rep(const csr_seq_rep<T>& b)
 : vector_of_iterator<pair_type>(b.nrow()+1),
   _row_ownership (b.row_ownership()),
   _col_ownership (b.col_ownership()),
   _data(b._data),
   _is_symmetric (b._is_symmetric),
   _pattern_dimension (b._pattern_dimension)
{
  // physical copy of csr
  typedef typename csr_seq_rep<T>::size_type size_type;
  typename csr_seq_rep<T>::const_iterator ib = b.begin();
  typename csr_seq_rep<T>::iterator       ia = begin();
  ia[0] = _data.begin().operator->();
  for (size_type i = 0, n = b.nrow(); i < n; i++) {
    ia [i+1] = ia[0] + (ib[i+1] - ib[0]);
  }
}
template<class T>
csr_seq_rep<T>::csr_seq_rep(const asr_seq_rep<T>& a)
  : vector_of_iterator<pair_type>(a.nrow()+1),
   _row_ownership (a.row_ownership()),
   _col_ownership (a.col_ownership()),
   _data(a.nnz())
{
    typedef pair<size_type,T> pair_type;
    typedef typename asr_seq_rep<T>::row_type::value_type const_pair_type;
    
    asr_to_csr (
	a.begin(),
        a.end(), 
	always_true<const_pair_type>(), 
	pair_identity<const_pair_type,pair_type>(), 
        vector_of_iterator<pair_type>::begin(), 
        _data.begin());
}
template<class T>
void
csr_seq_rep<T>::to_asr(asr_seq_rep<T>& b) const
{
    typedef pair<size_type,T> pair_type;
    typedef typename asr_seq_rep<T>::row_type::value_type const_pair_type;

    csr_to_asr (
        vector_of_iterator<pair_type>::begin(), 
        vector_of_iterator<pair_type>::end(),
        _data.begin().operator->(), 
	pair_identity<pair_type,const_pair_type>(), 
	b.begin().operator->());
}
template<class T>
idiststream& 
csr_seq_rep<T>::get (idiststream& ps)
{
    typedef pair<size_type,T> pair_type;
    typedef typename asr_seq_rep<T>::row_type::value_type const_pair_type;
    asr_seq_rep<T> a;
    a.get(ps);
    resize (a.nrow(), a.ncol(), a.nnz());
    asr_to_csr (
	a.begin(), 
        a.end(), 
	always_true<const_pair_type>(), 
	pair_identity<const_pair_type,pair_type>(), 
        vector_of_iterator<pair_type>::begin(), 
        _data.begin());
    return ps;
}
template<class T>
odiststream&
csr_seq_rep<T>::put (odiststream& ops, size_type istart) const
{
    std::ostream& os = ops.os();
    os << setprecision (std::numeric_limits<T>::digits10)
       << "%%MatrixMarket matrix coordinate real general" << std::endl
       << nrow() << " " << ncol() << " " << nnz() << endl;
    const_iterator ia = begin();
    const size_type base = 1;
    for (size_type i = 0, n = nrow(); i < n; i++) {
        for (const_data_iterator iter_jva = ia[i], last_jva = ia[i+1];
	    iter_jva != last_jva; iter_jva++) {

	    os << i+istart+base << " "
               << (*iter_jva).first+base << " "
               << (*iter_jva).second << endl;	
  	}
    }
    return ops;
}
template<class T>
void
csr_seq_rep<T>::dump (const string& name, size_type istart) const
{
    std::ofstream os (name.c_str());
    std::cerr << "! file \"" << name << "\" created." << std::endl;
    odiststream ops(os);
    put (ops);
}
// ----------------------------------------------------------------------------
// basic linear algebra
// ----------------------------------------------------------------------------

template<class T>
void
csr_seq_rep<T>::mult(
    const vec<T,sequential>& x,
    vec<T,sequential>&       y)
    const
{
    csr_amux (
        vector_of_iterator<pair_type>::begin(), 
        vector_of_iterator<pair_type>::end(), 
        x.begin(), 
        set_op<T,T>(),
        y.begin());
}
template<class T>
void
csr_seq_rep<T>::trans_mult(
    const vec<T,sequential>& x,
    vec<T,sequential>&       y)
    const
{
    // reset y and then cumul
    std::fill (y.begin(), y.end(), T(0));
    csr_cumul_trans_mult (
        vector_of_iterator<pair_type>::begin(), 
        vector_of_iterator<pair_type>::end(), 
        x.begin(), 
        set_add_op<T,T>(),
        y.begin());
}
template<class T>
csr_seq_rep<T>&
csr_seq_rep<T>::operator*= (const T& lambda)
{
  iterator ia = begin();
  for (data_iterator p = ia[0], last_p = ia[nrow()]; p != last_p; p++) {
    (*p).second *= lambda;
  }
  return *this;
}
// ----------------------------------------------------------------------------
// expression c=a+b and c=a-b with a temporary c=*this
// ----------------------------------------------------------------------------
template<class T>
template<class BinaryOp>
void
csr_seq_rep<T>::assign_add (
    const csr_seq_rep<T>& a, 
    const csr_seq_rep<T>& b,
    BinaryOp binop)
{
    check_macro (a.nrow() == b.nrow() && a.ncol() == b.ncol(),
	"incompatible csr add(a,b): a("<<a.nrow()<<":"<<a.ncol()<<") and "
	"b("<<b.nrow()<<":"<<b.ncol()<<")");
    //
    // first pass: compute nnz_c and resize
    //
    size_type nnz_c = 0;
    const size_type infty = std::numeric_limits<size_type>::max();
    const_iterator ia = a.begin();
    const_iterator ib = b.begin();
    for (size_type i = 0, n = a.nrow(); i < n; i++) {
        for (const_data_iterator iter_jva = ia[i], last_jva = ia[i+1],
                                 iter_jvb = ib[i], last_jvb = ib[i+1];
	    iter_jva != last_jva || iter_jvb != last_jvb; ) {

            size_type ja = iter_jva == last_jva ? infty : (*iter_jva).first;
            size_type jb = iter_jvb == last_jvb ? infty : (*iter_jvb).first;
	    if (ja == jb) {
		iter_jva++;
		iter_jvb++;
	    } else if (ja < jb) {
		iter_jva++;
            } else {
		iter_jvb++;
            }
	    nnz_c++;
  	}
    }
    resize (a.row_ownership(), b.col_ownership(), nnz_c);
    data_iterator iter_jvc = _data.begin().operator->();
    iterator ic = begin();
    *ic++ = iter_jvc;
    //
    // second pass: add and store in c
    //
    for (size_type i = 0, n = a.nrow(); i < n; i++) {
        for (const_data_iterator iter_jva = ia[i], last_jva = ia[i+1],
                                 iter_jvb = ib[i], last_jvb = ib[i+1];
	    iter_jva != last_jva || iter_jvb != last_jvb; ) {

            size_type ja = iter_jva == last_jva ? infty : (*iter_jva).first;
            size_type jb = iter_jvb == last_jvb ? infty : (*iter_jvb).first;
	    if (ja == jb) {
		*iter_jvc++ = std::pair<size_type,T> (ja, binop((*iter_jva).second, (*iter_jvb).second));
		iter_jva++;
		iter_jvb++;
	    } else if (ja < jb) {
		*iter_jvc++ = *iter_jva++;
            } else {
		*iter_jvc++ = *iter_jvb++;
            }
  	}
        *ic++ = iter_jvc;
    }
    set_symmetry          (a.is_symmetric() && b.is_symmetric());
    set_pattern_dimension (std::max(a.pattern_dimension(), b.pattern_dimension()));
}
// ----------------------------------------------------------------------------
// trans(a)
// ----------------------------------------------------------------------------
/*@! 
 \vfill \pagebreak \mbox{} \vfill \begin{algorithm}[h]
  \caption{{\tt sort}: sort rows by increasing column order}
  \begin{algorithmic}
    \INPUT {the matrix in CSR format}
      ia(0:nrow), ja(0:nnz-1), a(0:nnz-1)
    \ENDINPUT
    \OUTPUT {the sorted CSR matrix}
      iao(0:nrow), jao(0:nnzl-1), ao(0:nnzl-1)
    \ENDOUTPUT
    \BEGIN 
      {\em first: reset iao} \\
      \FORTO {i := 0} {nrow}
	  iao(i) := 0;
      \ENDFOR
	
      {\em second: compute lengths from pointers} \\
      \FORTO {i := 0} {nrow-1}
        \FORTO {p := ia(i)} {ia(i+1)-1}
	    iao (ja(p)+1)++;
        \ENDFOR
      \ENDFOR

      {\em third: compute pointers from lengths} \\
      \FORTO {i := 0} {nrow-1}
	  iao(i+1) += iao(i)
      \ENDFOR

      {\em fourth: copy values} \\
      \FORTO {i := 0} {nrow-1}
        \FORTO {p := ia(i)} {ia(i+1)-1}
          j := ja(p) \\
	  q := iao(j) \\
	  jao(q) := i \\
	  ao(q) := a(p) \\
	  iao (j)++
        \ENDFOR
      \ENDFOR

      {\em fiveth: reshift pointers} \\
      \FORTOSTEP {i := nrow-1} {0} {-1}
	iao(i+1) := iao(i);
      \ENDFOR
      iao(0) := 0
    \END
 \end{algorithmic} \end{algorithm}
 \vfill \pagebreak \mbox{} \vfill
*/

// NOTE: transposed matrix has always rows sorted by increasing column indexes
//       even if original matrix has not
template<class T>
void
csr_seq_rep<T>::build_transpose (csr_seq_rep<T>& b) const
{
  b.resize (col_ownership(), row_ownership(), nnz()); 

  // first pass: set ib(*) to ib(0)
  iterator ib = b.begin();
  for (size_type j = 0, m = b.nrow(); j < m; j++) {
    ib[j+1] = ib[0];
  }
  // second pass: compute lengths of row(i) of a^T in ib(i+1)-ib(0)
  const_iterator ia = begin();
  for (size_type i = 0, n = nrow(); i < n; i++) {
    for (const_data_iterator p = ia[i], last_p = ia[i+1]; p != last_p; p++) {
      size_type j = (*p).first;
      ib [j+1]++;
    }
  }
  // third pass: compute pointers from lengths
  for (size_type j = 0, m = b.nrow(); j < m; j++) {
    ib [j+1] += (ib[j]-ib[0]);
  }
  // fourth pass: store values
  data_iterator q0 = ib[0];
  for (size_type i = 0, n = nrow(); i < n; i++) {
    for (const_data_iterator p = ia[i], last_p = ia[i+1]; p != last_p; p++) {
      size_type j = (*p).first;
      data_iterator q = ib[j];
      (*q).first  = i;
      (*q).second = (*p).second;
      ib[j]++;
    }
  }
  // fiveth: shift pointers
  for (long int j = b.nrow()-1; j >= 0; j--) {
    ib[j+1] = ib[j];
  }
  ib[0] = q0;
}
// ----------------------------------------------------------------------------
// instanciation in library
// ----------------------------------------------------------------------------
template class csr_seq_rep<Float>;
template void csr_seq_rep<Float>::assign_add (
	const csr_seq_rep<Float>&, const csr_seq_rep<Float>&, std::plus<Float>);
template void csr_seq_rep<Float>::assign_add (
	const csr_seq_rep<Float>&, const csr_seq_rep<Float>&, std::minus<Float>);
} // namespace rheolef
