#include "helperfuncs.h"

void xavg(const FlowField& un, FlowField& unMean) {
  assert(un.xstate() == Spectral);
  unMean.setToZero();
  unMean.setState(Spectral, un.ystate());
  for (int i=0; i<un.vectorDim(); ++i)
    for (int ny=0; ny<un.numYmodes(); ++ny)
      for (int nz=0; nz<un.numZmodes(); ++nz)
	unMean.cmplx(0,ny,nz,i) = un.cmplx(0,ny,nz,i);
}

string join(const string& s0, const string& s1, const string& s2) {
  return s0 + s1 + s2;
}
    
void saveplots(const string& lbl, FlowField& un, FlowField& unMean, 
	       ChebyTransform& tr,
	       ostream& modestream, ostream& dragstream) { 
    
  cout << "plotting stuff ..." << flush;

  for (int kx=-2; kx<=2; ++kx) 
    for (int kz=0; kz<=2; ++kz) 
      un.saveProfile(un.nx(kx), un.nz(kz), string("plots/un") + i2s(kx) + i2s(kz), tr);

  int Ny = un.numYgridpts();
  xavg(un, unMean);
  un.makePhysical();
  un.saveSlice(Ny/2, 0, join("plots/u", lbl, "slice"));
  un.saveSlice(Ny/2, 1, join("plots/v", lbl, "slice")); 
  un.saveSlice(Ny/2, 2, join("plots/w", lbl, "slice")); 
  un.saveSliceXY(0, 0, join("plots/u", lbl, "sliceXY"));
  un.saveSliceXY(0, 1, join("plots/v", lbl, "sliceXY"));
  un.saveSliceXY(0, 2, join("plots/w", lbl, "sliceXY"));
  un.saveCrossSection(0, 0, join("plots/u", lbl, "xsection"));
  un.saveCrossSection(0, 1, join("plots/v", lbl, "xsection"));
  un.saveCrossSection(0, 2, join("plots/w", lbl, "xsection"));
  un.makeSpectral();

  unMean.makePhysical();
  unMean.saveCrossSection(0, 0, join("plots/u", lbl, "Mxsection"));
  unMean.saveCrossSection(0, 1, join("plots/v", lbl, "Mxsection"));
  unMean.saveCrossSection(0, 2, join("plots/w", lbl, "Mxsection"));
      
  un.saveSpectrum(join("plots/u", lbl,"spec"));
  //un.saveDissSpectrum("plots/undiss");

  const char sp = ' ';
  modestream << L2Norm2(un) <<sp<< un.energy(0,0) <<sp
	     << un.energy(un.nx(0),un.nz(1))+un.energy(un.nx(0),un.nz(-1))<<sp
	     << un.energy(un.nx(1),un.nz(0))+un.energy(un.nx(-1),un.nz(0))<<sp
	     << un.energy(un.nx(1),un.nz(1))+un.energy(un.nx(-1),un.nz(-1))<<sp
	     << un.energy(un.nx(0),un.nz(2))+un.energy(un.nx(0),un.nz(-2))<<sp
	     << un.energy(un.nx(2),un.nz(0))+un.energy(un.nx(-2),un.nz(0))<<sp
	     << un.energy(un.nx(1),un.nz(2))+un.energy(un.nx(-1),un.nz(-2))<<sp
	     << un.energy(un.nx(2),un.nz(1))+un.energy(un.nx(-2),un.nz(-1))<<sp
	     << un.energy(un.nx(2),un.nz(2))+un.energy(un.nx(-2),un.nz(-2))<<sp
	     <<endl;
  
  BasisFunc u00 = un.profile(0,0);
  BasisFunc u00y = ydiff(u00);
  u00y.makePhysical(tr);
  dragstream << Re(u00y[0].eval_a()) <<sp<< Re(u00y[0].eval_b()) <<endl;
  cout << "done" << endl;
}


void ucheck(const FlowField& un, const NSIntegrator& dns, Real dt, Real t, 
	    const string& label) {

  Real unorm2=L2Norm2(un);
  Real udiv = un.divergence();
  cout << "========================================================" << endl;
  cout << label << endl;
  cout << "  t == " << t << endl;
  cout << " dt == " << dt << endl;
  cout << "CFL == " << dns.CFL() << endl;
  cout << "L2Norm2(un) == " << unorm2 << endl;
  cout << "    div(un) == " << udiv << endl;  
  cout << "      dp/Dx == " << dns.dPdx() << endl;
  cout << "u.Ubulk-UbulkRef == " << dns.Ubulk()-dns.UbulkRef() << endl;

  if (unorm2 > 1 || udiv > 0.1) {
    cerr << label << ": Excessive norm or divergence at t=" << t << endl;
    exit(1);
  }
  if (unorm2 <0.0001) {
    cerr << label << ": un has collapsed at t=" << t << endl;
    exit(1);
  }
}

void modifyPressure(FlowField& p, FlowField& u, int sign) {
  fieldstate uxzstate = u.xzstate();
  fieldstate uystate  = u.ystate();
  fieldstate pxzstate = p.xzstate();
  fieldstate pystate  = p.ystate();

  u.makePhysical();
  p.makePhysical();

  int s = (sign > 0) ? 1 : -1;
  int Nx=u.numXgridpts();
  int Ny=u.numYgridpts();
  int Nz=u.numZgridpts();
  for (int ny=0; ny<Ny; ++ny)
    for (int nx=0; nx<Nx; ++nx)
      for (int nz=0; nz<Nz; ++nz) 
	p(nx,ny,nz,0) += s*0.5*(square(u(nx,ny,nz,0)) + 
				square(u(nx,ny,nz,1)) + 
				square(u(nx,ny,nz,2)));

  u.makeState(uxzstate, uystate); 
  p.makeState(pxzstate, pystate); 
}


void symmetrize(BasisFunc& umean) {

  fieldstate s = umean.state();

  int Ny = umean.Ny();
  ChebyTransform trans(Ny);
  umean.makePhysical(trans);

  ChebyCoeff u = umean[0].re;
  for (int ny=0; ny<Ny; ++ny) {
    umean[0].re[ny] = 0.5*(u[ny] - u[Ny-1-ny]);
    umean[0].im.setToZero();
    umean[1].re.setToZero();
    umean[1].im.setToZero();
    umean[2].re.setToZero();
    umean[2].im.setToZero();
  }

  if (s == Spectral)
    umean.makeSpectral(trans);
}
