!--------------------------------------- LICENCE BEGIN -----------------------------------
!Environment Canada - Atmospheric Science and Technology License/Disclaimer,
!                     version 3; Last Modified: May 7, 2008.
!This is free but copyrighted software; you can use/redistribute/modify it under the terms
!of the Environment Canada - Atmospheric Science and Technology License/Disclaimer
!version 3 or (at your option) any later version that should be found at:
!http://collaboration.cmc.ec.gc.ca/science/rpn.comm/license.html
!
!This software is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
!without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!See the above mentioned License/Disclaimer for more details.
!You should have received a copy of the License/Disclaimer along with this software;
!if not, you can write to: EC-RPN COMM Group, 2121 TransCanada, suite 500, Dorval (Quebec),
!CANADA, H9P 1J3; or send e-mail to service.rpn@ec.gc.ca
!-------------------------------------- LICENCE END --------------------------------------


MODULE BMatrix_mod 7,8
  use bMatrixHI
  use bMatrixEnsemble
  use controlVector_mod
  use gridStateVector_mod
  use LAMbMatrixHI_mod
  use horizontalCoord_mod
  use timeCoord_mod
  use gaussGrid_mod
  implicit none
  save
  private
  
  ! public procedures
  public :: bmat_setup, bmat_finalize, bmat_sqrtB, bmat_sqrtBT
  public :: bmat_reduceToMPILocal, bmat_expandToMPIGlobal
  ! public procedures through inheritance
  public :: bhi_getScaleFactor,ben_getScaleFactor


  type(struct_hco), pointer :: hco_anl

contains
  
!--------------------------------------------------------------------------
! bmat_setup
!--------------------------------------------------------------------------

  SUBROUTINE bmat_setup(hco_anl_in, vco_anl_in) 2,5
    !
    !- bmat_setup - Initializes the analysis Background term for the 
    !               specific analysis configuration used.
    !
    IMPLICIT NONE

    type(struct_vco), pointer :: vco_anl_in
    type(struct_hco), pointer :: hco_anl_in

    integer :: cvdimens, cvdimhi
    integer :: get_max_rss

    !
    !- 1.  Get/Check the analysis grid info
    !

    !- 1.1 Horizontal Grid info
    hco_anl => hco_anl_in

    !
    !- 2.  Set the B matrices
    !
    cvdimhi  = 0
    cvdimens = 0

    !- 2.1 Time-Mean Homogeneous and Isotropic...
    if ( hco_anl % global ) then
      write(*,*)
      write(*,*) 'Setting up the modular GLOBAL HI covariances...'
      call bhi_Setup( hco_anl, vco_anl_in, & ! IN
                      cvdimhi )              ! OUT
    else
      write(*,*)
      write(*,*) 'Setting up the modular LAM HI covariances...'
      call lbhi_Setup( hco_anl, vco_anl_in, & ! IN
                       cvdimhi )              ! OUT
    end if

    write(*,*) 'Memory Used: ',get_max_rss()/1024,'Mb'
    write(*,*) 'Dimension of HI  control vector returned:',cvdimhi

    !- 2.2 Flow-dependent Ensemble-Based
    write(*,*)
    write(*,*) 'Setting up the modular ENSEMBLE covariances...'
    call ben_Setup( hco_anl,             & ! IN
                    vco_anl_in,          & ! IN
                    tim_nstepobsinc,     & ! IN
                    tim_getDatestamp(),  & ! IN
                    cvdimens )             ! OUT

    write(*,*) 'Memory Used: ',get_max_rss()/1024,'Mb'
    write(*,*) 'Dimension of ENS control vector returned:',cvdimens

    !
    !- 3.  Setup the control vector
    !
    call cvm_Setup( cvdimhi, cvdimens ) ! IN

    write(*,*) 'Memory Used: ',get_max_rss()/1024,'Mb'
    write(*,*) 'Dimension of TOTAL control vector:',cvm_nvadim
    
  END SUBROUTINE bmat_setup

!--------------------------------------------------------------------------
! bmat_sqrtB
!-------------------------------------------------------------------------- 

  SUBROUTINE bmat_sqrtB(da_v,na_vdim,statevector) 8,14
    implicit none
    !
    !- Purpose: Transforms model state from error covariance space
    !           to grid point space.
    !
    integer         :: na_vdim

    real(8)         :: da_v(na_vdim)
    real(8),pointer :: cvBhi(:), cvBen(:), field(:,:,:), field4d(:,:,:,:)

    type(struct_gsv) :: statevector, statevector_temp

    !
    !- 1.  Set analysis increment to zero
    !
    call gsv_zero(statevector) 

    !
    !- 2.  Compute the analysis increment
    !

    !- 2.1 Allocate and set to zero another temporary statevector
    call gsv_setVco(statevector_temp, gsv_getVco(statevector))
    call gsv_setHco(statevector_temp, hco_anl)
    call gsv_allocate(statevector_temp,&
                      statevector % numStep, mpi_local=.true.)
    call gsv_zero(statevector_temp)

    !- 2.2 Compute 3D contribution to increment from BmatrixHI
    call tmg_start(50,'B_HI')
    cvBhi => cvm_getSubVector(da_v,1)

    if ( associated(cvBhi) ) then
      if ( statevector % hco % global ) then
        !- 2.2.1 Global mode 
        call bhi_bsqrt( cvBhi,      & ! IN
                        statevector ) ! OUT
      else
        !- 2.2.2 LAM mode
        call lbhi_bSqrt( cvBhi,      & ! IN 
                         statevector ) ! OUT
      end if
      !- copy 3D increment to other timesteps to create 4D increment
      call gsv_3dto4d(statevector)
    end if
    call tmg_stop(50)

    !- 2.3 compute 4D contribution to increment from BmatrixEnsemble
    call tmg_start(60,'B_ENS')
    cvBen => cvm_getSubVector(da_v,2)

    if ( associated(cvBen) ) call ben_bsqrt(cvBen, statevector_temp)
    call tmg_stop(60)
      
    !- 2.4 Add the two contributions together, result in statevector
    call gsv_add(statevector_temp,statevector)

    call gsv_deallocate(statevector_temp)

  END SUBROUTINE bmat_sqrtB

!--------------------------------------------------------------------------
! bmat_sqrtBT
!--------------------------------------------------------------------------

  SUBROUTINE bmat_sqrtBT(da_v,na_vdim,statevector) 2,5
    implicit none
    !
    !- Purpose: Transforms model state from grid point space 
    !           to error covariance space.
    !
    integer :: na_vdim
    real*8 :: da_v(na_vdim)
    real*8,pointer :: cvBhi(:),cvBen(:)
    type(struct_gsv) :: statevector

    if ( statevector % hco % global ) then
      !- 1.1 Adjoint of the identity (change of norm)
      call adjnorm(statevector)
    end if

    !- 1.2 set gradient to zero
    da_v(:)=0.0d0
    
    !- 1.3 Add contribution to gradient from BmatrixEnsemble
    call tmg_start(61,'B_ENS_T')
    cvBen=>cvm_getSubVector(da_v,2)

    if ( associated(cvBen) ) call ben_bsqrtad(statevector,cvBen)
    call tmg_stop(61)

    !- 1.4 add contribution to gradient from BmatrixHI
    call tmg_start(51,'B_HI_T')
    cvBhi=>cvm_getSubVector(da_v,1)
    if ( associated(cvBhi) ) then
      !- 1.4.1 adjoint of copy 3D increment to 4D increment
      call gsv_3dto4dAdj(statevector)

      if ( statevector % hco % global ) then
        !- 1.4.2 add contribution to gradient from GLOBAL BmatrixHI
        call bhi_bsqrtad( statevector, & ! IN
                          cvBhi )        ! OUT
      else
        !- 1.4.3 add contribution to gradient from LAM BmatrixHI
        call lbhi_bSqrtAdj( statevector, & ! IN
                            cvBhi )        ! OUT
      end if
    end if
    call tmg_stop(51)

  END SUBROUTINE bmat_sqrtBT

!--------------------------------------------------------------------------
! adjnorm
!--------------------------------------------------------------------------

  subroutine adjnorm(statevector) 1,1
    !
    !- Adjoint of the identity (change of norm)
    !
    implicit none
    type(struct_gsv) :: statevector

    integer :: jlev,jlat,jlon,lon1,lon2,lat1,lat2,jstep,jvar

    real(8) :: rwtinv

    real*8, pointer :: ptr(:,:,:,:)
    
    lon1 = statevector % myLonBeg
    lon2 = statevector % myLonEnd
    lat1 = statevector % myLatBeg
    lat2 = statevector % myLatEnd

    ptr => gsv_getField(statevector)
!$OMP PARALLEL
!$OMP DO PRIVATE (jlat,jstep,jlev,jlon,rwtinv)
    do jlat = lat1, lat2
       rwtinv = real(statevector % ni,8) / gaus_RWT(jlat)
       do jstep = 1, statevector % numStep
          do jlev = 1, statevector % nk
             do jlon = lon1, lon2
                ptr(jlon,jlev,jlat,jstep) = rwtinv * ptr(jlon,jlev,jlat,jstep)
             end do
          end do
       end do
    end do
!$OMP END DO
!$OMP END PARALLEL

    RETURN
  END SUBROUTINE adjnorm

!--------------------------------------------------------------------------
! bmat_finalize
!--------------------------------------------------------------------------

  SUBROUTINE bmat_finalize(da_v) 1,5
    implicit none
    real*8 :: da_v(:)
    !
    !- Purpose: Releases memory used by B matrices.
    !
    real(8), pointer :: cvBhi(:), cvBen(:)

    cvBhi => cvm_getSubVector(da_v,1)
    if ( associated(cvBhi) ) then
      if ( hco_anl % global ) then
        call bhi_finalize()
      else
        call lbhi_finalize
      end if
    end if

    cvBen => cvm_getSubVector(da_v,2)
    if ( associated(cvBen) ) call ben_finalize() 

  END SUBROUTINE bmat_finalize

!--------------------------------------------------------------------------
! BMAT_reduceToMPILocal
!--------------------------------------------------------------------------

  SUBROUTINE BMAT_reduceToMPILocal(cv_mpilocal,cv_mpiglobal,cvDim_mpilocal_out) 6,3
    implicit none
    real*8  :: cv_mpilocal(:)
    real*8  :: cv_mpiglobal(:)
    integer :: cvDim_mpilocal_out

    integer :: cvDim_Bhi_mpilocal,cvDim_Ben_mpilocal
    real*8,pointer :: cvBhi_mpilocal(:),cvBen_mpilocal(:)
    real*8,pointer :: cvBhi_mpiglobal(:),cvBen_mpiglobal(:)

    cvBhi_mpilocal =>cvm_getSubVector(cv_mpilocal,1)
    cvBhi_mpiglobal=>cvm_getSubVector_mpiglobal(cv_mpiglobal,1)
    if(associated(cvBhi_mpilocal)) then
      if ( hco_anl % global ) then 
         call bhi_reduceToMPILocal(cvBhi_mpilocal,cvBhi_mpiglobal,cvDim_Bhi_mpilocal)
      else
         call lbhi_reduceToMPILocal(cvBhi_mpilocal,cvBhi_mpiglobal,cvDim_Bhi_mpilocal)
      end if
    endif

    cvBen_mpilocal =>cvm_getSubVector(cv_mpilocal,2)
    cvBen_mpiglobal=>cvm_getSubVector_mpiglobal(cv_mpiglobal,2)
    if(associated(cvBen_mpilocal)) then
       call ben_reduceToMPILocal(cvBen_mpilocal,cvBen_mpiglobal,cvDim_Ben_mpilocal)
    end if

    cvDim_mpilocal_out = cvDim_Bhi_mpilocal + cvDim_Ben_mpilocal

  END SUBROUTINE BMAT_reduceToMPILocal

!--------------------------------------------------------------------------
! BMAT_expandToMPIGlobal
!--------------------------------------------------------------------------

  SUBROUTINE BMAT_expandToMPIGlobal(cv_mpilocal,cv_mpiglobal,cvDim_mpiglobal_out) 3,3
    implicit none
    real*8  :: cv_mpilocal(:)
    real*8  :: cv_mpiglobal(:)
    integer :: cvDim_mpiglobal_out

    integer :: cvDim_Bhi_mpiglobal,cvDim_Ben_mpiglobal
    real*8,pointer :: cvBhi_mpilocal(:),cvBen_mpilocal(:)
    real*8,pointer :: cvBhi_mpiglobal(:),cvBen_mpiglobal(:)

    cvBhi_mpilocal =>cvm_getSubVector(cv_mpilocal,1)
    cvBhi_mpiglobal=>cvm_getSubVector_mpiglobal(cv_mpiglobal,1)
    if(associated(cvBhi_mpilocal)) then
      if ( hco_anl % global ) then
         call bhi_expandToMPIGlobal(cvBhi_mpilocal,cvBhi_mpiglobal,cvDim_Bhi_mpiglobal)
      else
         call lbhi_expandToMPIGlobal(cvBhi_mpilocal,cvBhi_mpiglobal,cvDim_Bhi_mpiglobal)
      end if
    endif

    cvBen_mpilocal =>cvm_getSubVector(cv_mpilocal,2)
    cvBen_mpiglobal=>cvm_getSubVector_mpiglobal(cv_mpiglobal,2)
    if(associated(cvBen_mpilocal)) then
       call ben_expandToMPIGlobal(cvBen_mpilocal,cvBen_mpiglobal,cvDim_Ben_mpiglobal)
    endif

    cvDim_mpiglobal_out = cvDim_Bhi_mpiglobal + cvDim_Ben_mpiglobal

  end SUBROUTINE BMAT_expandToMPIGlobal

END MODULE BMatrix_mod