///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
/// 
/// =========================================================================

#ifdef _RHEOLEF_HAVE_MPI
// Note: this file is recursively included by "polymorphic_array_mpi.cc" 
// with growing "N" and in the namespace "rheolef"

/*F:
NAME: mpi_polymorphic_scatter_begin -- gather/scatter start (@PACKAGE@ @VERSION@)
DESCRIPTION:
  Start communication for distributed to sequential scatter context.
IMPLEMENTATION NOTE:
  Same as the @code{mpi_scatter_begin} code, except a first pass that
  determines the size for each variant of the base type: all exchanges
  are performed each of the variant type, in some homogeneous non-polymorphic
  buffers, since this mode of exchange is faster with mpi.
AUTHORS:
    LMC-IMAG, 38041 Grenoble cedex 9, France
    | Pierre.Saramito@imag.fr
DATE:   23 march 1999
END:
*/

template <class Container, class Message, class Tag>
struct mpi_polymorphic_scatter_begin_global_t<Container,Message,Tag,N> {
  void operator() (
    const Container&           x,
    Message& 	               from,
    Message& 	               to,
    const Tag&                 tag_header,
    const boost::array<Tag,N>& tag,
    const communicator&        comm) const
  {
    typedef typename Container::size_type size_type;
    const size_type _n_variant = N;
    // ==========================================================
    // part 1. header: exchange sizes by variants
    // ==========================================================
    // 1.1) header post receives
    // ----------------------------------------------------------
    std::list<mpi::request> from_request_variant;
    std::vector<boost::array<size_type,_n_variant> > from_ndata_variant (from.n_proc());
    for (size_type i_recv = 0, n_recv = from_ndata_variant.size(); i_recv < n_recv; i_recv++) {
      mpi::request i_req = comm.irecv (
	    from.procs[i_recv],
	    tag_header,
	    from_ndata_variant[i_recv].begin(),
      	    _n_variant);
      from_request_variant.push_back (i_req);
    }
    // ----------------------------------------------------------
    // 1.2) prepare send: count variants
    // ----------------------------------------------------------
    std::vector<boost::array<size_type,_n_variant> > to_ndata_variant (to.n_proc());
    for (size_type i_send = 0, n_send = to_ndata_variant.size(); i_send < n_send; i_send++) {
      to_ndata_variant[i_send].assign(0);
      for (size_type i = to.starts[i_send], n = to.starts[i_send+1]; i < n; i++) {
        size_type k = x.variant (to.indices[i]);
	to_ndata_variant[i_send][k]++;
      }
    }
    // ----------------------------------------------------------
    // 1.3) do send
    // ----------------------------------------------------------
    std::list<mpi::request> to_request_variant;
    for (size_type i_send = 0, n_send = to_ndata_variant.size(); i_send < n_send; i_send++) {
      mpi::request i_req = comm.isend(
            to.procs [i_send],
	    tag_header,
    	    to_ndata_variant[i_send].begin(),
	    _n_variant); 
      to_request_variant.push_back (i_req);
    }
    // ----------------------------------------------------------
    // 1.4) wait on receive & send
    // ----------------------------------------------------------
    mpi::wait_all (from_request_variant.begin(), from_request_variant.end());
    mpi::wait_all (to_request_variant.begin(),   to_request_variant.end());
    from_request_variant.clear();
    to_request_variant.clear();
    // ----------------------------------------------------------
    // 1.5) for {from,to}: compute start indexes by variants
    //      and toral sizes by variants
    // ----------------------------------------------------------
    boost::array<size_type,_n_variant>               from_total_size_variant;
    boost::array<size_type,_n_variant>               to_total_size_variant;
    from.start_variant [0].assign(0);
    to.start_variant   [0].assign(0);
    from_total_size_variant.assign(0);
    to_total_size_variant.assign(0);
    for (size_type i_recv = 0, n_recv = from_ndata_variant.size(); i_recv < n_recv; i_recv++) {
      for (size_type k = 0; k < _n_variant; k++) {
        from.start_variant [i_recv+1][k] = from.start_variant [i_recv][k] + from_ndata_variant [i_recv][k];
        from_total_size_variant[k] += from_ndata_variant [i_recv][k];
      }
    }
    for (size_type i_send = 0, n_send = to_ndata_variant.size(); i_send < n_send; i_send++) {
      for (size_type k = 0; k < _n_variant; k++) {
        to.start_variant [i_send+1][k] = to.start_variant [i_send][k] + to_ndata_variant [i_send][k];
        to_total_size_variant[k] += to_ndata_variant [i_send][k];
      }
    }
    // ----------------------------------------------------------
    // 1.6) resize exactly the mpi {from,to} buffers
    // ----------------------------------------------------------
    from.values.resize (from_total_size_variant);
    to.values.resize   (to_total_size_variant);
    // ==========================================================
    // part 2. body: exchange values by variants
    // ==========================================================
    // 2.1) post receives
    // ----------------------------------------------------------
#define _RHEOLEF_post_recv(z,k,unused)                     		\
    from.requests[k].clear();						\
    for (size_type i_recv = 0, n_recv = from.n_proc(), ik_start = 0; i_recv < n_recv; i_recv++) { \
      size_type ik_size = from.start_variant[i_recv+1][k] - from.start_variant [i_recv][k]; \
      if (ik_size == 0) continue;					\
      mpi::request ik_req = comm.irecv(					\
            from.procs [i_recv],					\
	    tag[k],							\
    	    from.values._stack_##k.begin().operator->() + ik_start,	\
	    ik_size); 							\
        ik_start += ik_size;						\
        from.requests[k].push_back (std::make_pair(i_recv, ik_req));	\
    } 
    BOOST_PP_REPEAT(N, _RHEOLEF_post_recv, ~)
#undef _RHEOLEF_post_recv
    // ----------------------------------------------------------
    // 2.2) prepare send: apply right permutation
    // ----------------------------------------------------------
    for (size_type i_send = 0, n_send = to.n_proc(), i_start = 0; i_send < n_send; i_send++) {
      size_type i_size = to.starts[i_send+1] - to.starts[i_send];
      for (size_type i = i_start, n = i_start + i_size; i < n; i++) {
	to.values.set (i, i-i_start, x [to.indices[i]], x.variant (to.indices[i]));
      }
      i_start += i_size;
    }
    // ----------------------------------------------------------
    // 3) do sends
    // ----------------------------------------------------------
#define _RHEOLEF_pre_send(z,k,unused)                     		\
    to.requests[k].clear();						\
    for (size_type i_send = 0, n_send = to.n_proc(), ik_start = 0; i_send < n_send; i_send++) { \
      size_type ik_size = to.start_variant [i_send+1][k] - to.start_variant [i_send][k]; \
      if (ik_size == 0) continue;					\
      mpi::request ik_req = comm.isend(					\
            to.procs [i_send],						\
	    tag[k],							\
    	    to.values._stack_##k.begin().operator->() + ik_start,	\
	    ik_size); 							\
        ik_start += ik_size;						\
        to.requests[k].push_back (std::make_pair(i_send, ik_req));	\
    }
    BOOST_PP_REPEAT(N, _RHEOLEF_pre_send, ~)
#undef _RHEOLEF_pre_send
  }
};
#ifdef TODO
template <
    class InputIterator, 
    class OutputIterator, 
    class SetOp, 
    class T>
void
mpi_polymorphic_scatter_begin_local (
    InputIterator           		x,
    OutputIterator          		y,
    scatter_message<std::vector<T> >&   from,
    scatter_message<std::vector<T> >&   to,
    SetOp		    		op)
{
    msg_both_permutation_apply (
	to.local_slots.begin(),
	to.local_slots.end(),
	x,
	op,
	from.local_slots.begin(),
	y);
}
// take care of local insert: template specialisation
template <
    class InputIterator, 
    class OutputIterator, 
    class T>
void
mpi_polymorphic_scatter_begin_local (
    InputIterator           		x,
    OutputIterator          		y,
    scatter_message<std::vector<T> >&   from,
    scatter_message<std::vector<T> >&   to,
    set_op<T,T>		    		op)
{
    if (y == x && ! to.local_nonmatching_computed) {
        // scatter_local_optimize(to,from);
	fatal_macro ("y == x: adress matches in scatter: not yet -- sorry");
    } 
    if (to.local_is_copy) {

        std::copy(x + to.local_copy_start,
                  x + to.local_copy_start + to.local_copy_length,
                  y + from.local_copy_start);

    } else if (y != x || ! to.local_nonmatching_computed) {

        msg_both_permutation_apply (
	    to.local_slots.begin(),
	    to.local_slots.end(),
	    x,
	    op,
	    from.local_slots.begin(),
	    y);

    } else { // !to.local_is_copy && y == x && to.local_nonmatching_computed

	msg_both_permutation_apply (
	    to.local_slots_nonmatching.begin(),
	    to.local_slots_nonmatching.end(),
	    x,
	    op,
	    from.local_slots_nonmatching.begin(),
	    y);
    }
}
#endif // TODO
template <class Container, class Message, class Buffer, class Tag>
struct mpi_polymorphic_scatter_begin_t<Container,Message,Buffer,Tag,N> {
  void operator() (
    const Container&           x,
    Buffer&	               y,
    Message& 	               from,
    Message& 	               to,
    const Tag&                 tag_header,
    const boost::array<Tag,N>& tag,
    const communicator&        comm) const
  {
    mpi_polymorphic_scatter_begin_global_t<Container,Message,Tag,N>  scatter_begin_global_fct;
    scatter_begin_global_fct (x, from, to, tag_header, tag, comm);
    check_macro (to.n_local() == 0, "n_local = "<<to.n_local()<<" != 0 not yet supported ");
#ifdef TODO
    if (to.n_local() == 0) return;
    mpi_polymorphic_scatter_begin_local  (x, y, from, to, op);
#endif // TODO
  }
};
#endif // _RHEOLEF_HAVE_MPI
