#include "Rivet/Analysis.hh"
#include "Rivet/Projections/FinalState.hh"
#include "Rivet/Projections/HeavyHadrons.hh"
#include "Rivet/Projections/UnstableParticles.hh"
#include "Rivet/Tools/Random.hh"

namespace Rivet {

    /// @brief Muon neutrino cross-section and flux at 13.6 TeV
    class FASER_2024_I2855783 : public Analysis {
    public:

    // Default analysis constructor
    RIVET_DEFAULT_ANALYSIS_CTOR(FASER_2024_I2855783);

    /// @name Analysis methods
    ///@{

    /// Book histograms and initialise projections for run
    void init() {

      // initialize used particle groups
      declare(UnstableParticles(Cuts::abspid==PID::K0S        || Cuts::abspid==PID::K0L       ||
                                Cuts::abspid==PID::PIPLUS     || Cuts::abspid==PID::KPLUS     ||
                                Cuts::abspid==PID::SIGMAMINUS || Cuts::abspid==PID::SIGMAPLUS ||
                                Cuts::abspid==PID::XI0        || Cuts::abspid==PID::XIMINUS   ||
                                Cuts::abspid==PID::LAMBDA), "UP");

      declare(HeavyHadrons(Cuts::abspid==PID::DPLUS  || Cuts::abspid==PID::D0 ||
                           Cuts::abspid==PID::DSPLUS || Cuts::abspid==PID::LAMBDACPLUS), "HH");

      // inialize histograms for electronic neutrino analysis
      const vector<double> edges{-1./100., -1./1000., 1./1000., 1./100.};
      book(_e_num, 9, 1, 2);
      book(_h_num, "_aux_num", refData(9, 1, 2));

      book(_h_int, edges);
      book(_h_int->bin(1), 10, 1, 2);
      book(_h_int->bin(2), 12, 1, 2);
      book(_h_int->bin(3), 11, 1, 2);

      book(_h_flux, edges);
      book(_h_flux->bin(1), 22, 1, 1);
      book(_h_flux->bin(2), 24, 1, 1);
      book(_h_flux->bin(3), 23, 1, 1);

      _aper = refData<YODA::BinnedEstimate<int,double>>("_aux_aper");
      _xsint = refData<YODA::BinnedEstimate<int,double>>("_aux_xsint");
      _decay_branchings = refData<YODA::BinnedEstimate<int,int>>("_aux_decay_branchings");
      for (PdgId pid : {130, 211, 310, 321, 3112, 3122, 3312, 3322, 411, 421, 431, 4122}) {
        _energy_cdf[pid] = refData<YODA::Estimate1D>("_aux_energy_"+to_string(pid));
      }
      _bfield = refData<YODA::Estimate1D>("_aux_bfield");
      _bprime = refData<YODA::Estimate1D>("_aux_bprime");
      _kickers = refData<YODA::Estimate1D>("_aux_kickers");

      // initializations of geomery, particle properties, decay CDFs
      _mass = { {PID::PIPLUS,     0.13957}, {PID::KPLUS,      0.49368},
                {PID::K0S,        0.49761}, {PID::K0L,        0.49761},
                {PID::LAMBDA,     1.11568}, {PID::SIGMAPLUS,  1.18937},
                {PID::SIGMAMINUS, 1.19745}, {PID::XI0,        1.31486},
                {PID::XIMINUS,    1.32171}, {PID::OMEGAMINUS, 1.67245} };

      _ctau = { {PID::PIPLUS,     c_light*2.603e-08}, {PID::KPLUS,      c_light*1.238e-08},
                {PID::K0S,        c_light*8.954e-11}, {PID::K0L,        c_light*5.116e-08},
                {PID::LAMBDA,     c_light*2.60e-10},  {PID::SIGMAPLUS,  c_light*8.018e-11},
                {PID::SIGMAMINUS, c_light*1.479e-10}, {PID::XI0,        c_light*2.90e-10},
                {PID::XIMINUS,    c_light*1.639e-10}, {PID::OMEGAMINUS, c_light*8.21e-11} };

      _decay_cdf = { {  211, YODA::Axis<double>({0.1,1.0})},
                     { -211, YODA::Axis<double>({-1.0,0.0,0.1,1.0})},
                     {  321, YODA::Axis<double>({0.5,1.0})},
                     { -321, YODA::Axis<double>({-1.0,0.0,0.5,1.0})},
                     {  310, YODA::Axis<double>({0.25,0.50,0.75,1.0})},
                     {  130, YODA::Axis<double>({0.25,0.50,0.75,1.0})},
                     { 3122, YODA::Axis<double>({-1.0,0.0,0.9,1.0})},
                     {-3122, YODA::Axis<double>({0.9,1.0})},
                     { 3222, YODA::Axis<double>({1.0})},
                     {-3222, YODA::Axis<double>({-1.0,0.0,1.0})},
                     { 3112, YODA::Axis<double>({0.5,1.0})},
                     {-3112, YODA::Axis<double>({-1.0,0.0,0.5,1.0})},
                     { 3322, YODA::Axis<double>({0.5,1.0})},
                     {-3322, YODA::Axis<double>({-1.0,0.0,0.5,1.0})},
                     { 3312, YODA::Axis<double>({-1.0,0.0,0.5,1.0})},
                     {-3312, YODA::Axis<double>({0.5,1.0})}, };
    }

    void analyze(const Event& event) {
      // light hadrons
      for (const Particle& hadron : apply<UnstableParticles>(event, "UP").particles()) {
        analyze_light(hadron);
      }
      // heavy hadrons
      for (const Particle& hadron : apply<HeavyHadrons>(event, "HH").particles()) {
        analyze_heavy(hadron);
      }
    }

    void analyze_light(const Particle& hadron) {

      // get hadron pid, position and momentum
      int hpid = hadron.pid();
      Vector3 x = hadron.origin().vector3()/meter;
      Vector3 p = hadron.momentum().vector3()/GeV;

      // rotate with beam crossing angle
      const double angle = -160./1000./1000. ;
      const double cosang = cos(angle);
      const double sinang = sin(angle);
      p = Vector3(p.x(), p.y()*cosang+p.z()*sinang, p.z()*cosang-p.y()*sinang);
      x = Vector3(x.x(), x.y()*cosang+x.z()*sinang, x.z()*cosang-x.y()*sinang);

      // reject hadrons produced outside of pipe or moving backwards
      if (p.z() < 0 || !inpipe(x)) return;

      // reject displaced hadrons from charged parent hadron decay for z>22m
      if (x.z()>22. && hadron.parents()[0].isCharged()) return;

      // process particle
      process_particle(x, p, hpid, hadron.charge());
    }

    void analyze_heavy(const Particle& hadron) {

      // read particle
      int pid  = hadron.pid();
      int apid = hadron.abspid();

      // rotate with beam crossing angle
      FourMomentum phad = hadron.momentum();
      const double angle = -160./1000./1000.;
      const double cosang = cos(angle);
      const double sinang = sin(angle);
      phad.setXYZM(phad.x(),
                   phad.y()*cosang+phad.z()*sinang,
                   phad.z()*cosang-phad.y()*sinang,
                   phad.mass());

      // re-decay rate
      size_t nrepeat = 2500;

      // loop through all possible decays
      for (PdgId vpid : {14, -14}) {
        // repeat decays nrepeat times for enhanced statistics
        for (size_t irepeat=0; irepeat<nrepeat; ++irepeat) {

          // get neutrino id
          const double br = _decay_branchings.binAt(pid,vpid).val();
          if (br <= 0) continue;
          const double weight = br/double(nrepeat);
          if (isnan(weight))  continue;

          //construct neutrino in hadron rest frame
          const double enu = _energy_cdf[apid].binAt(rand01()).val();
          const double phi = rand01() * TWOPI;
          const double costh = 0.9999999 * (2.*rand01() - 1.);

          // boost neutrino in lab frame
          FourMomentum vrest;
          vrest.setThetaPhiME(acos(costh), phi, 0, enu);
          LorentzTransform ltf;
          ltf.setBetaVec(phad.betaVec());
          FourMomentum p = ltf.transform(vrest);

          // save neutrino
          if (p.z()>0 && p.E()>10) {
            save_neutrino(pid, vpid, Vector3(0,0,0), p, weight);
          }
        }
      }
    }

    // Function that interpolates arrays with constant function (used for cross section).
    double linear_interpolation(int vpid, double y) const {
      const auto& b = _xsint.binAt(vpid,y);
      const auto& xaxis = _xsint.binning().axis<0>();
      const auto& yaxis = _xsint.binning().axis<1>();
      const auto& bplus = _xsint.bin(xaxis.index(vpid), yaxis.index(y)+1);
      const double frac = (y - b.yMin())/b.yWidth();
      const double dsigma = bplus.val() - b.val();
      return b.val() + frac * dsigma;
    }

    // Function that evaluates offset as function of the location.
    double get_offset(double z) const {
      if (z < 139.3) return 0;
      if (z > 158)   return 0.097;
      return (z-75.)*0.097 / (158.-75.);
    }

    // Function that checks if particle is inside beam pipe.
    bool inpipe(const Vector3& x) const {
      if (isnan(x.z()) || isnan(x.z()) || isnan(x.z())) return false;
      if (x.z()<-0.1 || x.z()>220.) return false;
      const double offset = get_offset(x.z());
      if (abs(abs(x.x())-offset) > _aper.binAt(1,x.z()).val()) return false;
      if (abs(x.y()) > _aper.binAt(2,x.z()).val()) return false;
      if ( sqr((abs(x.x())-offset)/_aper.binAt(3,x.z()).val()) +
           sqr(x.y()/_aper.binAt(4,x.z()).val()) > 1) return false;
      return true;
    }

    // Function that performs a numerical integration step for trajectory.
    vector<Vector3> integration_step(Vector3 x, Vector3 p, double charge, double length,
                                     double bfield=0., double bprime=0.) const {

      const double zstart = x.z();
      const double absbprime = abs(bprime);
      const double absbfield = abs(bfield);

      // warning if B and B' are both != 0
      if (absbprime > 0. && absbfield > 0.) {
        MSG_WARNING("Warning: bprime and bfield are non-zero: B="
                    << bfield << " and B'=" << bprime << " at z=" <<x.z());
      }

      // drift tube
      if (isZero(absbfield) && isZero(absbprime)) {
        x = x + p.unit() * length;
      }

      // dipole magnets
      if (absbfield>0 && isZero(absbprime)) {
        double rho   = add_quad(p.x(), p.z()) / bfield;
        double theta = length / rho;
        Vector3 punit = p.unit();
        Vector3 funit = (punit.cross(Vector3(0,1,0))).unit();
        x = x + rho * ( sin(theta) * punit + (1.-cos(theta)) * funit );
        p = p.mod() * cos(theta) * punit + p.mod() * sin(theta) * funit;
      }

      // quadrupole magnets
      if (absbprime>0 && isZero(absbfield)) {
        double kappa = 0.299 * bprime / p.mod();
        double rk = sqrt(abs(kappa));
        double rkh = length*rk;

        double offset = get_offset(x.z());
        if (offset > 0 && x.x()<0) offset =-offset;
        double x0 = x.x()-offset;
        double y0 = x.y();

        double xx,px,xy,py;
        Vector3 punit = p.unit();
        if (kappa>0) {
          const double cosrkh  = cos(rkh);
          const double sinrkh  = sin(rkh);
          const double coshrkh = cosh(rkh);
          const double sinhrkh = sinh(rkh);
          xx =  x0*cosrkh     + punit.x()*sinrkh/rk;
          px = -x0*sinrkh*rk  + punit.x()*cosrkh;
          xy =  y0*coshrkh    + punit.y()*sinhrkh/rk;
          py =  y0*sinhrkh*rk + punit.y()*coshrkh;
        }
        else {
          const double cosrkh  = cos(rkh);
          const double sinrkh  = sin(rkh);
          const double coshrkh = cosh(rkh);
          const double sinhrkh = sinh(rkh);
          xx =  x0*coshrkh    + punit.x()*sinhrkh/rk;
          px =  x0*sinhrkh*rk + punit.x()*coshrkh;
          xy =  y0*cosrkh     + punit.y()*sinrkh/rk;
          py = -y0*sinrkh*rk  + punit.y()*cosrkh;
        }
        const double pz = sqrt(1.-px*px-py*py);
        const double xz = x.z() + length;
        x = Vector3(xx+offset,xy,xz);
        p = p.mod()*Vector3(px,py,pz);
      }

      // kickers
      if (!isZero(charge)) {
        for (const auto& b : _kickers.bins(true)) {
          const double zkick = b.xMin();
          if (zstart < zkick && x.z() >= zkick) {
            const double py = p.y() + 6800. * b.val() * charge;
            p = Vector3(p.x(),py,p.z());
          }
        }
      }
      return {x,p};
    }

    // Function that obtains particle trajectory in beam pipe.
    vector<vector<Vector3>> get_trajectory(Vector3 x, Vector3 p, double charge) const {

      vector<Vector3> array_x, array_p, array_s;
      const double stepsize = 0.1;
      double length = stepsize;

      while (inpipe(x)) {
        // get fields
        const double bfield = _bfield.binAt(x.z()).val();
        const double bprime = _bprime.binAt(x.z()).val();
        // optimise stepsize
        if (charge!=0) {
          if (abs(bfield)>0 || abs(bprime)>0) length=stepsize/4.; // divide by field factor
          else length=stepsize;
          const double distance_to_boundary = min({ _bfield.binAt(x.z()).xMax(),
                                                    _bprime.binAt(x.z()).xMax(),
                                                    _kickers.binAt(x.z()).xMax() }) - x.z();
          if (distance_to_boundary<length) length = distance_to_boundary+0.001;
        }
        // record position
        array_x.push_back(x);
        array_p.push_back(p);
        array_s.push_back(Vector3(0,0,length));
        // integration step
        vector<Vector3> step = integration_step(x, p, charge, length, bfield*charge, bprime*charge);
        x=step[0];
        p=step[1];
      }
      return {array_x, array_p, array_s};
    }

    // Decay a hadron into neutrinos.
    void decay_particle(int hpid, Vector3 x, Vector3 p, double wdecay) {

      //create hadron momentum
      FourMomentum phadron;
      phadron.setXYZM(p.x(),p.y(),p.z(),_mass[abs(hpid)]);

      // get decay channel: 0 ve, 1 vm, 2 vebar, 3 vmbar
      size_t nuidx = _decay_cdf[hpid].index(rand01());
      if (nuidx % 2 == 0)  return;
      PdgId vpid = nuidx==1? 14 : -14;
      double wbranching = _decay_branchings.binAt(hpid,vpid).val()/_decay_cdf[hpid].width(nuidx);
      if (isnan(wbranching))  return;

      //get neutrino in hadron rest frame
      double phi = rand01() * TWOPI;
      double costh = 0.9999999 * (2.*rand01() - 1.);
      double enu = _energy_cdf.at(abs(hpid)).binAt(rand01()).val();
      FourMomentum vrest;
      vrest.setThetaPhiME(acos(costh), phi, 0, enu);

      // boost neutrino in lab frame
      LorentzTransform ltf;
      ltf.setBetaVec( phadron.betaVec() );
      FourMomentum pneutrino = ltf.transform(vrest);

      // save neutrino
      if (pneutrino.z()>0. && pneutrino.E()>10.) {
        save_neutrino(hpid, vpid, x, pneutrino, wdecay*wbranching);
      }
    }

    // Fill histos for this neutrino
    void save_neutrino(int hpid, int vpid, Vector3 x, FourMomentum p, double weight) {
      // get position at FASER interface plane at L=480m
      const Vector3 position = x + p.vector3()/p.vector3().z() * (480.-x.z());
      if ( add_quad(position.x(), position.y()+0.012) >= 0.1 )  return;
      const double en = p.E()/GeV;
      const double xs_cc = linear_interpolation(vpid, en) * en;
      double val = -vpid/abs(vpid)/en;
      _h_num->fill(val, weight*xs_cc*_facGeo2);
      if (abs(val) < 0.001)  val = 0.0009;
      _h_int->fill(val, 1./abs(val), weight*xs_cc*_facGeo2);
      _h_flux->fill(val, 1./abs(val), weight/314.15e-3);
    }

    // function that obtains trajetctory and decays for each particle
    void process_particle(Vector3 x0, Vector3 p0, int hpid, double charge) {

      // get decay length
      const double decaylength = _ctau[abs(hpid)] * p0.mod() / _mass[abs(hpid)];

      // get trajectory
      vector<vector<Vector3>> trajectory = get_trajectory(x0, p0, charge);

      // loop over trejectory and decay hadron
      for (size_t itraj=0; itraj<trajectory[0].size(); ++itraj) {
        Vector3 x = trajectory[0][itraj];
        Vector3 p = trajectory[1][itraj];
        const double s = trajectory[2][itraj].z();
        double wdecay = exp(-(x.z()-x0.z()) / decaylength) * (1. - exp(- s / decaylength));
        if (isnan(wdecay))  continue;
        decay_particle(hpid, x, p, wdecay);
      }
    }

    void finalize() {
      const double sf = crossSection()/picobarn/sumOfWeights();
      scale(_h_num, sf);
      barchart(_h_num, _e_num);
      scale({_h_int, _h_flux}, sf);
    }

    ///@}

  private:

    /// @name Histograms
    ///@{

    // Histos
    Histo1DPtr _h_num;
    Estimate1DPtr _e_num;
    Histo1DGroupPtr _h_int, _h_flux;

    ///@}

    //length[cm] * density[g/cm3] / protonmass [g] * lumi [ipb]
    const double _facGeo2= 1.022e27 * 65.6 * 1000.;

    // other input
    map<PdgId, double> _mass, _ctau;
    map<PdgId, YODA::Axis<double>> _decay_cdf;
    map<PdgId, YODA::Estimate1D> _energy_cdf;
    YODA::BinnedEstimate<int,int> _decay_branchings;
    YODA::BinnedEstimate<int,double> _xsint, _aper;
    YODA::Estimate1D _bfield, _bprime, _kickers;

  };

  RIVET_DECLARE_PLUGIN(FASER_2024_I2855783);

}
