!--------------------------------------- LICENCE BEGIN -----------------------------------
!Environment Canada - Atmospheric Science and Technology License/Disclaimer,
!                     version 3; Last Modified: May 7, 2008.
!This is free but copyrighted software; you can use/redistribute/modify it under the terms
!of the Environment Canada - Atmospheric Science and Technology License/Disclaimer
!version 3 or (at your option) any later version that should be found at:
!http://collaboration.cmc.ec.gc.ca/science/rpn.comm/license.html
!
!This software is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
!without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!See the above mentioned License/Disclaimer for more details.
!You should have received a copy of the License/Disclaimer along with this software;
!if not, you can write to: EC-RPN COMM Group, 2121 TransCanada, suite 500, Dorval (Quebec),
!CANADA, H9P 1J3; or send e-mail to service.rpn@ec.gc.ca
!-------------------------------------- LICENCE END --------------------------------------


module obsSpaceDiag_mod 1,13
  use topLevelControl_mod
  use mpivar_mod
  use bufr
  use codtyp_mod
  use EarthConstants_mod
  use MathPhysConstants_mod
  use horizontalCoord_mod
  use timeCoord_mod
  use controlVector_mod
  use obsSpaceData_mod
  use columnData_mod
  use gridStateVector_mod
  use bMatrix_mod
  implicit none
  save
  private
  
  ! public procedures
  public :: osd_calcInflation

  ! namelist variables
  real*8 :: deltaLat,deltaLon,deltaPressure,deltaHeight
  integer :: numFamily
  character(len=2) :: familyList(20)
  integer :: numElement
  integer :: elementList(20) 
  integer :: nrandseed

contains


  subroutine osd_calcInflation(obsSpaceData,columng) 1,53
    implicit none
    type(struct_obs) :: obsSpaceData
    type(struct_columnData) :: columng

    type(struct_gsv) :: statevector
    type(struct_columnData) :: column
    type(struct_hco), pointer :: hco_anl

    logical :: nmlExists
    integer :: familyIndex,elementIndex,bodyIndex,headerIndex,latIndex,lonIndex,verticalIndex
    integer :: maxLat,maxLon,maxVertical
    real*8, allocatable  :: innovStd(:,:,:),bmatHiStd(:,:,:),bmatEnStd(:,:,:)
    integer, allocatable :: counts(:,:,:)
!
    real*8, allocatable  :: my_innovStd(:,:,:),my_bmatHiStd(:,:,:),my_bmatEnStd(:,:,:)
    integer, allocatable :: my_counts(:,:,:)
    
    integer :: ierr,nulinnov,nulBmatHi,nulBmatEn,nulcount,fnom,fclos,ivco,iseed,jj,jlev,jvar
    integer :: dateprnt,timeprnt,newdate
    real*8 :: zdum,gasdev
    real*8,pointer :: cvBhi(:), cvBen(:), field(:,:,:,:)
    real*8,allocatable :: HxBhi(:), HxBen(:)
    real*8,allocatable :: scaleFactorBhi(:),scaleFactorBen(:)
    character(len=128) :: innovFileName,bmatHiFileName,bmatEnFileName,countFileName
    character(len=6)   :: elementStr
    character(len=8)   :: dateStr

    write(*,*) 'osd_calcInflation: Starting'

    call osd_setup(nmlExists)
    if(.not.nmlExists) return

    ierr = newdate(tim_getDatestamp(),dateprnt,timeprnt,-3)
    if(nrandseed.eq.999) nrandseed=100*dateprnt + int(timeprnt/100.0) ! if seed not set by namelist, use valid date/time
    write(*,*) 'osd_calcInflation: random seed set to ',nrandseed

    maxLat = 1 + nint(180.0d0/deltaLat)
    maxLon = 1 + nint(360.0d0/deltaLon)
    maxVertical = max(1+nint(110000.0d0/deltaPressure),1+nint(80000.0d0/deltaHeight),200)

    write(*,*) 'osd_calcInflation: Compute random realization of background error'

    ! allocate vectors to store Hx for 2 B matrices
    allocate(HxBhi(obs_numbody(obsSpaceData)))
    allocate(HxBen(obs_numbody(obsSpaceData)))

    ! initialize vectors of scaleFactors
    allocate(scaleFactorBhi(col_getNumLev(columng,'MM')))
    allocate(scaleFactorBen(col_getNumLev(columng,'MM')))
    call bhi_getScaleFactor(scaleFactorBhi)
    call ben_getScaleFactor(scaleFactorBen)

    ! initialize columnData object for increment
    call col_setVco(column,col_getVco(columng))
    call col_allocate(column,col_getNumCol(columng),mpi_local=.true.)
    call col_copyLatLon(columng,column)

    ! initialize gridStateVector object for increment
    call gsv_setVco(statevector,col_getVco(columng))
    hco_anl => hco_get('Analysis')
    call gsv_setHco(statevector,hco_anl)
    call gsv_allocate(statevector,tim_nstepobsinc,mpi_local=.true.)

    ! COMPUTE BMATRIX PERTURBATION FOR Bhi

    ! compute random control vector
    iseed = abs(nrandseed)
    zdum = gasdev(-iseed)
    do jj = 1,cvm_nvadim
      cvm_vazx(jj)=gasdev(1)
    enddo

    ! set Ben contribution to zero
    cvBen => cvm_getSubVector(cvm_vazx,2)
    cvBen(:) = 0.0d0

    ! multiply vector by B^1/2
    call bmat_sqrtB(cvm_vazx,cvm_nvadim,statevector)
    ! undo the scaleFactor (THIS IS NOT CORRECT FOR 2D VARIABLES!!!) 
    do jvar=1,vnl_numvarmax 
      if(gsv_varExist(vnl_varNameList(jvar))) then
        field => gsv_getField(statevector,vnl_varNameList(jvar))
        do jlev = 1, gsv_getNumLev(statevector,vnl_vartypeFromVarname(vnl_varNameList(jvar)))   
          if(scaleFactorBhi(jlev).gt.0.0d0) then
            field(:,jlev,:,:)=field(:,jlev,:,:)/scaleFactorBhi(jlev)
          endif
        enddo
      endif
    enddo
    ! multiply by H
    call oda_L(statevector,column,columng,obsSpaceData)  ! put in column H_horiz dx
    call oda_H(column,columng,obsSpaceData)  ! Save as OBS_WORK: H_vert H_horiz dx = Hdx
    do bodyIndex=1,obs_numBody(obsSpaceData)
      HxBhi(bodyIndex) = obs_bodyElem_r(obsSpaceData,OBS_WORK,bodyIndex)
    enddo

    ! COMPUTE BMATRIX PERTURBATION FOR Bensemble

    ! compute random control vector
    iseed = abs(nrandseed)
    zdum = gasdev(-iseed)
    do jj = 1,cvm_nvadim
      cvm_vazx(jj)=gasdev(1)
    enddo

    ! set Bhi contribution to zero
    cvBhi => cvm_getSubVector(cvm_vazx,1)
    cvBhi(:) = 0.0d0

    ! multiply vector by B^1/2
    call bmat_sqrtB(cvm_vazx,cvm_nvadim,statevector)
    ! undo the scaleFactor
    do jvar=1,vnl_numvarmax 
      if(gsv_varExist(vnl_varNameList(jvar))) then
        field => gsv_getField(statevector,vnl_varNameList(jvar))
        do jlev = 1, gsv_getNumLev(statevector,vnl_vartypeFromVarname(vnl_varNameList(jvar)))   
          if(scaleFactorBen(jlev).gt.0.0d0) then
            field(:,jlev,:,:)=field(:,jlev,:,:)/scaleFactorBen(jlev)
          endif
        enddo
      endif
    enddo
    ! multiply vector by H
    call oda_L(statevector,column,columng,obsSpaceData)  ! put in column H_horiz dx
    call oda_H(column,columng,obsSpaceData)  ! Save as OBS_WORK: H_vert H_horiz dx = Hdx
    do bodyIndex=1,obs_numBody(obsSpaceData)
      HxBen(bodyIndex) = obs_bodyElem_r(obsSpaceData,OBS_WORK,bodyIndex)
    enddo

    call col_deallocate(column)
    call gsv_deallocate(statevector)

    write(*,*) 'osd_calcInflation: Compute innovation and B-matrix variances'
    flush(6)  ! debug... bhe  


    allocate(my_innovStd(maxLat,maxLon,maxVertical))
    allocate(my_bmatHiStd(maxLat,maxLon,maxVertical))
    allocate(my_bmatEnStd(maxLat,maxLon,maxVertical))
    allocate(my_counts(maxLat,maxLon,maxVertical))

    allocate(innovStd(maxLat,maxLon,maxVertical))
    allocate(bmatHiStd(maxLat,maxLon,maxVertical))
    allocate(bmatEnStd(maxLat,maxLon,maxVertical))
    allocate(counts(maxLat,maxLon,maxVertical))

    FAMILY: do familyIndex = 1, numFamily
      ELEMENT: do elementIndex = 1, numElement
        ivco = -999
        my_innovStd(:,:,:) = 0.0d0
        my_bmatHiStd(:,:,:) = 0.0d0
        my_bmatEnStd(:,:,:) = 0.0d0
        my_counts(:,:,:) = 0

        call obs_set_current_body_list(obsSpaceData,familyList(familyIndex))
        BODY: do
          bodyIndex = obs_getBodyIndex(obsSpaceData)
          if (bodyIndex < 0) exit BODY

          if(obs_bodyElem_i(obsSpaceData,OBS_VNM,bodyIndex).eq.elementList(elementIndex) .and. &
             obs_bodyElem_i(obsSpaceData,OBS_ASS,bodyIndex).eq.1) then

            call osd_getIndices(obsSpaceData,bodyIndex,latIndex,lonIndex,verticalIndex)
            if(verticalIndex.eq.-1) then
              !skip this obs for whatever reason
              cycle BODY
            elseif(latIndex.gt.maxLat .or. lonIndex.gt.maxLon .or. verticalIndex.gt.maxVertical) then
              write(*,*) 'osd_calcInflation: index too big: lat,lon,vertical=',latIndex,lonIndex,verticalIndex
              call abort3d('osd_calcInflation')
            endif

            ivco = obs_bodyElem_i(obsSpaceData,OBS_VCO,bodyIndex)
            counts(latIndex,lonIndex,verticalIndex) = counts(latIndex,lonIndex,verticalIndex) + 1
            innovStd(latIndex,lonIndex,verticalIndex) = innovStd(latIndex,lonIndex,verticalIndex) +     &
                                                        obs_bodyElem_r(obsSpaceData,OBS_OMP,bodyIndex)* &
                                                        obs_bodyElem_r(obsSpaceData,OBS_OMP,bodyIndex)
            bmatHiStd(latIndex,lonIndex,verticalIndex)  = bmatHiStd(latIndex,lonIndex,verticalIndex) +     &
                                                          HxBhi(bodyIndex)*HxBhi(bodyIndex)
            bmatEnStd(latIndex,lonIndex,verticalIndex)  = bmatEnStd(latIndex,lonIndex,verticalIndex) +     &
                                                          HxBen(bodyIndex)*HxBen(bodyIndex)

            headerIndex = obs_bodyElem_i(obsSpaceData,OBS_HIND,bodyIndex)

          endif
        enddo BODY

        call rpn_comm_allreduce(ivco,ivco,1,"MPI_INTEGER","MPI_MAX","GRID",ierr)

        call rpn_comm_allreduce(my_counts,counts,maxLat*maxLon*maxVertical,"MPI_INTEGER","MPI_SUM","GRID",ierr)
        call rpn_comm_allreduce(my_innovStd,innovStd,maxLat*maxLon*maxVertical,"MPI_DOUBLE_PRECISION","MPI_SUM","GRID",ierr)
        call rpn_comm_allreduce(my_bmatHiStd,bmatHiStd,maxLat*maxLon*maxVertical,"MPI_DOUBLE_PRECISION","MPI_SUM","GRID",ierr)
        call rpn_comm_allreduce(my_bmatEnStd,bmatEnStd,maxLat*maxLon*maxVertical,"MPI_DOUBLE_PRECISION","MPI_SUM","GRID",ierr)

        do verticalIndex = 1,maxVertical
          do lonIndex = 1,maxLon
            do latIndex = 1,maxLat
              if(counts(latIndex,lonIndex,verticalIndex).gt.0) then
                innovStd(latIndex,lonIndex,verticalIndex) =   &
                  sqrt(innovStd(latIndex,lonIndex,verticalIndex)/counts(latIndex,lonIndex,verticalIndex))
                bmatHiStd(latIndex,lonIndex,verticalIndex) =   &
                  sqrt(bmatHiStd(latIndex,lonIndex,verticalIndex)/counts(latIndex,lonIndex,verticalIndex))
                bmatEnStd(latIndex,lonIndex,verticalIndex) =   &
                  sqrt(bmatEnStd(latIndex,lonIndex,verticalIndex)/counts(latIndex,lonIndex,verticalIndex))
              endif
            enddo
          enddo
        enddo

        if(mpi_myid.eq.0 .and. sum(counts(:,:,:)).gt.0) then
         ! determine file names
          write(dateStr,'(i8.8)') dateprnt
          write(elementStr,'(i6.6)') elementList(elementIndex)
          innovFileName = 'innov' // dateStr // '_'  // trim(familyList(familyIndex)) // '_' // trim(elementStr) // '.dat'
          bmatHiFileName =  'bmathi'  // dateStr // '_'  // trim(familyList(familyIndex)) // '_' // trim(elementStr) // '.dat'
          bmatEnFileName =  'bmaten'  // dateStr // '_'  // trim(familyList(familyIndex)) // '_' // trim(elementStr) // '.dat'
          countFileName = 'count' // dateStr // '_'  // trim(familyList(familyIndex)) // '_' // trim(elementStr) // '.dat'

          ! open files
          nulinnov=0
          nulBmatHi =0
          nulBmatEn =0
          nulcount=0
          ierr = fnom(nulinnov,innovFileName,'FMT+R/W',0)
          ierr = fnom(nulBmatHi ,bmatHiFileName ,'FMT+R/W',0)
          ierr = fnom(nulBmatEn ,bmatEnFileName ,'FMT+R/W',0)
          ierr = fnom(nulcount,countFileName,'FMT+R/W',0)

          ! write data for this family/element
          write(nulinnov,*) '***maxLon,maxLat,deltaLon,deltaLat,deltaPressure,deltaHeight='
          write(nulinnov,*) maxLon,maxLat,deltaLon,deltaLat,deltaPressure,deltaHeight
          write(nulBmatHi,*)  '***maxLon,maxLat,deltaLon,deltaLat,deltaPressure,deltaHeight='
          write(nulBmatHi,*)  maxLon,maxLat,deltaLon,deltaLat,deltaPressure,deltaHeight
          write(nulBmatEn,*)  '***maxLon,maxLat,deltaLon,deltaLat,deltaPressure,deltaHeight='
          write(nulBmatEn,*)  maxLon,maxLat,deltaLon,deltaLat,deltaPressure,deltaHeight
          write(nulcount,*) '***maxLon,maxLat,deltaLon,deltaLat,deltaPressure,deltaHeight='
          write(nulcount,*) maxLon,maxLat,deltaLon,deltaLat,deltaPressure,deltaHeight
          do verticalIndex = 1,maxVertical
            if(sum(counts(:,:,verticalIndex)).gt.0) then
              write(nulinnov,*) '***verticalIndex,vco='
              write(nulinnov,*) verticalIndex,ivco
              write(nulBmatHi,*)  '***verticalIndex,vco='
              write(nulBmatHi,*)  verticalIndex,ivco
              write(nulBmatEn,*)  '***verticalIndex,vco='
              write(nulBmatEn,*)  verticalIndex,ivco
              write(nulcount,*) '***verticalIndex,vco='
              write(nulcount,*) verticalIndex,ivco
              do latIndex = 1,maxLat
                write(nulinnov,*) innovStd(latIndex,:,verticalIndex)
                write(nulBmatHi ,*) bmatHiStd(latIndex,:,verticalIndex)
                write(nulBmatEn ,*) bmatEnStd(latIndex,:,verticalIndex)
                write(nulcount,*) counts(latIndex,:,verticalIndex)
              enddo
            endif
          enddo

          ! close files
          ierr = fclos(nulinnov)
          ierr = fclos(nulBmatHi)
          ierr = fclos(nulBmatEn)
          ierr = fclos(nulcount)
        endif

      enddo ELEMENT
    enddo FAMILY

     deallocate(my_counts) 
     deallocate(my_innovStd)  
     deallocate(my_bmatHiStd)  
     deallocate(my_bmatEnStd)  

    deallocate(innovStd)
    deallocate(bmatHiStd)
    deallocate(bmatEnStd)
    deallocate(counts)
    deallocate(HxBhi)
    deallocate(HxBen)

    write(*,*) 'osd_calcInflation: Finished'
    flush(6)  

  end subroutine osd_calcInflation



  subroutine osd_getIndices(obsSpaceData,bodyIndex,latIndex,lonIndex,verticalIndex) 1,8
    implicit none
    type(struct_obs) :: obsSpaceData
    integer :: bodyIndex,headerIndex,latIndex,lonIndex,verticalIndex

    ! codtypes for tovs: 164(AMSUA) 168 180 181 182 183 185 186 192 193

    headerIndex = obs_bodyElem_i(obsSpaceData,OBS_HIND,bodyIndex)
    latIndex = 1 + nint( (90.0d0 + obs_headElem_r(obsSpaceData,OBS_LAT,headerIndex)*MPC_DEGREES_PER_RADIAN_R8)/deltaLat)
    lonIndex = 1 + nint(obs_headElem_r(obsSpaceData,OBS_LON,headerIndex)*MPC_DEGREES_PER_RADIAN_R8/deltaLon)

    select case(obs_bodyElem_i(obsSpaceData,OBS_VCO,bodyIndex))
      case(1)
        ! height coordinate
        verticalIndex = 1 + nint(obs_bodyElem_r(obsSpaceData,OBS_PPP,bodyIndex)/deltaHeight)
      case(2)
        ! pressure coordinate
        verticalIndex = 1 + nint(obs_bodyElem_r(obsSpaceData,OBS_PPP,bodyIndex)/deltaPressure)
      case(3)
        ! channel number
        verticalIndex = nint(obs_bodyElem_r(obsSpaceData,OBS_PPP,bodyIndex))
        if(obs_headElem_i(obsSpaceData,OBS_ITY,headerIndex).eq.CODTYP_AMSUA) then
          ! amsu-a
          verticalIndex = verticalIndex - 27
        else
          ! ignore other types of TOVS for now
          verticalIndex = -1
        endif
      case default
        ! unknown vertical coordinate
        write(*,*) 'osd_getIndices: Unknown VCO! ',obs_bodyElem_i(obsSpaceData,OBS_VCO,bodyIndex)
        verticalIndex = -1
    end select
 
  end subroutine osd_getIndices



  subroutine osd_setup(nmlExists) 1
    implicit none
    logical :: nmlExists

    integer :: nulnam,ierr,fnom,fclos
    namelist /namosd/nrandseed,deltaLat,deltaLon,deltaPressure,deltaHeight,numFamily,familyList,numElement,elementList

    ! set default values for namelist variables
    nrandseed = 999
    deltaLat = 10.0d0
    deltaLon = 10.0d0
    deltaPressure = 10000.0d0
    deltaHeight = 5000.0d0

    numFamily = 7
    familyList(:) = 'XX'
    familyList(1) = 'UA'
    familyList(2) = 'AI'
    familyList(3) = 'SC'
    familyList(4) = 'RO'
    familyList(5) = 'TO'
    familyList(6) = 'SW'
    familyList(7) = 'SF'

    numElement = 11
    elementList(:) = 0
    elementList(1) = BUFR_NETT
    elementList(2) = BUFR_NEUU
    elementList(3) = BUFR_NEVV
    elementList(4) = BUFR_NEES
    elementList(5) = BUFR_NEUS
    elementList(6) = BUFR_NEVS
    elementList(7) = BUFR_NBT1
    elementList(8) = BUFR_NBT2
    elementList(9) = BUFR_NBT3
    elementList(10)= BUFR_NERF
    elementList(11)= BUFR_NEPS

    nulnam = 0
    ierr = fnom(nulnam,'./flnml','FTN+SEQ+R/O',0)
    read(nulnam,nml=namosd,iostat=ierr)
    if(ierr.ne.0) then
      write(*,*) 'osd_setup: No valid namelist NAMOSD found, skipping diagnostics'
      nmlExists = .false.
      ierr = fclos(nulnam)
      return
    else
      nmlExists = .true.
    endif
    if(mpi_myid.eq.0) write(*,nml=namosd)
    ierr = fclos(nulnam)

  end subroutine osd_setup

end module obsSpaceDiag_mod