///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
/// 
/// =========================================================================

#ifdef _RHEOLEF_HAVE_MPI
// Note: this file is recursively included by "polymorphic_array_mpi.cc" 
// with growing "N" and in the namespace "rheolef"

// function template partial specialization is not allowed: use class-function
template <class PointerArray, class Container1, class Container2, class Message, class Tag>
struct mpi_polymorphic_scatter_init_t <PointerArray,Container1,Container2,Message,Tag,N> {
  typedef typename Container1::size_type size_type;
  void operator() (
// input:
    const Container1&	idx,
    const Container2&	idy,
    const PointerArray& ptr,
    const distributor&  ownership,
    const Tag&          tag_header,
// output:
    Message&            from,
    Message&            to) const
  {
    const size_type _n_variant = N;
    const size_type unset = std::numeric_limits<size_type>::max();
    communicator comm = ownership.comm();
    size_type  my_proc = comm.rank();
    size_type  nproc   = comm.size();
    size_type  idy_maxval = ownership.dis_size();
    // ------------------------------------------------------- 
    // 1) first count number of contributors to each processor
    // ------------------------------------------------------- 
    std::vector<size_type> msg_size(nproc, 0);
    std::vector<size_type> msg_mark(nproc, 0);
    std::vector<size_type> owner   (idx.size(), unset);

    boost::array<size_type,_n_variant> send_total_size_poly;
    boost::array<size_type,_n_variant> send_nproc_poly;
    boost::array<std::vector<size_type>,_n_variant> msg_mark_poly;
    boost::array<std::vector<size_type>,_n_variant> msg_size_poly;
    boost::array<std::vector<size_type>,_n_variant> owner_poly;
    send_total_size_poly.assign (0);
    send_nproc_poly.assign (0);
    for (size_type k = 0; k < _n_variant; k++) {
      msg_mark_poly[k].resize (nproc);
      msg_size_poly[k].resize (nproc);
      owner_poly[k].resize (idx.size());
      std::fill (msg_mark_poly[k].begin(), msg_mark_poly[k].end(), 0);
      std::fill (msg_size_poly[k].begin(), msg_size_poly[k].end(), 0);
      std::fill (owner_poly[k].begin(),    owner_poly[k].end(),    unset);
    }
    size_type send_nproc = 0;
    {
      size_type iproc = 0;
      for (size_type i = 0, n = idx.size(); i < n; i++) {
        size_type k = ptr[i].variant(); // the variant k of polymorphic_array[i]
        for (; iproc < nproc; iproc++) {
          if (idx[i] >= ownership[iproc] && idx[i] < ownership[iproc+1]) {
	    owner[i] = iproc;
            owner_poly [k][send_total_size_poly[k]] = iproc;
            send_total_size_poly[k]++;
            msg_size [iproc]++;
            msg_size_poly [k][iproc]++;
            if (!msg_mark[iproc]) {
                 msg_mark[iproc] = 1;
                 send_nproc++;
            }
            if (!msg_mark_poly[k][iproc]) {
                 msg_mark_poly[k][iproc] = 1;
                 send_nproc_poly[k]++;
	    }
            break;
          }
        }
        assert_macro (iproc != nproc, "bad data: processor range error (0:nproc-1).");
      }
    } // end block
    // ------------------------------------------------------- 
    // 2) avoid to send message to my-proc in counting
    // ------------------------------------------------------- 
    size_type n_local  = msg_size[my_proc]; 
    if (n_local != 0) {
        msg_size [my_proc] = 0;
        msg_mark [my_proc] = 0;
        send_nproc--;
    }
    // ----------------------------------------------------------------
    // 3) compute number of messages to be send to my_proc
    // ----------------------------------------------------------------
    std::vector<size_type> work(nproc, unset);
    mpi::all_reduce (
	comm, 
        msg_mark.begin().operator->(),
	nproc,
	work.begin().operator->(),
	std::plus<size_type>());
    size_type receive_nproc = work [my_proc];
    // ----------------------------------------------------------------
    // 4) compute messages max size to be send to my_proc
    // ----------------------------------------------------------------
    mpi::all_reduce (
        comm,
        msg_size.begin().operator->(),
        nproc,
	work.begin().operator->(),
        mpi::maximum<size_type>());
    size_type receive_max_size = work [my_proc];
    // ----------------------------------------------------------------
    // 5) post receive: exchange the buffer adresses between processes
    // ----------------------------------------------------------------
    std::list<std::pair<size_type,mpi::request> > receive_waits;
    std::vector<std::pair<size_type,size_type> >  receive_data (
	receive_nproc*receive_max_size, std::make_pair(unset, unset)); // pair(idx,variant)
    for (size_type i_receive = 0; i_receive < receive_nproc; i_receive++) {
      mpi::request i_req = comm.irecv (
	  mpi::any_source,
	  tag_header,
          receive_data.begin().operator->() + i_receive*receive_max_size,
	  receive_max_size);
      receive_waits.push_back (std::make_pair(i_receive, i_req));
    }
    // ---------------------------------------------------------------------------
    // 6) compute the send indexes
    // ---------------------------------------------------------------------------
    // comme idx est trie, on peut faire une copie de idx dans send_data
    // et du coup owner et send_data_ownership sont inutiles
    // TODO: ajouter le type k de la variante
    std::vector<std::pair<size_type,size_type> >  send_data (idx.size(), std::make_pair(unset, unset)); // pair(idx,variant)
    for (size_type i = 0, n = idx.size(); i < n; i++) {
        send_data[i].first  = idx[i];
        send_data[i].second = ptr[i].variant(); // the variant k of polymorphic_array[i]
    }
    // ---------------------------------------------------------------------------
    // 7) do send
    // ---------------------------------------------------------------------------
    std::list<std::pair<size_type,mpi::request> > send_waits;
    {
      size_type i_send = 0;
      size_type i_start = 0;
      for (size_type iproc = 0; iproc < nproc; iproc++) {
        size_type i_msg_size = msg_size[iproc];
        if (i_msg_size == 0) continue;
        mpi::request i_req = comm.isend (
	    iproc,
	    tag_header, 
            send_data.begin().operator->() + i_start,  
            i_msg_size);
        send_waits.push_back(std::make_pair(i_send,i_req));
        i_send++;
        i_start += i_msg_size;
      }
    } // end block
    // ---------------------------------------------------------------------------
    // 8) wait on receives indexes
    // ---------------------------------------------------------------------------
    // note: for wait_all, build an iterator adapter that scan the pair.second in [index,request]
    // and then get an iterator in the pair using iter.base(): retrive the corresponding index
    // for computing the position in the receive.data buffer
    typedef boost::transform_iterator<select2nd<size_t,mpi::request>, std::list<std::pair<size_t,mpi::request> >::iterator>
            request_iterator;
    std::vector<size_type> receive_size (receive_nproc);
    std::vector<size_type> receive_proc (receive_nproc);
    size_type receive_total_size = 0;
    boost::array<size_type,_n_variant> receive_total_size_poly;
    boost::array<size_type,_n_variant> receive_nproc_poly;
    boost::array<std::vector<size_type>,_n_variant> receive_proc_poly;
    boost::array<std::vector<size_type>,_n_variant> receive_proc_mark_poly;
    boost::array<std::vector<size_type>,_n_variant> receive_data_poly;
    boost::array<std::vector<size_type>,_n_variant> receive_size_poly;
    for (size_type k = 0; k < _n_variant; k++) {
      receive_proc_mark_poly[k].resize (receive_nproc);
      receive_proc_poly[k].resize      (receive_nproc);
      receive_size_poly[k].resize      (receive_nproc);
      receive_data_poly[k].resize      (receive_max_size);
      std::fill (receive_proc_mark_poly[k].begin(), receive_proc_mark_poly[k].end(), 0);
      std::fill (receive_proc_poly[k].begin(),      receive_proc_poly[k].end(),      0);
      std::fill (receive_size_poly[k].begin(),      receive_size_poly[k].end(),      0);
      std::fill (receive_data_poly[k].begin(),      receive_data_poly[k].end(),      unset);
    }
    std::fill (receive_total_size_poly.begin(), receive_total_size_poly.end(), 0);
    std::fill (receive_nproc_poly.begin(),      receive_nproc_poly.end(),      0);
    while (receive_waits.size() != 0) {
        typedef std::pair<size_type,size_type> data_type; // exchanged data type
        request_iterator iter_r_waits (receive_waits.begin(), select2nd<size_t,mpi::request>()),
                         last_r_waits (receive_waits.end(),   select2nd<size_t,mpi::request>());
	// waits on any receive...
        std::pair<mpi::status,request_iterator> pair_status = mpi::wait_any (iter_r_waits, last_r_waits);
	// check status
	boost::optional<int> i_msg_size_opt = pair_status.first.count<data_type>();
	check_macro (i_msg_size_opt, "receive wait failed");
    	int iproc = pair_status.first.source();
	check_macro (iproc >= 0, "receive: source iproc = "<<iproc<<" < 0 !");
	// get size of receive and number in data
	size_type i_msg_size = (size_t)i_msg_size_opt.get();
        std::list<std::pair<size_t,mpi::request> >::iterator i_pair_ptr = pair_status.second.base();
        size_type i_receive = (*i_pair_ptr).first;
        receive_proc [i_receive] = iproc;
        receive_size [i_receive] = i_msg_size;
        receive_total_size += i_msg_size;
        receive_waits.erase (i_pair_ptr);
        size_type i_start = i_receive*receive_max_size;
        for (size_type j = 0; j < i_msg_size; j++) {
            size_type idx_j = receive_data [i_start + j].first;
            size_type k     = receive_data [i_start + j].second; // k=variant
            receive_data_poly[k][i_start + receive_size_poly[k][i_receive]++] = idx_j;
            receive_total_size_poly[k]++;
	    if (! receive_proc_mark_poly[k][i_receive]) {
	          receive_proc_mark_poly[k][i_receive] = 1;
        	  receive_proc_poly[k][receive_nproc_poly[k]] = iproc;
	          receive_nproc_poly[k]++;
	    }
        }
    }
    // ---------------------------------------------------------------------------
    // 9) allocate the entire send(to) scatter context
    // ---------------------------------------------------------------------------
    to.resize (receive_total_size_poly, receive_nproc_poly);
    // ---------------------------------------------------------------------------
    // 10) compute the permutation of values that gives the sorted source[] sequence
    // ---------------------------------------------------------------------------
    boost::array<std::vector<size_type>,_n_variant> perm;
    for (size_type k = 0; k < _n_variant; k++) {
      // init: perm[i] = i
      perm[k].resize (receive_nproc_poly[k]);
      copy(index_iterator<size_type>(), index_iterator<size_type>(receive_nproc_poly[k]), perm[k].begin());
      // compute perm such that receive_proc[perm[]] is sorted by increasing proc number
      sort_with_permutation (
        receive_nproc_poly[k],
        receive_proc_poly[k].begin().operator->(),
        perm[k].begin().operator->());
    }
    // ---------------------------------------------------------------------------
    // 11) Computes the receive compresed message pattern for send(to)
    // ---------------------------------------------------------------------------
    size_type first_dis_idx = ownership.first_index();
    for (size_type k = 0; k < _n_variant; k++) {
      msg_to_context (
        perm[k].begin(),
        perm[k].end(),
        receive_proc_poly[k].begin(),
        receive_size_poly[k].begin(),
        receive_data_poly[k].begin(),
        receive_max_size,  // TODO: not optimal, could depend upon k: receive_max_size_poly[k]
        first_dis_idx,
        to.procs[k].begin(),
        to.starts[k].begin(),
        to.indices[k].begin());
    }
    // ---------------------------------------------------------------------------
    // 12) allocate the entire receive(from) scatter context
    // ---------------------------------------------------------------------------
    from.resize(send_total_size_poly, send_nproc_poly);
    // ---------------------------------------------------------------------------
    // 13) Computes the receive compresed message pattern for receive(from)
    // ---------------------------------------------------------------------------
    boost::array<std::vector<size_type>,_n_variant> proc2from_proc;
    for (size_type k = 0; k < _n_variant; k++) {
      proc2from_proc[k].resize (nproc);
      msg_from_context_pattern (
        msg_size_poly[k].begin(),
        msg_size_poly[k].end(),
        from.procs[k].begin(),
        from.starts[k].begin(),
        proc2from_proc[k].begin());
    }
    // ---------------------------------------------------------------------------
    // 14) Computes the receive compresed message indices for receive(from)
    // ---------------------------------------------------------------------------
    // assume that indices are sorted by increasing order
    for (size_type k = 0; k < _n_variant; k++) {
      std::vector<size_type> start (from.starts[k].size());
      copy (from.starts[k].begin(), from.starts[k].end(), start.begin());
      msg_from_context_indices (
        owner_poly[k].begin(),
        owner_poly[k].begin() + send_total_size_poly[k],
        idy,
        proc2from_proc[k].begin(),
        my_proc,
        idy_maxval,
        start.begin(),
        from.indices[k].begin());
    }
    // ---------------------------------------------------------------------------
    // 15) wait on sends
    // ---------------------------------------------------------------------------
    request_iterator iter_s_waits (send_waits.begin(), select2nd<size_type,mpi::request>()),
                     last_s_waits (send_waits.end(),   select2nd<size_type,mpi::request>());
    mpi::wait_all (iter_s_waits, last_s_waits);
    // ---------------------------------------------------------------------------
    // 16) Computes the receive compresed message local pattern,
    // i.e. the only part that does not requires communication.
    // ---------------------------------------------------------------------------
    from.local_slots.resize(n_local);
    to.local_slots.resize(n_local);
    size_type last_dis_idx = ownership.last_index();
    msg_local_context (
	idx.begin(),
    	idx.end(),
    	idy.begin(),
    	idy_maxval,
        first_dis_idx,
    	last_dis_idx,
        to.local_slots.begin(),
        to.local_slots.end(),
        from.local_slots.begin());

    // ---------------------------------------------------------------------------
    // 17) Optimize local exchanges during gatter/scatter
    // ---------------------------------------------------------------------------
    bool has_opt = msg_local_optimize (
        to.local_slots.begin(),
        to.local_slots.end(),
        from.local_slots.begin());

    if (has_opt && n_local != 0) {
        to.local_is_copy       = true; 
        to.local_copy_start    = to.local_slots[0]; 
        to.local_copy_length   = n_local;
        from.local_is_copy     = true;
        from.local_copy_start  = from.local_slots[0];
        from.local_copy_length = n_local;
    }
  }
};
#endif // _RHEOLEF_HAVE_MPI
