/*
    Theseus - maximum likelihood superpositioning of macromolecular structures

    Copyright (C) 2004-2010 Douglas L. Theobald

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the:

    Free Software Foundation, Inc.,
    59 Temple Place, Suite 330,
    Boston, MA  02111-1307  USA

    -/_|:|_|_\-
*/

#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "pdbUtils.h"
#include "pdbStats.h"
#include "CovMat.h"
#include "DLTmath.h"
#include "ProcGSLSVDOcc.h"


static double
CalcInnProdNormOcc(const Cds *cds)
{   
    int             i;
    double          sum;
    const double   *x = (const double *) cds->x,
                   *y = (const double *) cds->y,
                   *z = (const double *) cds->z;
    const double   *o = (const double *) cds->o;
    double          xi, yi, zi;

    sum = 0.0;
    i = cds->vlen;
    while(i-- > 0)
    {
        xi = *x++;
        yi = *y++;
        zi = *z++;

        sum += *o++ * (xi * xi + yi * yi + zi * zi);
    }

    return(sum);
}


static double
CalcE0Occ(const Cds *cds1, const Cds *cds2,
         const double *weights)
{   
    int             i;
    double          sum;
    const double   *x1 = (const double *) cds1->x,
                   *y1 = (const double *) cds1->y,
                   *z1 = (const double *) cds1->z;
    const double   *x2 = (const double *) cds2->x,
                   *y2 = (const double *) cds2->y,
                   *z2 = (const double *) cds2->z;
    const double   *o1 = (const double *) cds1->o,
                   *o2 = (const double *) cds2->o;
    double          x1i, y1i, z1i, x2i, y2i, z2i, weight;

    sum = 0.0;
    i = cds1->vlen;
    while(i-- > 0)
    {
        x1i = *x1++;
        y1i = *y1++;
        z1i = *z1++;
        x2i = *x2++;
        y2i = *y2++;
        z2i = *z2++;
        weight = *weights++ * *o1++ * *o2++;
        sum +=  weight *
               ((x1i * x1i + x2i * x2i) +
                (y1i * y1i + y2i * y2i) +
                (z1i * z1i + z2i * z2i));
    }

    return(sum);
}


static double
CalcE0CovOcc(const Cds *cds1, const Cds *cds2)
{   
    int             i;
    double          sum;
    const double   *x2 = (const double *) cds2->x,
                   *y2 = (const double *) cds2->y,
                   *z2 = (const double *) cds2->z;
    const double   *x1 = (const double *) cds1->x,
                   *y1 = (const double *) cds1->y,
                   *z1 = (const double *) cds1->z;
    const double   *cx2 = (const double *) cds2->covx,
                   *cy2 = (const double *) cds2->covy,
                   *cz2 = (const double *) cds2->covz;
    const double   *cx1 = (const double *) cds1->covx,
                   *cy1 = (const double *) cds1->covy,
                   *cz1 = (const double *) cds1->covz;
    const double   *o1 = (const double *) cds1->o,
                   *o2 = (const double *) cds2->o;
    double          x1i, y1i, z1i, x2i, y2i, z2i,
                    cx1i, cy1i, cz1i, cx2i, cy2i, cz2i;

    sum = 0.0;
    i = cds1->vlen;
    while(i-- > 0)
    {
        x1i = *x1++;
        y1i = *y1++;
        z1i = *z1++;
        x2i = *x2++;
        y2i = *y2++;
        z2i = *z2++;

        cx1i = *cx1++;
        cy1i = *cy1++;
        cz1i = *cz1++;
        cx2i = *cx2++;
        cy2i = *cy2++;
        cz2i = *cz2++;

        sum += *o1++ * *o2++ *
               ((cx1i * x1i + cx2i * x2i) +
                (cy1i * y1i + cy2i * y2i) +
                (cz1i * z1i + cz2i * z2i));
    }

    return(sum);
}


static void
CalcROcc(const Cds *cds1, const Cds *cds2, double **Rmat,
         const double *weights)
{
    int             i;
    double          weight;
    const double   *x2 = (const double *) cds2->x,
                   *y2 = (const double *) cds2->y,
                   *z2 = (const double *) cds2->z;
    const double   *x1 = (const double *) cds1->x,
                   *y1 = (const double *) cds1->y,
                   *z1 = (const double *) cds1->z;
    const double   *o1 = (const double *) cds1->o,
                   *o2 = (const double *) cds2->o;
    double          x2i, y2i, z2i, x1i, y1i, z1i;
    double          Rmat00, Rmat01, Rmat02,
                    Rmat10, Rmat11, Rmat12,
                    Rmat20, Rmat21, Rmat22;

    Rmat00 = Rmat01 = Rmat02 = Rmat10 = Rmat11 = Rmat12 =
    Rmat20 = Rmat21 = Rmat22 = 0.0;

    i = cds1->vlen;
    while(i-- > 0)
    {
        weight = *weights++ * *o1++ * *o2++;

        x1i = *x1++;
        y1i = *y1++;
        z1i = *z1++;

        x2i = weight * *x2++;
        y2i = weight * *y2++;
        z2i = weight * *z2++;

        Rmat00 += x2i * x1i;
        Rmat01 += x2i * y1i;
        Rmat02 += x2i * z1i;
        
        Rmat10 += y2i * x1i;
        Rmat11 += y2i * y1i;
        Rmat12 += y2i * z1i;
        
        Rmat20 += z2i * x1i;
        Rmat21 += z2i * y1i;
        Rmat22 += z2i * z1i;
    }

    Rmat[0][0] = Rmat00;
    Rmat[0][1] = Rmat01;
    Rmat[0][2] = Rmat02;
    Rmat[1][0] = Rmat10;
    Rmat[1][1] = Rmat11;
    Rmat[1][2] = Rmat12;
    Rmat[2][0] = Rmat20;
    Rmat[2][1] = Rmat21;
    Rmat[2][2] = Rmat22;
}


static void
CalcRvanOcc(const Cds *cds1, const Cds *cds2, double **Rmat)
{
    int             i;
    double          weight;
    const double   *x2 = (const double *) cds2->x,
                   *y2 = (const double *) cds2->y,
                   *z2 = (const double *) cds2->z;
    const double   *x1 = (const double *) cds1->x,
                   *y1 = (const double *) cds1->y,
                   *z1 = (const double *) cds1->z;
    const double   *o1 = (const double *) cds1->o,
                   *o2 = (const double *) cds2->o;
    double          x2i, y2i, z2i, x1i, y1i, z1i;
    double          Rmat00, Rmat01, Rmat02,
                    Rmat10, Rmat11, Rmat12,
                    Rmat20, Rmat21, Rmat22;

    Rmat00 = Rmat01 = Rmat02 = Rmat10 = Rmat11 = Rmat12 =
    Rmat20 = Rmat21 = Rmat22 = 0.0;

    i = cds1->vlen;
    while(i-- > 0)
    {
        weight = *o1++ * *o2++;

        x1i = *x1++;
        y1i = *y1++;
        z1i = *z1++;

        x2i = weight * *x2++;
        y2i = weight * *y2++;
        z2i = weight * *z2++;

        Rmat00 += x2i * x1i;
        Rmat01 += x2i * y1i;
        Rmat02 += x2i * z1i;
        
        Rmat10 += y2i * x1i;
        Rmat11 += y2i * y1i;
        Rmat12 += y2i * z1i;
        
        Rmat20 += z2i * x1i;
        Rmat21 += z2i * y1i;
        Rmat22 += z2i * z1i;
    }

    Rmat[0][0] = Rmat00;
    Rmat[0][1] = Rmat01;
    Rmat[0][2] = Rmat02;
    Rmat[1][0] = Rmat10;
    Rmat[1][1] = Rmat11;
    Rmat[1][2] = Rmat12;
    Rmat[2][0] = Rmat20;
    Rmat[2][1] = Rmat21;
    Rmat[2][2] = Rmat22;
}


static void
CalcRCovOcc(const Cds *cds1, const Cds *cds2, double **Rmat,
            const double **WtMat)
{
    int             i;
    double          weight;
    const double   *x2 = (const double *) cds2->covx,
                   *y2 = (const double *) cds2->covy,
                   *z2 = (const double *) cds2->covz;
    const double   *x1 = (const double *) cds1->x,
                   *y1 = (const double *) cds1->y,
                   *z1 = (const double *) cds1->z;
    const double   *o1 = (const double *) cds1->o,
                   *o2 = (const double *) cds2->o;
    double          x2i, y2i, z2i, x1i, y1i, z1i;
    double          Rmat00, Rmat01, Rmat02,
                    Rmat10, Rmat11, Rmat12,
                    Rmat20, Rmat21, Rmat22;

    Rmat00 = Rmat01 = Rmat02 = Rmat10 = Rmat11 = Rmat12 =
    Rmat20 = Rmat21 = Rmat22 = 0.0;

    i = cds1->vlen;
    while(i-- > 0)
    {
        weight = *o1++ * *o2++;

        x1i = *x1++;
        y1i = *y1++;
        z1i = *z1++;

        x2i = weight * *x2++;
        y2i = weight * *y2++;
        z2i = weight * *z2++;

        Rmat00 += x2i * x1i;
        Rmat01 += x2i * y1i;
        Rmat02 += x2i * z1i;
        
        Rmat10 += y2i * x1i;
        Rmat11 += y2i * y1i;
        Rmat12 += y2i * z1i;
        
        Rmat20 += weight * z2i * x1i;
        Rmat21 += weight * z2i * y1i;
        Rmat22 += weight * z2i * z1i;
    }

    Rmat[0][0] = Rmat00;
    Rmat[0][1] = Rmat01;
    Rmat[0][2] = Rmat02;
    Rmat[1][0] = Rmat10;
    Rmat[1][1] = Rmat11;
    Rmat[1][2] = Rmat12;
    Rmat[2][0] = Rmat20;
    Rmat[2][1] = Rmat21;
    Rmat[2][2] = Rmat22;
}


static int
CalcGSLSVD(double **Rmat, double **Umat, double *sigma, double **VTmat)
{
    svdGSLDest(Rmat, 3, sigma, VTmat);
    Mat3TransposeIp(VTmat);
    Mat3Cpy(Umat, (const double **) Rmat);
    return(1);
}


/* Takes U and V^t on input, calculates R = VU^t */
static int
CalcRotMat(double **rotmat, double **Umat, double **Vtmat)
{   
    int         i, j, k;
    double      det;

    memset(&rotmat[0][0], 0, 9 * sizeof(double));

    det = Mat3Det((const double **)Umat) * Mat3Det((const double **)Vtmat);

    if (det > 0)
    {
        for (i = 0; i < 3; ++i)
            for (j = 0; j < 3; ++j)
                for (k = 0; k < 3; ++k)
                    rotmat[i][j] += (Vtmat[k][i] * Umat[j][k]);

        return(1);
    }
    else
    {
        for (i = 0; i < 3; ++i)
        {
            for (j = 0; j < 3; ++j)
            {
                for (k = 0; k < 2; ++k)
                    rotmat[i][j] += (Vtmat[k][i] * Umat[j][k]);

                rotmat[i][j] -= (Vtmat[2][i] * Umat[j][2]);
            }
        }

        return(-1);
    }
}


/* returns sum of squared residuals, E
   rmsd = sqrt(E/atom_num)  */
double
ProcGSLSVDvanOcc(const Cds *cds1, const Cds *cds2, double **rotmat,
                    double **Rmat, double **Umat, double **VTmat, double *sigma,
                    double *norm1, double *norm2, double *innprod)
{
    double          det;

    *norm1 = CalcInnProdNormOcc(cds2);
    *norm2 = CalcInnProdNormOcc(cds1);
    CalcRvanOcc(cds1, cds2, Rmat);
    CalcGSLSVD(Rmat, Umat, sigma, VTmat);
    det = CalcRotMat(rotmat, Umat, VTmat);

/*     VerifyRotMat(rotmat, 1e-5); */
/*     printf("\n*************** sumdev = %8.2f ", sumdev); */
/*     printf("\nrotmat:"); */
/*     write_C_mat((const double **)rotmat, 3, 8, 0); */

    if (det < 0)
        *innprod = sigma[0] + sigma[1] - sigma[2];
    else
        *innprod = sigma[0] + sigma[1] + sigma[2];

/*     printf("\nRmat:"); */
/*     write_C_mat((const double **)Rmat, 3, 8, 0); */
/*     printf("\nUmat:"); */
/*     write_C_mat((const double **)Umat, 3, 8, 0); */
/*     printf("\nVTmat:"); */
/*     write_C_mat((const double **)VTmat, 3, 8, 0); */
/*     int i; */
/*     for (i = 0; i < 3; ++i) */
/*         printf("\nsigma[%d] = %8.2f ", i, sigma[i]); */

    return(*norm1 + *norm2 - 2.0 * *innprod);
}



double
ProcGSLSVDOcc(const Cds *cds1, const Cds *cds2, double **rotmat,
                 const double *weights,
                 double **Rmat, double **Umat, double **VTmat, double *sigma)
{
    double          det, sumdev;

    sumdev = CalcE0Occ(cds1, cds2, weights);
    /* printf("\n # sumdev = %8.2f ", sumdev); */
    CalcROcc(cds1, cds2, Rmat, weights);
    CalcGSLSVD(Rmat, Umat, sigma, VTmat);
    det = CalcRotMat(rotmat, Umat, VTmat);

/*     VerifyRotMat(rotmat, 1e-5); */
/*     printf("\n\n rotmat:"); */
/*     write_C_mat((const double **)rotmat, 3, 8, 0); */

    if (det < 0)
        sumdev -= 2.0 * (sigma[0] + sigma[1] - sigma[2]);
    else
        sumdev -= 2.0 * (sigma[0] + sigma[1] + sigma[2]);

/*     printf("\n\n Rmat:"); */
/*     write_C_mat((const double **)Rmat, 3, 8, 0); */
/*     printf("\n\n Umat:"); */
/*     write_C_mat((const double **)Umat, 3, 8, 0); */
/*     printf("\n\n VTmat:"); */
/*     write_C_mat((const double **)VTmat, 3, 8, 0); */
/*     int i; */
/*     for (i = 0; i < 3; ++i) */
/*         printf("\n sigma[%d] = %8.2f ", i, sigma[i]); */

    return(sumdev);
}


/* returns sum of squared residuals, E
   rmsd = sqrt(E/atom_num)  */
double
ProcGSLSVDCovOcc(Cds *cds1, Cds *cds2, double **rotmat,
                    const double **WtMat, double **Rmat,
                    double **Umat, double **VTmat, double *sigma)
{
    double          det, sumdev = 0.0;

    CalcCovCds(cds1, WtMat);
    CalcCovCds(cds2, WtMat);

    sumdev = CalcE0CovOcc(cds1, cds2);
    CalcRCovOcc(cds1, cds2, Rmat, WtMat);
    CalcGSLSVD(Rmat, Umat, sigma, VTmat);
    det = CalcRotMat(rotmat, Umat, VTmat);

    if (det < 0)
        sumdev -= 2.0 * (sigma[0] + sigma[1] - sigma[2]);
    else
        sumdev -= 2.0 * (sigma[0] + sigma[1] + sigma[2]);

    return(sumdev);
}
