#include <iostream>
#include <iomanip>
#include "cfvector.h"
#include "chebyshev.h"
#include "flowfield.h"
#include "dns.h"

int main() {

  // Define gridsize
  const int Nx=24;
  const int Ny=65;
  const int Nz=24;

  // Define box size
  const Real Lx=1.75*pi;
  const Real a= -1.0;
  const Real b=  1.0;
  const Real Lz=1.2*pi;

  // Define flow parameters
  const Real Reynolds = 400.0;
  const Real nu = 1.0/Reynolds;

  // Define integration parameters
  const int n = 1;     // take ten steps between printouts
  const Real dt = 0.02;
  const Real T  = 100.0;

  // Define size and smoothness of initial disturbance
  Real decay = 0.3;   
  Real magnitude  = 0.01;   
  int kxmax = 3;
  int kzmax = 3;

  // Construct base flow for plane Couette: U(y) = y
  ChebyCoeff U(Ny,a,b,Physical);
  Vector y = chebypoints(Ny, a,b);
  for (int ny=0; ny<Ny; ++ny) 
    U[ny] = y[ny];
  U.save("U");
  y.save("y");
 
  // Construct data fields: 3d velocity and 1d pressure
  cout << "building velocity and pressure fields..." << flush;
  FlowField u(Nx,Ny,Nz,3,Lx,Lz,a,b);
  FlowField q(Nx,Ny,Nz,1,Lx,Lz,a,b);
  cout << "done" << endl;

  // Perturb velocity field
  u.addPerturbations(kxmax,kzmax,1.0,decay);
  u *= magnitude/L2Norm(u);

  // Construct Navier-Stoke Integrator
  cout << "building DNS..." << flush;
  DNSFlags flags; 
  flags.timestepping = SBDF3;
  flags.initstepping = CNRK2;
  flags.nonlinearity = SkewSymmetric;
  //flags.dealiasing = DealiasXZ;
  flags.constraint  = PressureGradient; // enforce constant pressure gradient

  DNS dns(u, U, nu, dt, flags);
  cout << "done" << endl;

  PressureSolver psolve(u, U, nu, flags.nonlinearity);
  FlowField p = q;

  for (Real t=0; t<T; t += n*dt) {
    cout << "===============================================" << endl;
    cout << "         t == " << t << endl;
    cout << "       CFL == " << dns.CFL() << endl;
    cout << " L2Norm(u) == " << L2Norm(u) << endl;
    cout << "divNorm(u) == " << divNorm(u) << endl;
    cout << " bcNorm(u) == " << bcNorm(u) << endl;
    cout << " L2Dist(q,p) == " << L2Dist(q,p) << endl;

    // Save the kx=1,kz=2 Fourier profile and the velocity field
    BasisFunc u12 = u.profile(1,2);
    u12.save("u12_"+i2s(int(t))); 

    // Take n steps of length dt
    dns.advance(u, q, n);

  }
  u.binarySave("u"+i2s(int(T)));
  q.binarySave("q"+i2s(int(T)));
}

