/*           Copyright (C) 1999, 2000, 2001, 2002, 2003 Stijn van Dongen
 *
 * This file is part of MCL.  You can redistribute and/or modify MCL under the
 * terms of the GNU General Public License; either version 2 of the License or
 * (at your option) any later version.  You should have received a copy of the
 * GPL along with MCL, in the file COPYING.
*/

#include <ctype.h>
#include <math.h>
#include <stdlib.h>

#include "matrix.h"
#include "io.h"

#include "util/compile.h"
#include "util/alloc.h"
#include "util/types.h"
#include "util/err.h"
#include "util/io.h"


/* helper function */
/* should distinguish error condition */

static mcxbool idMap
(  mclx  *map
)
   {  int i
   ;  for (i=0;i<N_COLS(map);i++)
      {  if (map->cols[i].n_ivps != 1)
         return FALSE
      ;  if (map->cols[i].ivps[0].idx != map->dom_cols->ivps[i].idx)
         return FALSE
   ;  }
      return TRUE
;  }


/* helper function */

static mclVector* pmtVector
(  mclv  *dom
,  mclx  *map
,  mclpAR** ar_dompp
)
   {  mclpAR*  ar_dom = NULL
   ;  mclv*    new_dom_cols = NULL
   ;  int i
   ;  *ar_dompp = NULL

   ;  if (dom->n_ivps != map->dom_cols->n_ivps)
      {  mcxErr
         (  "mclxMapCheck"
         ,  "domains do not match (dom size %ld, map size %ld)"
         ,  (long) dom->n_ivps
         ,  (long) map->dom_cols->n_ivps
         )
      ;  goto fail
   ;  }
      ar_dom = mclpARresize(NULL, dom->n_ivps)

   ;  for (i=0;i<N_COLS(map);i++)
      {  if (map->cols[i].n_ivps != 1)
         {  mcxErr("mclxMapCheck", "not a mapping matrix")
         ;  goto fail
      ;  }
         ar_dom->ivps[i].idx = map->cols[i].ivps[0].idx
   ;  }

      new_dom_cols = mclvFromIvps(NULL, ar_dom->ivps, ar_dom->n_ivps)

   ;  if (new_dom_cols->n_ivps != ar_dom->n_ivps)
      {  mcxErr("mclxMapCheck", "map is not bijective")
      ;  goto fail
   ;  }

      *ar_dompp = ar_dom

   ;  if (0)
      {  fail
      :  mclvFree(&new_dom_cols)
      ;  mclpARfree(&ar_dom)
   ;  }
      return new_dom_cols
;  }


mcxstatus mclxMapCols
(  mclMatrix  *mx
,  mclMatrix  *map
)
   {  mclVector* new_dom_cols = NULL
   ;  mclpAR     *ar_dom = NULL
   ;  int i

   ;  if (map && idMap(map))
      return STATUS_OK

   ;  if (map && !(new_dom_cols = pmtVector(mx->dom_cols, map, &ar_dom)))
      return STATUS_FAIL
   ;  else if (!map)
      new_dom_cols = mclvCanonical(NULL, N_COLS(mx), 1.0)

   ;  for (i=0; i<N_COLS(mx); i++)
      mx->cols[i].vid = ar_dom ? ar_dom->ivps[i].idx : i

   ;  qsort(mx->cols, N_COLS(mx), sizeof(mclVector), mclvVidCmp)

   ;  mclvFree(&(mx->dom_cols))
   ;  mx->dom_cols = new_dom_cols
   ;  mclpARfree(&ar_dom)

   ;  return STATUS_OK
;  }


mcxstatus  mclxMapRows
(  mclMatrix  *mx
,  mclMatrix  *map
)
#if 1
   {  mclVector* new_dom_rows
   ;  mclVector* vec = mx->cols
   ;  mclpAR* ar_dom = NULL

   ;  if (map && idMap(map))
      return STATUS_OK

   ;  if (map && !(new_dom_rows = pmtVector(mx->dom_rows, map, &ar_dom)))
      return STATUS_FAIL
   ;  else if (!map)
      new_dom_rows = mclvCanonical(NULL, N_COLS(mx), 1.0)

   ;  while (vec < mx->cols + N_COLS(mx))
      {  mclIvp* rowivp    =  vec->ivps
      ;  mclIvp* rowivpmax =  rowivp + vec->n_ivps
      ;  int offset = -1
      
      ;  while (rowivp < rowivpmax)
         {  offset  =  mclvGetIvpOffset(mx->dom_rows, rowivp->idx, offset)
         ;  if (offset < 0)
               mcxErr
               (  "mclxMapRows PANIC"
               ,  "index <%ld> not in domain for <%ldx%ld> matrix"
               ,  (long) rowivp->idx
               ,  (long) N_COLS(mx)
               ,  N_ROWS(mx)
               )
            ,  mcxExit(1)
         ;  else
            rowivp->idx = ar_dom ? ar_dom->ivps[offset].idx : offset
         ;  rowivp++
      ;  }
         mclvSort(vec, mclpIdxCmp)
      ;  vec++
   ;  }
      
      mclvFree(&(mx->dom_rows))
   ;  mclpARfree(&ar_dom)
   ;  mx->dom_rows = new_dom_rows
   ;  return STATUS_OK
;  }
#else
   {  mclMatrix* tp = mclxTranspose(mx), *tptp
   ;  int i
   ;  if (mclxMapCols(tp, ar_dom) != STATUS_OK)
      {  mclxFree(&tp)
      ;  return STATUS_FAIL
   ;  }

      for (i=0;i<N_COLS(mx);i++)
      mcxFree(mx->cols[i].ivps)        /* release columns */
   ;  mcxFree(mx->cols)                /* free column array */
   ;  mclvFree(&(mx->dom_rows))        /* free row dom, keep col dom */
                                       /* KEEP root object */

   ;  tptp = mclxTranspose(tp)
   ;  mclxFree(&tp)

   ;  mx->cols = tptp->cols
   ;  mx->dom_rows = tptp->dom_rows    /* copy row dom, kept col dom */
   ;  mclvFree(&(tptp->dom_cols))      /* free col dom, not needed */
   ;  mcxFree(tptp)                   /* free root object */
   ;  return STATUS_OK
;  }
#endif


void mclxInflate
(  mclMatrix*   mx
,  double       power
)
   {  mclVector*     vecPtr          =     mx->cols
   ;  mclVector*     vecPtrMax       =     vecPtr + N_COLS(mx)

   ;  while (vecPtr < vecPtrMax)
      {  mclvInflate(vecPtr, power)
      ;  vecPtr++
   ;  }
   }


mclMatrix* mclxAllocZero
(  mclVector * dom_cols
,  mclVector * dom_rows
)
   {  int i
   ;  int n_cols  = dom_cols->n_ivps
   ;  mclMatrix   *dst     =  (mclMatrix*) mcxAlloc
                              (  sizeof(mclMatrix)
                              ,  EXIT_ON_FAIL
                              )
   ;  if (!dom_cols || !dom_rows)
      {  mcxErr("mclxAllocZero PBD", "missing domain(s)")
      ;  mcxExit(1)
   ;  }

      dst->cols   =  (mclVector*) mcxAlloc
                     (  n_cols * sizeof(mclVector)
                     ,  RETURN_ON_FAIL
                     )
   ;  if (!dst->cols && n_cols)
         mcxMemDenied(stderr, "mclxAllocZero", "mclVector", n_cols)
      ,  mcxExit(1)

   ;  dst->dom_cols  =  dom_cols
   ;  dst->dom_rows  =  dom_rows

   ;  for (i=0; i<n_cols; i++)
      {  mclVector* vec = dst->cols+i
      ;  mclvInit(vec)
      ;  vec->vid    =  dom_cols->ivps[i].idx
   ;  }

      return dst
;  }


mclMatrix* mclxCartesian
(  mclVector*     dom_cols
,  mclVector*     dom_rows
,  double         val
)
   {  int i
   ;  mclMatrix*  rect  =  mclxAllocZero(dom_cols, dom_rows)

   ;  for(i=0;i<N_COLS(rect);i++)
      {  mclvCopy(rect->cols+i, dom_rows)
      ;  mclvMakeConstant(rect->cols+i, val)
   ;  }
      return rect
;  }


mclMatrix*  mclxSub
(  const mclMatrix*  mx
,  const mclVector*  colSelect
,  const mclVector*  rowSelect
)
   {  mclIvp   *domivp, *domivpMax
   ;  mclVector *subvec = NULL, *mxvec = NULL
   ;  mclMatrix*  sub

   ;  mclVector*dom_cols= mclvCopy(NULL, colSelect)
   ;  mclVector*dom_rows= mclvCopy(NULL, rowSelect)
   ;  mcxbool     ok    =  TRUE

  /*  fixme: perhaps check for domain identity, and do nothing if equal */

   ;  if (!mcldEquate(dom_cols, mx->dom_cols, MCL_DOM_SUB))
  /*  nothing, fixme: optional warning? */
   ;  if (!mcldEquate(dom_rows, mx->dom_rows, MCL_DOM_SUB))
  /*  nothing, fixme: optional warning?*/

   ;  sub         =  mclxAllocZero(dom_cols, dom_rows) 
   ;  domivp      =  dom_cols->ivps;
   ;  domivpMax   =  domivp+dom_cols->n_ivps;

  /* fixme; vector check needed, domain must be conforming */

   ;  if (!ok)
      {  mclxFree(&sub)
      ;  return NULL
   ;  }

      while (domivp<domivpMax)
      {  subvec  =  mclxGetVector(sub, domivp->idx, EXIT_ON_FAIL, subvec)
      ;  mxvec   =  mclxGetVector(mx, domivp->idx, RETURN_ON_FAIL, mxvec)
      ;  if (mxvec)
         {  mcldMeet(mxvec, dom_rows, subvec)
         ;  mxvec++
      ;  }
         domivp++
      ;  subvec++
   ;  }

      return sub
;  }


double mclxSelectValues
(  mclMatrix*  mx
,  double*     lft
,  double*     rgt
,  mcxbits     equate
)
   {  long c
   ;  double sum = 0.0
   ;  for (c=0;c<N_COLS(mx);c++)
      sum += mclvSelectValues(mx->cols+c, lft, rgt, equate)
   ;  return sum
;  }


mclMatrix* mclxConstDiag
(  mclVector* vec
,  double c
)
   {  mclMatrix*  m = mclxDiag(vec)
   ;  mclxUnary(m, fltConstant, &c)
   ;  return m
;  }


mclMatrix* mclxDiag
(  mclVector* vec
)
   {  mclMatrix*  m = mclxAllocZero(vec, mclvCopy(NULL, vec))
   ;  int i

   ;  for(i=0;i<N_COLS(m);i++)
      mclvInsertIdx(m->cols+i, vec->ivps[i].idx, vec->ivps[i].val)
   ;  return m
;  }


mclMatrix* mclxCopy
(  const mclMatrix*     src
)
   {  int         n_cols  =   N_COLS(src)
   ;  mclMatrix*  dst     =   mclxAllocZero
                              (  mclvCopy(NULL, src->dom_cols)
                              ,  mclvCopy(NULL, src->dom_rows)
                              )
   ;  const mclVector* src_vec =  src->cols
   ;  mclVector* dst_vec  =  dst->cols

   ;  while (--n_cols >= 0)
      {  if (!mclvRenew(dst_vec, src_vec->ivps, src_vec->n_ivps))
         {  mclxFree(&dst)
         ;  break
      ;  }
         src_vec++
      ;  dst_vec++
   ;  }
      return dst
;  }


void mclxFree
(  mclMatrix**             mxpp
)
   {  mclMatrix* mx = *mxpp
   ;  if (mx)
      {  mclVector*  vec      =  mx->cols
      ;  int         n_cols   =  N_COLS(mx)

      ;  while (--n_cols >= 0)
         {  mcxFree(vec->ivps)
         ;  vec++
      ;  }

         mclvFree(&(mx->dom_rows))
      ;  mclvFree(&(mx->dom_cols))

      ;  mcxFree(mx->cols)
      ;  mcxFree(mx)

      ;  *mxpp = NULL
   ;  }
   }


void mclxMakeStochastic
(  mclMatrix* mx
)  
   {  mclVector* vecPtr    =  mx->cols
   ;  mclVector* vecPtrMax =  vecPtr + N_COLS(mx)

   ;  while (vecPtr < vecPtrMax)
         mclvNormalize(vecPtr)
      ,  vecPtr++
;  }


void mclxMakeSparse
(  mclMatrix* m
,  int        maxDensity
)
   {  int  n_cols    =  N_COLS(m)
   ;  mclVector* vec =  m->cols

   ;  while (--n_cols >= 0)
      {  mclvSelectHighest(vec, maxDensity)
      ;  mclvSort(vec, NULL)
      ;  ++vec
   ;  }
   }


void mclxUnary
(  mclMatrix*  src
,  double (*operation)(pval, void*)
,  void* arg
)
   {  int         n_cols =  N_COLS(src)
   ;  mclVector*  vec    =  src->cols

   ;  while (--n_cols >= 0)
         mclvUnary(vec, operation, arg)
      ,  vec++
;  }


mclMatrix* mclxBinary
(  const mclMatrix* m1
,  const mclMatrix* m2
,  double  (*operation)(pval, pval)
)
   {  mclVector *dom_rows     =  mcldMerge
                                 (  m1->dom_rows
                                 ,  m2->dom_rows
                                 ,  NULL
                                 )
   ;  mclVector *dom_cols     =  mcldMerge
                                 (  m1->dom_cols
                                 ,  m2->dom_cols
                                 ,  NULL
                                 )
   ;  mclMatrix*  m3          =  mclxAllocZero(dom_cols, dom_rows)
   ;  mclVector  *dstvec      =  m3->cols 
   ;  mclVector  *m1vec       =  m1->cols
   ;  mclVector  *m2vec       =  m2->cols
   ;  mclVector  empvec

   ;  mclvInit(&empvec)

   ;  while (dstvec < m3->cols + N_COLS(m3))
      {  m1vec = mclxGetVector(m1, dstvec->vid, RETURN_ON_FAIL, m1vec)
      ;  m2vec = mclxGetVector(m2, dstvec->vid, RETURN_ON_FAIL, m2vec)

      ;  if
         (  !mclvBinary
            (  m1vec ? m1vec : &empvec
            ,  m2vec ? m2vec : &empvec
            ,  dstvec
            ,  operation
            )
         )
         {  mclxFree(&m3)
         ;  break
      ;  }
         dstvec++
      ;  if (m1vec)
         m1vec++
      ;  if (m2vec)
         m2vec++
   ;  }

      return m3
;  }


int mclxGetVectorOffset
(  const mclMatrix* mx
,  long  vid
,  mcxOnFail ON_FAIL
,  long  offset
)
   {  mclVector* vec =  mclxGetVector
                        (  mx
                        ,  vid
                        ,  ON_FAIL
                        ,  offset > 0 ? mx->cols+offset : NULL
                        )
   ;  return vec ? vec - mx->cols : -1
;  }


mclVector* mclxGetNextVector
(  const mclMatrix* mx
,  long   vid
,  mcxOnFail ON_FAIL
,  const mclVector* offset
)
   {  const mclVector* max =  mx->cols + N_COLS(mx)

   ;  if (!offset)
      offset = mx->cols

   ;  while (offset < max)
      {  if (offset->vid >= vid)
         break
      ;  else
         offset++
   ;  }
      if (offset >= max || offset->vid > vid)
      {  if (ON_FAIL == RETURN_ON_FAIL)
         return NULL
      ;  else
            mcxErr
            (  "mclxGetNextVector PBD"
            ,  "did not find vector <%ld> in <%ld,%ld> matrix"
            ,  (long) vid
            ,  N_COLS(mx)
            ,  N_ROWS(mx)
            )
         ,  mcxExit(1)
   ;  }
      else
      return (mclVector*) offset
   ;  return NULL
;  }


mclVector* mclxGetVector
(  const mclMatrix* mx
,  long   vid
,  mcxOnFail ON_FAIL
,  const mclVector* offset
)
   {  long n_cols = N_COLS(mx)
   ;  mclVector* found = NULL

   ;  if (mx->cols[0].vid == 0 && mx->cols[n_cols-1].vid == n_cols-1)
      {  if (vid >= 0 && vid <= n_cols-1)
         found = mx->cols+vid
      ;  else
         found = NULL
   ;  }
      else
      {  mclVector keyvec
      ;  mclvInit(&keyvec)
      ;  keyvec.vid = vid

      ;  if (!offset)
         offset = mx->cols

      ;  n_cols -= (offset - mx->cols)
      ;  found =  bsearch
                  (  &keyvec
                  ,  offset
                  ,  n_cols
                  ,  sizeof(mclVector)
                  ,  mclvVidCmp
                  )
   ;  }

      if (!found && ON_FAIL == EXIT_ON_FAIL)
         mcxErr
         (  "mclxGetVector PBD"
         ,  "did not find vector <%ld> in <%ld,%ld> matrix"
         ,  (long) vid
         ,  (long) N_COLS(mx)
         ,  (long) N_ROWS(mx)
         )
      ,  mcxExit(1)

   ;  return found
;  }


mclx* mclxMakeMap
(  mclVector*  dom_cols
,  mclVector*  new_dom_cols
)
   {  mclx* mx
   ;  int i

   ;  if (dom_cols->n_ivps != new_dom_cols->n_ivps)
      return NULL

   ;  mx = mclxAllocZero(dom_cols, new_dom_cols)

   ;  for (i=0;i<N_COLS(mx);i++)
      mclvInsertIdx(mx->cols+i, new_dom_cols->ivps[i].idx, 1.0)

   ;  return mx
;  }


mclMatrix* mclxTranspose
(  const mclMatrix*  src
)
   {  mclMatrix*   dst  =  mclxAllocZero
                           (  mclvCopy(NULL, src->dom_rows)
                           ,  mclvCopy(NULL, src->dom_cols)
                           )
   ;  const mclVector*  src_vec  =  src->cols
   ;  int               vec_ind  =  N_COLS(src)
   ;  mclVector*        dst_vec
   ;

      /*
       * Pre-calculate sizes of destination columns
       * fixme; if canonical domains do away with mclxGetVector.
      */
      while (--vec_ind >= 0)
      {  int   src_n_ivps  =  src_vec->n_ivps
      ;  mclIvp*  src_ivp  =  src_vec->ivps
      ;  dst_vec           =  dst->cols

      ;  while (--src_n_ivps >= 0)
         {  dst_vec = mclxGetVector(dst, src_ivp->idx, EXIT_ON_FAIL, dst_vec)
         ;  dst_vec->n_ivps++
         ;  src_ivp++
         ;  dst_vec++   /* with luck we get immediate hit */
      ;  }
         src_vec++
   ;  }

      /*
       * Allocate
      */
      dst_vec     =  dst->cols
   ;  vec_ind     =  N_COLS(dst)
   ;  while (--vec_ind >= 0)
      {  if (!mclvResize(dst_vec, dst_vec->n_ivps))
         {  mclxFree(&dst)
         ;  return 0
      ;  }
         dst_vec->n_ivps = 0    /* dirty: start over for write */
      ;  dst_vec++
   ;  }

      /*
       * Write
       *
      */
      src_vec     =  src->cols
   ;  while (src_vec < src->cols+N_COLS(src))
      {  int   src_n_ivps  =  src_vec->n_ivps
      ;  mclIvp* src_ivp   =  src_vec->ivps
      ;  dst_vec           =  dst->cols

      ;  while (--src_n_ivps >= 0)
         {  dst_vec = mclxGetVector(dst, src_ivp->idx, EXIT_ON_FAIL, dst_vec)
         ;  dst_vec->ivps[dst_vec->n_ivps].idx = src_vec->vid
         ;  dst_vec->ivps[dst_vec->n_ivps].val = src_ivp->val
         ;  dst_vec->n_ivps++
         ;  dst_vec++
         ;  src_ivp++
      ;  }
         src_vec++
   ;  }

      return dst
;  }


mclVector* mclxColNums
(  const mclMatrix*  m
,  double           (*f)(const mclVector * vec)
,  mcxenum           mode  
)
   {  mclVector*  nums =  mclvResize(NULL, N_COLS(m))
   ;  int vec_ind =  0
   ;  int ivp_idx =  0
   
   ;  if (nums)
      {  while (vec_ind < N_COLS(m))
         {  double val =  f(m->cols + vec_ind)
         ;  if (val || mode == MCL_VECTOR_COMPLETE)
            mclpInstantiate(nums->ivps + (ivp_idx++), vec_ind, val)
         ;  vec_ind++
      ;  }
         mclvResize(nums, ivp_idx)
   ;  }
      return nums
;  }


mclVector* mclxColSums
(  const mclMatrix*  m
,  mcxenum     mode  
)
   {  return mclxColNums(m, mclvSum, mode)
;  }


mclVector* mclxColSizes
(  const mclMatrix*     m
,  mcxenum        mode
)
   {  return mclxColNums(m, mclvSize, mode)
;  }


double mclxSubMass
(  const mclMatrix*     m
,  const mclVector*     colSelect
,  const mclVector*     rowSelect
)
   {  mclMatrix   *sub  =  mclxSub(m, colSelect, rowSelect)
  /*  fixme mem hog; can use mclvCountGiven with idx identity I believe */
   ;  double      mass  =  mclxMass(sub)
   ;  mclxFree(&sub)
   ;  return mass
;  }


double mclxMass
(  const mclMatrix*     m
)
   {  int   c
   ;  double  mass  =  0
   ;  for (c=0;c<N_COLS(m);c++)
      mass += mclvSum(m->cols+c)
   ;  return mass
;  }


long mclxSubNrofEntries
(  const mclMatrix*     m
,  const mclVector*     colSelect
,  const mclVector*     rowSelect
)
   {  mclMatrix   *sub  =  mclxSub(m, colSelect, rowSelect)
  /* fixme mem hog; can use mclvCountGiven with idx identity I believe */
   ;  long         nr    =  sub ? mclxNrofEntries(sub) : 0
   ;  mclxFree(&sub)
   ;  return nr
;  }


long mclxNrofEntries
(  const mclMatrix*     m
)
   {  int  c
   ;  long nr =  0
   ;  for (c=0;c<N_COLS(m);c++)
      nr += (m->cols+c)->n_ivps
   ;  return nr
;  }


void  mclxColumnsRealign
(  mclMatrix* m
,  int (*cmp)(const void* vec1, const void* vec2)
)
   {  int i
   ;  qsort(m->cols, N_COLS(m), sizeof(mclVector), cmp)
   ;  for (i=0;i<m->dom_cols->n_ivps;i++)
      m->cols[i].vid = m->dom_cols->ivps[i].idx
;  }


double mclxMaxValue
(  const mclMatrix*        mx
) 
   {  double max_val  =  0.0
   ;  mclxUnary((mclMatrix*)mx, fltPropagateMax, &max_val)
   ;  return max_val
;  }


mclMatrix* mclxIdentity
(  mclVector* vec
)  
   {  return mclxConstDiag(vec, 1.0)
;  }


void mclxScale
(  const mclMatrix*  mx
,  double   f
) 
   {  mclxUnary((mclMatrix*)mx, fltScale, &f)
;  }


void mclxHdp
(  mclMatrix*  mx
,  double  power
)  
   {  mclxUnary(mx, fltPower, &power)
;  }


#if 0    /* these are now macros in matrix.h */
mcxbool mclxRowCanonical
(  const mclMatrix*        mx
)
   {  return mcldIsCanonical(mx->dom_rows)
;  }
mcxbool mclxColCanonical
(  const mclMatrix*        mx
)
   {  return mcldIsCanonical(mx->dom_cols)
;  }
#endif


void mclxMakeCharacteristic
(  mclMatrix*              mx
)  
   {  double one  =  1.0
   ;  mclxUnary(mx, fltConstant, &one)
;  }


mclMatrix* mclxMax
(  const mclMatrix*        m1
,  const mclMatrix*        m2
)  
   {  return mclxBinary(m1, m2, fltMax)
;  }


mclMatrix* mclxAdd
(  const mclMatrix*        m1
,  const mclMatrix*        m2
)  
   {  return mclxBinary(m1, m2, fltAdd)
;  }


mclMatrix* mclxHadamard
(  const mclMatrix*        m1
,  const mclMatrix*        m2
)
   {  return mclxBinary(m1, m2, fltMultiply)
;  }


