/*            Copyright (C) 2001, 2002, 2003 Stijn van Dongen
 *
 * This file is part of MCL.  You can redistribute and/or modify MCL under the
 * terms of the GNU General Public License; either version 2 of the License or
 * (at your option) any later version.  You should have received a copy of the
 * GPL along with MCL, in the file COPYING.
*/


#include <string.h>
#include <stdio.h>
#include <math.h>
#include <float.h>

#include "impala/compose.h"
#include "impala/matrix.h"
#include "impala/vector.h"
#include "impala/pval.h"
#include "impala/io.h"
#include "impala/iface.h"

#include "util/io.h"
#include "util/err.h"
#include "util/minmax.h"
#include "util/opt.h"
#include "util/types.h"

const char* me = "mcxarray";

const char* usagelines[] =
{  "Usage: mcxarray [options] <array data matrix>"
,  ""
,  "Options:"
,  "-o <fname>       write to file fname"
,  "-co <num>        remove inproduct (output) values smaller than num"
,  "-gq <num>        ignore data (input) values smaller than num"
,  "-lq <num>        ignore data (input) values larger than num"
,  "-t               work with the transpose"
,  "-c               work with the cosine"
,  "-p               work with Pearson correlation score"
,  "-pi <num>        inflate the input columns"
,  "-tpi <num>       inflate the tranposed columns"
,  NULL
}  ;



void mclxAddTranspose
(  mclx* mx
)
/* todo: check symmetry (?) */
   {  long c
   ;  mclx* mxt = mclxTranspose(mx)
   ;  for (c=0;c<N_COLS(mx);c++)
      {  mclvAdd(mx->cols+c, mxt->cols+c, mx->cols+c)
      ;  mclvRelease(mxt->cols+c)
   ;  }
      mclxFree(&mxt)
;  }


int main
(  int                  argc
,  const char*          argv[]
)
   {  int a = 1, c, d
   ;  int digits = 8
   ;  double cutoff = 0.4, pi = 0.0, tpi = 0.0, lq = DBL_MAX, gq = -DBL_MAX
   ;  mclx* tbl, *res
   ;  mcxIO* xfin, *xfout
   ;  mclv* ssqs, *sums, *scratch
   ;  mcxbool transpose = FALSE
   ;  const char* out = "array"
   ;  mcxbool mode = 'p'
   ;  int n_mod

   ;  while (a<argc)
      {  if (!strcmp(argv[a], "-h"))
         {  mcxUsage(stdout, me, usagelines)
         ;  return 0
      ;  }
         else if (!strcmp(argv[a], "-t"))
         transpose = TRUE
      ;  else if (!strcmp(argv[a], "-p"))
         mode = 'p'
      ;  else if (!strcmp(argv[a], "-c"))
         mode = 'c'
      ;  else if (!strcmp(argv[a], "-cutoff") || !strcmp(argv[a], "-co"))
         {  if (a++ + 1 < argc)
            cutoff = atof(argv[a])
         ;  else goto arg_missing
      ;  }
         else if (!strcmp(argv[a], "-lq"))
         {  if (a++ + 1 < argc)
            lq = atof(argv[a])
         ;  else goto arg_missing
      ;  }
         else if (!strcmp(argv[a], "-gq"))
         {  if (a++ + 1 < argc)
            gq = atof(argv[a])
         ;  else goto arg_missing
      ;  }
         else if (!strcmp(argv[a], "-o"))
         {  if (a++ + 1 < argc)
            out = argv[a]
         ;  else goto arg_missing
      ;  }
         else if (!strcmp(argv[a], "-pi"))
         {  if (a++ + 1 < argc)
            pi = atof(argv[a])
         ;  else goto arg_missing
      ;  }
         else if (!strcmp(argv[a], "-tpi"))
         {  if (a++ + 1 < argc)
            tpi = atof(argv[a])
         ;  else goto arg_missing
      ;  }
         else if (!strcmp(argv[a], "-digits"))
         {  if (a++ + 1 < argc)
            digits = strtol(argv[a], NULL, 10)
         ;  else goto arg_missing
      ;  }
         else if (0)
         {  arg_missing
         :  mcxTell(me, "flag <%s> needs argument; see help (-h)", argv[argc-1])
         ;  mcxExit(1)
      ;  }
         else if (a == argc-1)
         break
      ;  else
         {  mcxErr(me, "not an option: <%s>", argv[a])
         ;  return 1
      ;  }
         a++
   ;  }

      xfin = mcxIOnew(argv[argc-1], "r")
   ;  xfout = mcxIOnew(out, "w")
   ;  mcxIOopen(xfin, EXIT_ON_FAIL)
   ;  mcxIOopen(xfout, EXIT_ON_FAIL)
   ;  tbl = mclxRead(xfin, EXIT_ON_FAIL)

   ;  if (lq < DBL_MAX)
      {  double mass = mclxMass(tbl)
      ;  double kept = mclxSelectValues(tbl, NULL, &lq, MCLX_LQ)
      ;  fprintf(stderr, "orig %.2f kept %.2f\n", mass, kept)
   ;  }

      if (gq > -DBL_MAX)
      {  double mass = mclxMass(tbl)
      ;  double kept = mclxSelectValues(tbl, &gq, NULL, MCLX_GQ)
      ;  fprintf(stderr, "orig %.2f kept %.2f\n", mass, kept)
   ;  }

      if (pi)
      mclxInflate(tbl, pi)

   ;  if (transpose)
      {  mclx* tblt = mclxTranspose(tbl)
      ;  mclxFree(&tbl)
      ;  tbl = tblt
      ;  if (tpi)
         mclxInflate(tbl, tpi)
   ;  }

      ssqs = mclvCopy(NULL, tbl->dom_cols)
   ;  sums = mclvCopy(NULL, tbl->dom_cols)
   ;  scratch = mclvCopy(NULL, tbl->dom_cols)

   ;  for (c=0;c<N_COLS(tbl);c++)
      {  double sumsq = mclvPowSum(tbl->cols+c, 2.0)
      ;  double sum = mclvSum(tbl->cols+c)
      ;  ssqs->ivps[c].val = sumsq
      ;  sums->ivps[c].val = sum
   ;  }

      res   =
      mclxAllocZero
      (  mclvCopy(NULL, tbl->dom_cols)
      ,  mclvCopy(NULL, tbl->dom_cols)
      )

   ;  n_mod =  MAX(1+(N_COLS(tbl)-1)/40, 1)

   ;  {  double N  = MAX(N_ROWS(tbl), 1)

      ;  for (c=0;c<N_COLS(tbl);c++)
         {  mclvZeroValues(scratch)
         ;  for (d=c;d<N_COLS(tbl);d++)
            {  double ip = mclvIn(tbl->cols+c, tbl->cols+d)
            ;  double score = 0.0
            ;  if (mode == 'c')
               {  double nom = sqrt(ssqs->ivps[c].val  * ssqs->ivps[d].val)
               ;  score = nom ? ip / nom : 0.0
            ;  }
               else if (mode == 'p')
               {  double s1 = sums->ivps[c].val
               ;  double sq1= ssqs->ivps[c].val
               ;  double s2 = sums->ivps[d].val
               ;  double sq2= ssqs->ivps[d].val
               ;  double nom= sqrt((sq1 - s1*s1/N) * (sq2 - s2*s2/N))

               ;  double num= ip - s1*s2/N
               ;  double f1 = sq1 - s1*s1/N
               ;  double f2 = sq2 - s2*s2/N
               ;  score = nom ? (num / nom) : 0.0
;if (0) fprintf(stderr, "--%.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f\n", s1, sq1, s2, sq2, f1, f2, num, nom)
            ;  }
               if (score > cutoff)
               scratch->ivps[d].val = score
         ;  }
            mclvAdd(scratch, res->cols+c, res->cols+c)
         ;  if ((c+1) % n_mod == 0)
               fputc('.', stderr)
            ,  fflush(NULL)
      ;  }
      }

      mclxAddTranspose(res)

   ;  mclxWriteAscii(res, xfout, digits, EXIT_ON_FAIL)
   ;  return 0
;  }

