// -*- C++ -*-
#include "Rivet/Analysis.hh"
#include "Rivet/Projections/Beam.hh"
#include "Rivet/Projections/UnstableParticles.hh"

namespace Rivet {


  /// @brief Upsilon polarization at 7 and 8 TeV
  class LHCB_2017_I1621596 : public Analysis {
  public:

    /// Constructor
    RIVET_DEFAULT_ANALYSIS_CTOR(LHCB_2017_I1621596);


    /// @name Analysis methods
    /// @{

    /// Book histograms and initialise projections before the run
    void init() {
      // projections
      declare(Beam(), "Beams");
      declare(UnstableParticles(), "UFS");

      for (double eVal : allowedEnergies()) {
        const int en = round(eVal);
        if (isCompatibleWithSqrtS(eVal))  _sqs = en;

        int ih(en==8000);
        // histograms
        _ybins={2.2,3.0,3.5,4.5};
        for (size_t iups=0; iups<3; ++iups) {
          for (size_t iframe=0; iframe<3; ++iframe) {
            for (size_t imom=0; imom<3; ++imom) {
              for (size_t iy=0; iy<3; ++iy) {
                book(_p_Upsilon[ih][iups][iframe][iy][imom],
                     "TMP/UPS_"+toString(iups)+"_"+toString(iframe)+"_"+toString(iy)+"_"+toString(imom)+"_"+toString(ih),
                     refData(32*iups+4*ih+8*iframe+1,1,iy+1));
              }
              book(_p_Upsilon[ih][iups][iframe][3][imom],
                   "TMP/UPS_"+toString(iups)+"_"+toString(iframe)+"_3_"+toString(imom)+"_"+toString(ih),
                   refData(32*iups+4*ih+25,1,iframe+1));
            }
          }
        }
      }
      raiseBeamErrorIf(_sqs==0);
    }

    void findDecayProducts(const Particle& mother, unsigned int& nstable, Particles& mup, Particles& mum) const {
      for (const Particle& p : mother.children()) {
        int id = p.pid();
        if (id == PID::MUON) {
          ++nstable;
          mum += p;
        }
        else if (id == PID::ANTIMUON) {
          ++nstable;
          mup += p;
        }
        else if (id == PID::PI0 || id == PID::K0S || id == PID::K0L ) {
          ++nstable;
        }
        else if ( !p.children().empty() ) {
          findDecayProducts(p, nstable, mup, mum);
        }
        else {
          ++nstable;
        }
      }
    }

    /// Perform the per-event analysis
    void analyze(const Event& event) {
      // find the beams
      const ParticlePair & beams = apply<Beam>(event, "Beams").beams();
      // Final state of unstable particles to get particle spectra
      const UnstableParticles& ufs = apply<UnstableParticles>(event, "UFS");
      for (const Particle& p : ufs.particles(Cuts::pid==553 || Cuts::pid==100553 || Cuts::pid==200553)) {
      	// pT and rapidity
      	double rapidity = p.rapidity();
      	double xp = p.perp();
        if (rapidity<2.2 || rapidity >4.5) continue;
        // which upsilon
        unsigned int iups=p.pid()/100000;
        // polarization
      	unsigned int nstable=0;
      	Particles mup,mum;
      	findDecayProducts(p,nstable,mup,mum);
      	if (mup.size()!=1 || mum.size()!=1 || nstable!=2) continue;
        size_t iy=0;
        for (iy=0; iy<3; ++iy) {
          if (rapidity < _ybins[iy+1]) break;
        }
        // first the CS frame
        // first boost so upslion momentum =0 in z direction
        Vector3 beta = p.mom().betaVec();
        beta.setX(0.);beta.setY(0.);
        LorentzTransform boost = LorentzTransform::mkFrameTransformFromBeta(beta);
        FourMomentum pp = boost.transform(p.mom());
        // and then transverse so pT=0
        beta = pp.betaVec();
        LorentzTransform boost2 = LorentzTransform::mkFrameTransformFromBeta(beta);
        // get all the momenta in this frame
        Vector3 muDirn = boost2.transform(boost.transform(mup[0].mom())).p3().unit();
        FourMomentum p1 = boost2.transform(boost.transform(beams. first.mom()));
        FourMomentum p2 = boost2.transform(boost.transform(beams.second.mom()));
        if (beams.first.mom().z()<0.) swap(p1,p2);
        if (p.rapidity()<0.) swap(p1,p2);
        Vector3 axisy = (p1.p3().cross(p2.p3())).unit();
        Vector3 axisz(0.,0.,1.);
        Vector3 axisx = axisy.cross(axisz);
        double cTheta = axisz.dot(muDirn);
        double cPhi   = axisx.dot(muDirn);
        // fill the moments
        _p_Upsilon[_sqs==8000][iups][1][iy][0]->fill(xp, 1.25*(3.*sqr(cTheta)-1.));
        _p_Upsilon[_sqs==8000][iups][1][iy][1]->fill(xp, 1.25*(1.-sqr(cTheta))*(2.*sqr(cPhi)-1.));
        _p_Upsilon[_sqs==8000][iups][1][iy][2]->fill(xp, 2.5 *cTheta*sqrt(1.-sqr(cTheta))*cPhi);
        _p_Upsilon[_sqs==8000][iups][1][3 ][0]->fill(xp, 1.25*(3.*sqr(cTheta)-1.));
        _p_Upsilon[_sqs==8000][iups][1][3 ][1]->fill(xp, 1.25*(1.-sqr(cTheta))*(2.*sqr(cPhi)-1.));
        _p_Upsilon[_sqs==8000][iups][1][3 ][2]->fill(xp, 2.5 *cTheta*sqrt(1.-sqr(cTheta))*cPhi);
        // Gottfried-Jackson frame
        axisz = p1.p3().unit();
        axisx = axisy.cross(axisz);
        cTheta = axisz.dot(muDirn);
        cPhi   = axisx.dot(muDirn);
        // fill the moments
        _p_Upsilon[_sqs==8000][iups][2][iy][0]->fill(xp, 1.25*(3.*sqr(cTheta)-1.));
        _p_Upsilon[_sqs==8000][iups][2][iy][1]->fill(xp, 1.25*(1.-sqr(cTheta))*(2.*sqr(cPhi)-1.));
        _p_Upsilon[_sqs==8000][iups][2][iy][2]->fill(xp, 2.5 *cTheta*sqrt(1.-sqr(cTheta))*cPhi);
        _p_Upsilon[_sqs==8000][iups][2][3 ][0]->fill(xp, 1.25*(3.*sqr(cTheta)-1.));
        _p_Upsilon[_sqs==8000][iups][2][3 ][1]->fill(xp, 1.25*(1.-sqr(cTheta))*(2.*sqr(cPhi)-1.));
        _p_Upsilon[_sqs==8000][iups][2][3 ][2]->fill(xp, 2.5 *cTheta*sqrt(1.-sqr(cTheta))*cPhi);
        // now for the HX frame
        beta = p.mom().betaVec();
        boost = LorentzTransform::mkFrameTransformFromBeta(beta);
        axisz = pp.p3().unit();
        axisx = axisy.cross(axisz);
        cTheta = axisz.dot(muDirn);
        cPhi   = axisx.dot(muDirn);
        // fill the moments
        _p_Upsilon[_sqs==8000][iups][0][iy][0]->fill(xp, 1.25*(3.*sqr(cTheta)-1.));
        _p_Upsilon[_sqs==8000][iups][0][iy][1]->fill(xp, 1.25*(1.-sqr(cTheta))*(2.*sqr(cPhi)-1.));
        _p_Upsilon[_sqs==8000][iups][0][iy][2]->fill(xp, 2.5 *cTheta*sqrt(1.-sqr(cTheta))*cPhi);
        _p_Upsilon[_sqs==8000][iups][0][3 ][0]->fill(xp, 1.25*(3.*sqr(cTheta)-1.));
        _p_Upsilon[_sqs==8000][iups][0][3 ][1]->fill(xp, 1.25*(1.-sqr(cTheta))*(2.*sqr(cPhi)-1.));
        _p_Upsilon[_sqs==8000][iups][0][3 ][2]->fill(xp, 2.5 *cTheta*sqrt(1.-sqr(cTheta))*cPhi);
      }
    }

    /// Normalise histograms etc., after the run
    void finalize() {

      for (double eVal : allowedEnergies()) {
        const int en = round(eVal);

        int ih(en==8000);

        // loop over upslion
        for (size_t iups=0; iups<3; ++iups) {
          // loop over iframe
          for (size_t iframe=0; iframe<3; ++iframe) {
            size_t ibase = 32*iups+4*ih+8*iframe;
            size_t ibase2 = 32*iups+4*ih+24;
            // rapidity range
            for (size_t iy=0; iy<4; ++iy) {
              // book scatters
              Estimate1DPtr lTheta,lPhi,lThetaPhi,lTilde;
              if (iy<3) {
                book(lTheta   ,ibase+1,1,1+iy);
                book(lPhi     ,ibase+3,1,1+iy);
                book(lThetaPhi,ibase+2,1,1+iy);
                book(lTilde   ,ibase+4,1,1+iy);
              }
              else {
                book(lTheta   ,ibase2+1,1,1+iframe);
                book(lPhi     ,ibase2+3,1,1+iframe);
                book(lThetaPhi,ibase2+2,1,1+iframe);
                book(lTilde   ,ibase2+4,1,1+iframe);
              }
              // histos for the moments
              Profile1DPtr moment[3];
              for (size_t ix=0; ix<3; ++ix) {
                moment[ix] = _p_Upsilon[ih][iups][iframe][iy][ix];
              }
              // loop over bins
              for (size_t ibin=1; ibin<=moment[0]->numBins(); ++ibin) {
                // extract moments and errors
                double val[3], err[3];
                // m1 = lTheta/(3+lTheta), m2 = lPhi/(3+lTheta), m3 = lThetaPhi/(3+lTheta)
                for (size_t ix=0; ix<3; ++ix) {
                  val[ix] = moment[ix]->bin(ibin).effNumEntries()>0 ? moment[ix]->bin(ibin).mean(2)   : 0.;
                  err[ix] = moment[ix]->bin(ibin).effNumEntries()>1 ? moment[ix]->bin(ibin).stdErr(2) : 0.;
                }
                // values of the lambdas and their errors
                double l1 = 3.*val[0]/(1.-val[0]);
                double l2 = (3.+l1)*val[1];
                lTheta   ->bin(ibin).setVal(l1);
                lTheta   ->bin(ibin).setErr(3./sqr(1.-val[0])*err[0]);
                lPhi     ->bin(ibin).setVal(l2);
                lPhi     ->bin(ibin).setErr(3./sqr(1.-val[0])*sqrt(sqr(err[0]*val[1])+sqr(err[1]*(1.-val[0]))));
                lThetaPhi->bin(ibin).setVal((3.+l1)*val[2]);
                lThetaPhi->bin(ibin).setErr(3./sqr(1.-val[0])*sqrt(sqr(err[0]*val[1])+sqr(err[1]*(1.-val[0]))));
                lTilde   ->bin(ibin).setVal((l1+3.*l2)/(1.-l2));
                lTilde   ->bin(ibin).setErr(3./sqr(1.-val[0]-3*val[1])*sqrt(sqr(err[0])+9.*sqr(err[1])));
              }
            }
          }
        }
      }
    }

    /// @}


    /// @name Histograms
    /// @{
    Profile1DPtr _p_Upsilon[2][3][3][4][3];
    vector<double>  _ybins;
    int _sqs = 0;
    /// @}


  };


  RIVET_DECLARE_PLUGIN(LHCB_2017_I1621596);

}
