!-----------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations         !
!   Copyright (C) 2000 - 2010  CP2K developers group                          !
!-----------------------------------------------------------------------------!

! *****************************************************************************
!> \par History
!>      none
!> \author MI (20.12.2004)
! *****************************************************************************
MODULE kg_gpw_correction

  USE atomic_kind_types,               ONLY: atomic_kind_type,&
                                             get_atomic_kind_set
  USE cell_types,                      ONLY: cell_type
  USE cp_dbcsr_operations,             ONLY: cp_dbcsr_deallocate_matrix,&
                                             cp_dbcsr_from_sm,&
                                             sm_from_dbcsr
  USE cp_dbcsr_types,                  ONLY: cp_dbcsr_p_type
  USE cp_para_types,                   ONLY: cp_para_env_type
  USE distribution_2d_types,           ONLY: distribution_2d_type
  USE f77_blas
  USE input_constants,                 ONLY: sic_none,&
                                             xc_funct_no_shortcut
  USE input_section_types,             ONLY: &
       section_vals_create, section_vals_duplicate, section_vals_get, &
       section_vals_get_subs_vals, section_vals_release, &
       section_vals_set_subs_vals, section_vals_type, section_vals_val_get, &
       section_vals_val_set
  USE kg_gpw_collocate_den,            ONLY: integrate_mol_potential
  USE kg_gpw_fm_mol_types,             ONLY: fm_mol_blocks_type,&
                                             get_fm_mol_block,&
                                             get_kg_fm_mol_set,&
                                             kg_fm_mol_set_type
  USE kg_gpw_pw_env_types,             ONLY: get_molbox_env,&
                                             get_rho_mol_block,&
                                             kg_molbox_env_type,&
                                             rho_mol_blocks_type
  USE kinds,                           ONLY: dp
  USE message_passing,                 ONLY: mp_sum
  USE particle_types,                  ONLY: particle_type
  USE pw_env_types,                    ONLY: pw_env_get,&
                                             pw_env_type
  USE pw_pool_types,                   ONLY: pw_pool_give_back_pw,&
                                             pw_pool_type
  USE pw_types,                        ONLY: pw_p_type
  USE qs_environment_types,            ONLY: get_qs_env,&
                                             qs_environment_type
  USE qs_force_types,                  ONLY: qs_force_type
  USE qs_rho_types,                    ONLY: qs_rho_type
  USE sparse_matrix_types,             ONLY: allocate_matrix_set,&
                                             deallocate_matrix_set,&
                                             real_matrix_p_type
  USE termination,                     ONLY: stop_program
  USE timings,                         ONLY: timeset,&
                                             timestop
  USE xc,                              ONLY: xc_exc_calc,&
                                             xc_vxc_pw_create1
#include "cp_common_uses.h"

  IMPLICIT NONE

  PRIVATE

  CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'kg_gpw_correction'

  PUBLIC :: kg_gpw_ekin_mol

!***

CONTAINS

!****f* kg_gpw_correction/kg_gpw_ekin_mol

! *****************************************************************************
!> \brief loop over all the molecules where :
!>      the molecular densityis passed to the calculation of the ekin potential
!>      the ekin potential is integrated in real space
!>      the ks matrix is updated aggordingly (block by block)
!>      the forces coming from this contribution are calculated if required
!> \param qs_env the qs environment
!> \param molbox_env cell, internal positions, rho, pw_env for each molecule
!> \param fm_mol_set other info about the molecule, maybe not neede
!> \param h ks-matrix in sparse form
!> \param ekin_mol total energy from this term (to be added to the total energy)
!> \param error variable to control error logging, stopping,...
!>        see module cp_error_handling
!> \author MI
! *****************************************************************************
  SUBROUTINE kg_gpw_ekin_mol(qs_env,molbox_env,fm_mol_set,ks_global_b,p_global_b,&
                             ekin_mol,calculate_forces,just_energy,error)

    TYPE(qs_environment_type), POINTER       :: qs_env
    TYPE(kg_molbox_env_type), DIMENSION(:), &
      POINTER                                :: molbox_env
    TYPE(kg_fm_mol_set_type), DIMENSION(:), &
      POINTER                                :: fm_mol_set
    TYPE(cp_dbcsr_p_type), DIMENSION(:), &
      POINTER                                :: ks_global_b, p_global_b
    REAL(dp), INTENT(OUT)                    :: ekin_mol
    LOGICAL, INTENT(IN), OPTIONAL            :: calculate_forces, just_energy
    TYPE(cp_error_type), INTENT(inout)       :: error

    CHARACTER(len=*), PARAMETER :: routineN = 'kg_gpw_ekin_mol', &
      routineP = moduleN//':'//routineN

    INTEGER :: handle, iat_kind, iat_mol, iatom, ikind, imol, imolecule_kind, &
      ispin, istat, ke_sections, nat_mol, natom, nmol_local, nmolecule_kind, &
      nspins, sic_method_id
    INTEGER, ALLOCATABLE, DIMENSION(:)       :: atom_of_kind
    INTEGER, DIMENSION(:), POINTER           :: i_atom, i_kind
    LOGICAL                                  :: failure, is_set, &
                                                my_calculate_forces, &
                                                my_just_energy
    REAL(dp)                                 :: ekin_imol
    REAL(dp), DIMENSION(:, :), POINTER       :: forces_mol, r_mbox
    TYPE(atomic_kind_type), DIMENSION(:), &
      POINTER                                :: atomic_kind_set
    TYPE(cell_type), POINTER                 :: cell_mol
    TYPE(cp_para_env_type), POINTER          :: para_env
    TYPE(distribution_2d_type), POINTER      :: distribution_2d
    TYPE(fm_mol_blocks_type), DIMENSION(:), &
      POINTER                                :: fm_mol_blocks
    TYPE(fm_mol_blocks_type), POINTER        :: mol_block
    TYPE(kg_fm_mol_set_type), POINTER        :: fm_mol
    TYPE(kg_molbox_env_type), POINTER        :: molbox
    TYPE(particle_type), DIMENSION(:), &
      POINTER                                :: particle_set
    TYPE(pw_env_type), POINTER               :: pw_env_mol
    TYPE(pw_p_type), DIMENSION(:), POINTER   :: my_vxc_rho, my_vxc_tau, &
                                                rho_g, rho_r, tau, vxc_rho, &
                                                vxc_tau
    TYPE(pw_pool_type), POINTER              :: auxbas_pw_pool
    TYPE(qs_force_type), DIMENSION(:), &
      POINTER                                :: force_global
    TYPE(qs_rho_type), POINTER               :: rho_global, rho_mol
    TYPE(real_matrix_p_type), DIMENSION(:), &
      POINTER                                :: ks_global, p_global
    TYPE(rho_mol_blocks_type), &
      DIMENSION(:), POINTER                  :: rho_mol_blocks
    TYPE(rho_mol_blocks_type), POINTER       :: rho_block
    TYPE(section_vals_type), POINTER         :: input, kef_section, &
                                                xc_fun_section, &
                                                xc_fun_section_kg, &
                                                xc_section, xc_section_kg

!------------------------------------------------------------------------------

    failure=.FALSE.

    CALL timeset(routineN,handle)

    NULLIFY(ks_global)!sm->dbcsr
    CALL get_qs_env(qs_env=qs_env,distribution_2d=distribution_2d,error=error)!sm->dbcsr
    CALL allocate_matrix_set( ks_global, SIZE(ks_global_b), error )!sm->dbcsr
    DO ispin=1,SIZE(ks_global)!sm->dbcsr
       CALL sm_from_dbcsr(ks_global(ispin)%matrix, ks_global_b(ispin)%matrix, distribution_2d,error)!sm->dbcsr
    ENDDO!sm->dbcsr

    NULLIFY(p_global)!sm->dbcsr
    CALL allocate_matrix_set( p_global, SIZE(p_global_b), error )!sm->dbcsr
    DO ispin=1,SIZE(p_global)!sm->dbcsr
       CALL sm_from_dbcsr(p_global(ispin)%matrix, p_global_b(ispin)%matrix, distribution_2d,error)!sm->dbcsr
    ENDDO!sm->dbcsr

    ! initialize to zero
    ekin_mol = 0.0_dp

    NULLIFY(atomic_kind_set,particle_set,para_env,vxc_tau,tau)
    NULLIFY(rho_global,rho_g,rho_r,my_vxc_rho,my_vxc_tau,input,xc_section_kg)
    NULLIFY(forces_mol)

    CALL get_qs_env(qs_env=qs_env, atomic_kind_set=atomic_kind_set,&
         input=input,rho=rho_global,para_env=para_env,error=error)

    ! generate a new XC section with only the KE functional
    xc_section => section_vals_get_subs_vals(input,"DFT%XC",error=error)
    xc_fun_section => section_vals_get_subs_vals(xc_section,&
         "XC_FUNCTIONAL",error=error)

    CALL section_vals_duplicate(xc_section,xc_section_kg,error=error)

    ! Attention, we test for TF, TFW, and KE_GGA sections in turn, only
    ! the last section detected will be operational for the correction
    ! this makes sense, as anyway only exactly one section should be available
    ke_sections = 0
    NULLIFY(kef_section)
    kef_section => section_vals_get_subs_vals(xc_fun_section,"TF",error=error)
    CALL section_vals_get(kef_section, explicit=is_set, error=error)
    IF(is_set) THEN
      NULLIFY(xc_fun_section_kg)
      CALL section_vals_create(xc_fun_section_kg,xc_fun_section%section,&
           error=error)
      CALL section_vals_val_set(xc_fun_section_kg,"_SECTION_PARAMETERS_",&
           i_val=xc_funct_no_shortcut,error=error)
      CALL section_vals_set_subs_vals(xc_fun_section_kg,"TF",&
           kef_section,error=error)
      CALL section_vals_set_subs_vals(xc_section_kg,"XC_FUNCTIONAL",&
           xc_fun_section_kg,error=error)
      ke_sections = ke_sections + 1
    END IF

    NULLIFY(kef_section)
    kef_section => section_vals_get_subs_vals(xc_fun_section,"TFW",error=error)
    CALL section_vals_get(kef_section, explicit=is_set, error=error)
    IF(is_set) THEN
      NULLIFY(xc_fun_section_kg)
      CALL section_vals_create(xc_fun_section_kg,xc_fun_section%section,&
           error=error)
      CALL section_vals_val_set(xc_fun_section_kg,"_SECTION_PARAMETERS_",&
           i_val=xc_funct_no_shortcut,error=error)
      CALL section_vals_set_subs_vals(xc_fun_section_kg,"TFW",&
           kef_section,error=error)
      CALL section_vals_set_subs_vals(xc_section_kg,"XC_FUNCTIONAL",&
           xc_fun_section_kg,error=error)
      ke_sections = ke_sections + 1
    END IF

    NULLIFY(kef_section)
    kef_section => section_vals_get_subs_vals(xc_fun_section,"KE_GGA",error=error)
    CALL section_vals_get(kef_section, explicit=is_set, error=error)
    IF(is_set) THEN
      NULLIFY(xc_fun_section_kg)
      CALL section_vals_create(xc_fun_section_kg,xc_fun_section%section,&
           error=error)
      CALL section_vals_val_set(xc_fun_section_kg,"_SECTION_PARAMETERS_",&
           i_val=xc_funct_no_shortcut,error=error)
      CALL section_vals_set_subs_vals(xc_fun_section_kg,"KE_GGA",&
           kef_section,error=error)
      CALL section_vals_set_subs_vals(xc_section_kg,"XC_FUNCTIONAL",&
           xc_fun_section_kg,error=error)
      ke_sections = ke_sections + 1
    END IF
    CPPostcondition(ke_sections==1,cp_failure_level,routineP,error,failure)

    ! Sic is not implemented
    CALL section_vals_val_get(input,"DFT%SIC%SIC_METHOD",&
         i_val=sic_method_id,error=error)
    IF(sic_method_id /= sic_none) &
       CALL  stop_program(routineN,moduleN,__LINE__,"KG_GPW with SIC  not implemented")

    nspins = SIZE(p_global)

    ! initial allocations
    ALLOCATE(rho_r(nspins),STAT=istat)
    CPPostcondition(istat==0,cp_failure_level,routineP,error,failure)

    IF (rho_global%tau_r_valid) THEN
       ALLOCATE(tau(nspins),STAT=istat)
       CPPostcondition(istat==0,cp_failure_level,routineP,error,failure)
    END IF

    ! for gradient corrected functional the density in g space might be useful
    IF ( rho_global%rho_g_valid ) THEN
       ALLOCATE(rho_g(nspins),STAT=istat)
       CPPostcondition(istat==0,cp_failure_level,routineP,error,failure)
    END IF

    CPPrecondition(ASSOCIATED(molbox_env),cp_failure_level,routineP,error,failure)
    nmolecule_kind = SIZE(molbox_env,1)
    CPPrecondition(ASSOCIATED(fm_mol_set),cp_failure_level,routineP,error,failure)
    CPPrecondition(SIZE(fm_mol_set,1)==nmolecule_kind,cp_failure_level,routineP,error,failure)

    my_just_energy = .FALSE.
    IF(PRESENT(just_energy)) my_just_energy = just_energy
    my_calculate_forces = .FALSE.
    IF(PRESENT(calculate_forces)) my_calculate_forces = calculate_forces

    IF (my_calculate_forces) THEN
    ! Array of index within the kind
    ! to associate the forces to the right position in forces array

      CALL get_qs_env(qs_env=qs_env,particle_set=particle_set,error=error)
      natom = SIZE(particle_set)
      ALLOCATE (atom_of_kind(natom),STAT=istat)
      CPPostcondition(istat==0,cp_failure_level,routineP,error,failure)

      CALL get_atomic_kind_set(atomic_kind_set=atomic_kind_set,&
                               atom_of_kind=atom_of_kind)

      CALL get_qs_env(qs_env=qs_env, force=force_global,error=error)

    END IF

    IF(.NOT. failure) THEN

      ! Loop over the molecule kinds
      DO imolecule_kind = 1,nmolecule_kind

        NULLIFY(molbox,pw_env_mol,cell_mol,rho_mol_blocks)
        molbox => molbox_env(imolecule_kind)
        CALL get_molbox_env(molbox_env=molbox,natom=nat_mol,cell_mol=cell_mol,&
             nmolecule_local=nmol_local, pw_env_mol=pw_env_mol, rho_mol_blocks=rho_mol_blocks)
        IF(nmol_local>0) THEN
        CALL pw_env_get(pw_env=pw_env_mol,auxbas_pw_pool=auxbas_pw_pool,error=error)

        NULLIFY(fm_mol,fm_mol_blocks)
        fm_mol => fm_mol_set(imolecule_kind)
        CALL get_kg_fm_mol_set(kg_fm_mol_set=fm_mol,&
                               fm_mol_blocks=fm_mol_blocks)

        ! Prepare array forces
        IF(my_calculate_forces) THEN
          ALLOCATE(forces_mol(3,nat_mol),STAT=istat)
          CPPostcondition(istat==0,cp_failure_level,routineP,error,failure)
        ELSE
          NULLIFY(forces_mol)
        END IF

        ! Loop over the molecules of one kind
        DO imol = 1,nmol_local

          NULLIFY(rho_block, rho_mol, r_mbox)
          rho_block => rho_mol_blocks(imol)
          CALL get_rho_mol_block(rho_block=rho_block, rho_mol=rho_mol,&
                                 r_in_molbox=r_mbox)

          NULLIFY(mol_block,i_atom,i_kind)
          mol_block => fm_mol_blocks(imol)
          CALL get_fm_mol_block(fm_mol_block = mol_block,&
                                index_atom = i_atom,&
                                index_kind = i_kind)

          DO ispin=1,nspins
            rho_r(ispin)%pw => rho_mol%rho_r(ispin)%pw
            IF(rho_mol%tau_r_valid) tau(ispin)%pw => rho_mol%tau_r(ispin)%pw
            IF(rho_mol%rho_g_valid) rho_g(ispin)%pw => rho_mol%rho_g(ispin)%pw
          END DO

          ekin_imol = 0.0_dp
          IF (my_just_energy) THEN
            ekin_imol = xc_exc_calc(rho_r=rho_r,tau=tau,&
                     rho_g=rho_g, xc_section=xc_section_kg,&
                     cell=cell_mol, pw_pool=auxbas_pw_pool,&
                     error=error)
          ELSE
            CALL xc_vxc_pw_create1(vxc_rho=my_vxc_rho,vxc_tau=my_vxc_tau, rho_r=rho_r,&
                                   rho_g=rho_g,tau=tau,exc=ekin_imol,&
                                   xc_section=xc_section_kg,&
                                   cell=cell_mol, pw_pool=auxbas_pw_pool,&
                                   error=error)
          END IF

          ekin_mol = ekin_mol + ekin_imol

          ! we have pw data for the xc, here we transfer to coeff
          ! pw -> coeff
          IF (ASSOCIATED(my_vxc_rho)) THEN
             ALLOCATE(vxc_rho(nspins),STAT=istat)
             CPPostcondition(istat==0,cp_failure_level,routineP,error,failure)
             DO ispin=1,nspins
               vxc_rho(ispin)%pw => my_vxc_rho(ispin)%pw
             END DO
             DEALLOCATE(my_vxc_rho,STAT=istat)
             CPPostcondition(istat==0,cp_failure_level,routineP,error,failure)
          END IF
          IF (ASSOCIATED(my_vxc_tau)) THEN
             ALLOCATE(vxc_tau(nspins),STAT=istat)
             CPPostcondition(istat==0,cp_failure_level,routineP,error,failure)
             DO ispin=1,nspins
               vxc_tau(ispin)%pw => my_vxc_tau(ispin)%pw
             END DO
             DEALLOCATE(my_vxc_tau,STAT=istat)
             CPPostcondition(istat==0,cp_failure_level,routineP,error,failure)
          END IF

          ! Integrate the potential
          IF (.NOT. just_energy) THEN
             CPPostcondition(ASSOCIATED(vxc_rho),cp_failure_level,routineP,error,failure)
             DO ispin=1,nspins

               vxc_rho(ispin)%pw%cr3d  =  vxc_rho(ispin)%pw%cr3d*vxc_rho(ispin)%pw%pw_grid%dvol

               IF(my_calculate_forces) THEN
                 ! Initialize the forces
                 forces_mol(1:3,1:nat_mol) = 0.0_dp
                 CALL integrate_mol_potential(qs_env=qs_env,vxc_mol=vxc_rho(ispin),&
                                             matrix_p_sm=p_global(ispin),&
                                             matrix_h_sm=ks_global(ispin),&
                                             pw_env=pw_env_mol,&
                                             atom=i_atom, kind=i_kind, ratom=r_mbox,&
                                             forces_mol=forces_mol,&
                                             distribution_2d=distribution_2d,&
                                             error=error)
               ELSE
                 CALL integrate_mol_potential(qs_env=qs_env,vxc_mol=vxc_rho(ispin),&
                                             matrix_p_sm=p_global(ispin),&
                                             matrix_h_sm=ks_global(ispin),&
                                             pw_env=pw_env_mol,&
                                             atom=i_atom, kind=i_kind, ratom=r_mbox,&
                                             distribution_2d=distribution_2d,&
                                             error=error)
               END IF

               CALL pw_pool_give_back_pw(auxbas_pw_pool,vxc_rho(ispin)%pw,error=error)

             END DO
             DEALLOCATE(vxc_rho,STAT=istat)
             CPPostconditionNoFail(istat==0,cp_warning_level,routineP,error)

             IF (ASSOCIATED(vxc_tau)) THEN
               DO ispin=1,nspins

                 vxc_tau(ispin)%pw%cr3d  =  vxc_tau(ispin)%pw%cr3d*vxc_tau(ispin)%pw%pw_grid%dvol

                 IF(my_calculate_forces) THEN
                   CALL integrate_mol_potential(qs_env=qs_env,vxc_mol=vxc_tau(ispin),&
                                               matrix_p_sm=p_global(ispin),&
                                               matrix_h_sm=ks_global(ispin),&
                                               pw_env=pw_env_mol,&
                                               atom=i_atom, kind=i_kind, ratom=r_mbox,&
                                               forces_mol=forces_mol,&
                                               compute_tau=.TRUE.,&
                                               distribution_2d=distribution_2d,&
                                               error=error)
                 ELSE
                   CALL integrate_mol_potential(qs_env=qs_env,vxc_mol=vxc_rho(ispin),&
                                               matrix_p_sm=p_global(ispin),&
                                               matrix_h_sm=ks_global(ispin),&
                                               pw_env=pw_env_mol,&
                                               atom=i_atom, kind=i_kind, ratom=r_mbox,&
                                               compute_tau=.TRUE.,&
                                               distribution_2d=distribution_2d,&
                                               error=error)
                 END IF
                 CALL pw_pool_give_back_pw(auxbas_pw_pool,vxc_tau(ispin)%pw,error=error)

               END DO
               DEALLOCATE(vxc_tau,STAT=istat)
               CPPostconditionNoFail(istat==0,cp_warning_level,routineP,error)

             END IF
          END IF

          ! Copy the forces in the global array
          IF(my_calculate_forces) THEN

             DO iat_mol = 1,nat_mol
               ikind = i_kind(iat_mol)
               iatom = i_atom(iat_mol)
               iat_kind = atom_of_kind(iatom)
               ! Change sign: it is the correction
               force_global(ikind)%kg_gpw_ekin_mol(1:3,iat_kind) = -forces_mol(1:3,iat_mol)
             END DO

          END IF

        END DO  ! imol

        END IF  !nmol_local
        IF(ASSOCIATED(forces_mol)) THEN
          DEALLOCATE(forces_mol,STAT=istat)
          CPPostcondition(istat==0,cp_failure_level,routineP,error,failure)
        ELSE
          NULLIFY(forces_mol)
        END IF

      END DO ! imolecule_kind

       CALL mp_sum(ekin_mol,para_env%group)
    END IF ! failure

    ! Deallocate
    DEALLOCATE(rho_r,STAT=istat)
    CPPostconditionNoFail(istat==0,cp_warning_level,routineP,error)
    IF (ASSOCIATED(rho_g)) THEN
       DEALLOCATE(rho_g,STAT=istat)
       CPPostconditionNoFail(istat==0,cp_warning_level,routineP,error)
    END IF
    IF (ASSOCIATED(tau)) THEN
       DEALLOCATE(tau,STAT=istat)
       CPPostconditionNoFail(istat==0,cp_warning_level,routineP,error)
    END IF

    CALL section_vals_release(xc_section_kg,error=error)
    CALL section_vals_release(xc_fun_section_kg,error=error)


    DO ispin=1,SIZE(ks_global)!sm->dbcsr
       CALL cp_dbcsr_deallocate_matrix(ks_global_b(ispin)%matrix,error)
       ALLOCATE(ks_global_b(ispin)%matrix)
       CALL cp_dbcsr_from_sm(ks_global_b(ispin)%matrix, ks_global(ispin)%matrix, error)!sm->dbcsr
    ENDDO!sm->dbcsr
    CALL deallocate_matrix_set( ks_global, error )!sm->dbcsr

    DO ispin=1,SIZE(p_global)!sm->dbcsr
       CALL cp_dbcsr_deallocate_matrix(p_global_b(ispin)%matrix,error)
       ALLOCATE(p_global_b(ispin)%matrix)
       CALL cp_dbcsr_from_sm(p_global_b(ispin)%matrix, p_global(ispin)%matrix, error)!sm->dbcsr
    ENDDO!sm->dbcsr
    CALL deallocate_matrix_set( p_global, error )!sm->dbcsr

    CALL timestop(handle)

  END SUBROUTINE kg_gpw_ekin_mol

END MODULE kg_gpw_correction
