!--------------------------------------- LICENCE BEGIN -----------------------------------
!Environment Canada - Atmospheric Science and Technology License/Disclaimer,
!                     version 3; Last Modified: May 7, 2008.
!This is free but copyrighted software; you can use/redistribute/modify it under the terms
!of the Environment Canada - Atmospheric Science and Technology License/Disclaimer
!version 3 or (at your option) any later version that should be found at:
!http://collaboration.cmc.ec.gc.ca/science/rpn.comm/license.html
!
!This software is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
!without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
!See the above mentioned License/Disclaimer for more details.
!You should have received a copy of the License/Disclaimer along with this software;
!if not, you can write to: EC-RPN COMM Group, 2121 TransCanada, suite 500, Dorval (Quebec),
!CANADA, H9P 1J3; or send e-mail to service.rpn@ec.gc.ca
!-------------------------------------- LICENCE END --------------------------------------

!--------------------------------------------------------------------------
! MODULE BmatrixEnsemble (Background-error Covariance Matrix estimated
!                         using ensemble members and spatial localization
!                         prefix="ben")
!
! Purpose: Performs transformation from control vector to analysis increment 
!          using the spatially localized ensemble covariance matrix
!
! Subroutines:
!    ben_setup (public)
!    ben_BSqrt (public)
!    ben_BSqrtAd (public)
!    setupLocalization
!    setupEnsemble
!    localizationSqrt
!    localizationSqrtAd
!    addEnsMember
!    addEnsMemberAd
!
! Dependencies:
!    globalSpectralTransform
!    matsqrt
!--------------------------------------------------------------------------

MODULE BmatrixEnsemble 1,8
  use mpivar_mod
  use gridStateVector_mod
  use globalSpectralTransform
  use lamSpectralTransform_mod
  use horizontalCoord_mod
  use verticalCoord_mod
  use mathPhysConstants_mod
  use earthConstants_mod
  implicit none
  save
  private

  ! public procedures
  public :: ben_Setup,ben_BSqrt,ben_BSqrtAd,ben_reduceToMPILocal,ben_expandToMPIGlobal,ben_Finalize
  public :: ben_getScaleFactor

  logical             :: initialized = .false.
  integer,parameter   :: maxNumLevels=200
  real(8)             :: scaleFactor(maxNumLevels)
  real(8)             :: scaleFactorLQ(maxNumLevels)
  real(8),allocatable :: ensLocalCor(:,:)
  real(8),allocatable :: ensLocalVert(:,:)
  integer,allocatable :: nip1_M(:),nip1_T(:),verticalLevel(:,:),verticalLevelEns(:)
  integer,allocatable :: jn_vec(:)
  integer             :: nj,ni,lonPerPE,myLonBeg,myLonEnd,latPerPE,myLatBeg,myLatEnd
  integer             :: nLevInc_M,nLevInc_T,nkgdimInc,nLevEns_M,nLevEns_T,nkgdimEns,topLevIndex_M,topLevIndex_T
  integer             :: myMemBeg,myMemEnd,myMemCount
  integer             :: mymBeg,mymEnd,mymSkip,mymCount
  integer             :: ntrunc,nla_mpiglobal,nla_mpilocal,maxMyNla,nphase
  integer             :: nEns,cvDim_mpilocal,cvDim_mpiglobal
  integer             :: numTime
  integer             :: ngposituu,ngpositvv,ngposittt,ngpositq,ngpositps,ngposittg
  integer             :: gstID
  character(len=256)  :: enspathname,ensfilebasename
  integer,external    :: get_max_rss

  type :: struct_ens
    real(4), allocatable :: member_r4(:,:,:,:)
  end type struct_ens
  type(struct_ens), pointer :: ensPerturbations(:)

  real(8), parameter :: rsq2 = sqrt(2.0d0)

  logical                   :: is_staggered
  type(struct_hco), pointer :: hco_ben    ! Analysis horizontal grid parameters
  type(struct_lst)          :: lst_ben    ! Spectral transform Parameters

  integer, pointer    :: ilaList_mpiglobal(:)
  integer, pointer    :: ilaList_mpilocal(:)

CONTAINS

!--------------------------------------------------------------------------
! BEN_setup
!--------------------------------------------------------------------------

  SUBROUTINE BEN_setup(hco_in,vco_in,NUMTIME_IN,stamp_in,CVDIM_OUT) 1,22
    implicit none
  
    type(struct_hco), pointer, intent(in) :: hco_in
    type(struct_vco), pointer, intent(in) :: vco_in

    real(8)        :: hLocalize(2),vLocalize,zps
    real(8),allocatable :: pressureProfileEns_M(:)
    real(8),pointer :: pressureProfileInc_M(:)

    integer        :: numTime_in,stamp_in
    integer        :: cvDim_out
    integer        :: jlev,jn,jm,jproc,ila,mpiMode,return_code,status,Vcode_anl
    integer        :: fnom,fclos,ierr,nulnam
    integer        :: mynBeg,mynEnd,mynSkip

    !namelist
    NAMELIST /NAMBEN/nEns,scaleFactor,scaleFactorLQ,ntrunc,enspathname,ensfilebasename, &
                     nLevEns_M,nLevEns_T,hLocalize,vLocalize

    call tmg_start(12,'BEN_SETUP')
    !
    !- 1.  Read namelist-dependent options
    !

    ! parameters from namelist
    scaleFactor(:)  = 0.0d0
    scaleFactorLQ(:)= 1.0d0
    nEns            = 10
    ntrunc          = 31
    nLevEns_M       = vco_in%nLev_M
    nLevEns_T       = vco_in%nLev_T
    enspathname     = '***NOT_DEFINED***'
    ensfilebasename = ''
    hLocalize(1)    = 2800.0d0
    hLocalize(2)    = -1.0d0
    vLocalize       = 2.0d0

    nulnam = 0
    ierr = fnom(nulnam,'./flnml','FTN+SEQ+R/O',0)
    read(nulnam,nml=namben,iostat=ierr)
    if(ierr.ne.0) call abort3d('ben_setup: Error reading namelist')
    if(mpi_myid.eq.0) write(*,nml=namben)
    ierr = fclos(nulnam)

    !
    !- 2.  Settings
    !
    hco_ben => hco_in
    ni = hco_ben % ni
    nj = hco_ben % nj

    !- 2.1 Global or LAM?
    if (hco_ben % global) then
      if(mpi_myid.eq.0) write(*,*)
      if(mpi_myid.eq.0) write(*,*) 'bmatrixEnsemble: GLOBAL mode activated'
    else
      if(mpi_myid.eq.0) write(*,*)
      if(mpi_myid.eq.0) write(*,*) 'bmatrixEnsemble: LAM mode activated'
    endif

    !- 2.1 Bmatrix Weight
    do jlev = 1, nLevEns_T
      if(scaleFactor(jlev).gt.0.0d0) then 
        scaleFactor(jlev) = sqrt(scaleFactor(jlev))
      else
        scaleFactor(jlev) = 0.0d0
      endif
    enddo

    if ( sum(scaleFactor(1:nLevEns_T)) == 0.0d0 ) then
      if(mpi_myid.eq.0) write(*,*) 'bmatrixEnsemble: scaleFactor=0, skipping rest of setup'
      cvdim_out = 0
      initialized = .true.
      return
    endif

    do jlev = 1, nLevEns_T
      if(scaleFactorLQ(jlev).gt.0.0d0) then 
        scaleFactorLQ(jlev) = sqrt(scaleFactorLQ(jlev))
      else
        scaleFactorLQ(jlev) = 0.0d0
      endif
    enddo

    !- 2.2. Levels
    status = vgd_get(vco_in%vgrid,key='ig_1 - vertical coord code',value=Vcode_anl)
    if(Vcode_anl .eq. 5001) then
      is_staggered = .false.
      if(nLevEns_T.ne.nLevEns_M) then
        write(*,*) 'bmatrixEnsemble: nLevEns_T, nLevEns_M = ',nLevEns_T,nLevEns_M
        call abort3d('bmatrixEnsemble: Vcode=5001, nLevEns_T must equal nLevEns_M!')
      endif
    elseif(Vcode_anl .eq. 5002) then
      is_staggered = .true.
      if(nLevEns_T.ne.(nLevEns_M+1)) then
        write(*,*) 'bmatrixEnsemble: nLevEns_T, nLevEns_M = ',nLevEns_T,nLevEns_M
        call abort3d('bmatrixEnsemble: Vcode=5002, nLevEns_T must equal nLevEns_M+1!')
      endif
    else
      write(*,*) 'Vcode_anl = ',Vcode_anl
      call abort3d('bmatrixEnsemble: unknown vertical coordinate type!')
    endif
    write(*,*) 'bmatrixEnsemble: vertical coord is_staggered = ',is_staggered

    if(nLevEns_M.gt.vco_in%nLev_M) then
      call abort3d('bmatrixEnsemble: ensemble has more levels than increment - not allowed!')
    endif

    if(nLevEns_M.lt.vco_in%nLev_M) then
      if(mpi_myid.eq.0) write(*,*) 'bmatrixEnsemble: ensemble has less levels than increment'
      if(mpi_myid.eq.0) write(*,*) '                 some levels near top will have zero increment'
    endif

    nLevInc_M = vco_in%nLev_M
    nLevInc_T = vco_in%nLev_T
    nkgdimInc = 2*nLevInc_M+2*nLevInc_T+2  ! assume 4 3d and 2 2d variables
    nkgdimEns = 2*nLevEns_M+2*nLevEns_T+2  ! assume 4 3d and 2 2d variables
    topLevIndex_M = nLevInc_M-nLevEns_M+1
    topLevIndex_T = nLevInc_T-nLevEns_T+1

    ! array to convert amplitude field level (1->nLevEns_M) for each variable into ens (1->nkgdimEns)
    allocate(verticalLevel(nLevEns_M,6),stat=ierr)
    if(ierr.ne.0) then
      write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=1',ierr
      call abort3d('aborting in ben_setup')
    endif
    verticalLevel(:,:) = 0
    do jlev = 1, nLevEns_M
      verticalLevel(jlev,1) =                       jlev ! UU
      verticalLevel(jlev,2) = 1*nLevEns_M          +jlev ! VV
    enddo
    do jlev = 1, nLevEns_M
      verticalLevel(jlev,3) = 2*nLevEns_M          +jlev ! TT
      verticalLevel(jlev,4) = 2*nLevEns_M+nLevEns_T+jlev ! HU
    enddo
    verticalLevel(nLevEns_M,5) = 2*nLevEns_M+2*nLevEns_T+1 ! P0
    verticalLevel(nLevEns_M,6) = 2*nLevEns_M+2*nLevEns_T+2 ! TG

    allocate(verticalLevelEns(nkgdimEns),stat=ierr)
    if(ierr.ne.0) then
      write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=2',ierr
      call abort3d('aborting in ben_setup')
    endif
    do jlev = 1, nLevEns_M
      verticalLevelEns(                      jlev) = jlev ! UU
      verticalLevelEns(1*nLevEns_M          +jlev) = jlev ! VV
    enddo
    if(is_staggered) then
      ! adjust so that same scale factor used at the surface (nLevEns_T) for all variables
      verticalLevelEns(:) = verticalLevelEns(:)+1
    endif
    do jlev = 1, nLevEns_T
      verticalLevelEns(2*nLevEns_M          +jlev) = jlev ! TT
      verticalLevelEns(2*nLevEns_M+nLevEns_T+jlev) = jlev ! HU
    enddo
    verticalLevelEns(2*nLevEns_M+2*nLevEns_T+1) = nLevEns_T ! P0
    verticalLevelEns(2*nLevEns_M+2*nLevEns_T+2) = nLevEns_T ! TG

    allocate(nip1_M(nLevEns_M),stat=ierr)
    nip1_M(1:nLevEns_M) = vco_in%ip1_M(topLevIndex_M:nLevInc_M)
    allocate(nip1_T(nLevEns_T),stat=ierr)
    nip1_T(1:nLevEns_T) = vco_in%ip1_T(topLevIndex_T:nLevInc_T)

    !- 2.3 Spectral Transform

    if (hco_ben % global) then

      ! Global Mode
      nphase = 2
      nla_mpiglobal = (ntrunc+1)*(ntrunc+2)/2

      mpimode = 5
      gstID = gst_setup(ni,nj,ntrunc,mpiMode,nEns)
      if(mpi_myid.eq.0) write(*,*) 'BEN : returned value of gstID = ',gstID

      allocate(jn_vec(nla_mpiglobal),stat=ierr)
      if(ierr.ne.0) then
        write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=5',ierr
        call abort3d('aborting in ben_setup')
      endif
      do jn = 0, ntrunc
        do jm = 0, jn
          ila = gst_getnind(jm,gstID)+jn-jm
          jn_vec(ila) = jn
        enddo
      enddo

    else

       ! LAM mode
       call lst_Setup( lst_ben,                          & ! OUT
                       ni, nj, hco_ben % dlon, ntrunc,   & ! IN
                       'LatLonMLev', nEns )                ! IN

       if(mpi_myid.eq.0) write(*,*) 'BEN : returned value of lstID = ', lst_ben % id
       nphase       = lst_ben % nphase
       nla_mpilocal = lst_ben % nla
       nla_mpiglobal= nla_mpilocal !!!! Because the variables in the global-mode code 
                                   !    still uses the global size of nla

    endif

    !- 2.4 Distribute control vector over mpi processes according to member index and m
    call mpivar_setup_levels_npex(nEns,myMemBeg,myMemEnd,myMemCount)
    write(*,*) 'ben_setup: myMemBeg,End,Count=',myMemBeg,myMemEnd,myMemCount

    if ( myMemCount .le. 0 ) then
      write(*,*) 'ERROR: Number of MPI processes must be <= number of ensemble members'
      call abort3d('ben_setup')
    endif

    call mpivar_setup_m(ntrunc,mymBeg,mymEnd,mymSkip,mymCount)

    if (hco_ben % global) then
      ! all total wavenumbers "n" on each mpi task
      mynBeg = 0
      mynEnd = ntrunc
      mynSkip = 1

      ! compute arrays to facilitate conversions between ila_mpilocal and ila_mpiglobal (only split by m)
      call gst_ilaList_mpiglobal(ilaList_mpiglobal,nla_mpilocal,maxMyNla,gstID,mymBeg,mymEnd,mymSkip,mynBeg,mynEnd,mynSkip)
      call gst_ilaList_mpilocal(ilaList_mpilocal,gstID,mymBeg,mymEnd,mymSkip,mynBeg,mynEnd,mynSkip)
    end if

    !- 2.5 Domain Partionning
    call mpivar_setup_latbands(nj,latPerPE,myLatBeg,myLatEnd)
    call mpivar_setup_lonbands(ni,lonPerPE,myLonBeg,myLonEnd)

    !- 2.6 Number of time bins
    if(numTime_in.eq.1.or.numTime_in.eq.3.or.numTime_in.eq.5.or.numTime_in.eq.7) then
      numTime = numTime_in
    else
      call abort3d('Invalid value for NUMTIME (choose 1 or 3 or 5 or 7)!')
    endif

    !- 2.7 Localization
    zps = 101000.D0
    status = vgd_levels( vco_in%vgrid, ip1_list=vco_in%ip1_M, levels=pressureProfileInc_M, &
                         sfc_field=zps, in_log=.false.)

    allocate(pressureProfileEns_M(nLevEns_M),stat=ierr)
    pressureProfileEns_M(1:nLevEns_M) = pressureProfileInc_M(topLevIndex_M:nLevInc_M)
    call setupLocalization(hLocalize,vLocalize,pressureProfileEns_M)
    deallocate(pressureProfileEns_M)
    deallocate(pressureProfileInc_M)

    !
    !- 3.  Read/Process the Ensemble
    !
    call setupEnsemble_latlon(stamp_in,return_code)
    if ( return_code .lt. 0 ) then
      ! could not open local ensemble files, try global files
      call setupEnsemble(stamp_in)
    endif

    if (hco_ben % global) then
      cvDim_mpiglobal = (ntrunc+1)*(ntrunc+1)*nLevEns_M*nEns
      cvDim_mpilocal  = 0
      do jm = mymBeg, mymEnd, mymSkip
        do jn = jm, ntrunc
          if(jm.eq.0) then
            ! controlVector only contains real part for jm=0
            cvDim_mpilocal = cvDim_mpilocal + 1*myMemCount*nLevEns_M
          else
            ! controlVector contains real and imag parts for jm>0
            cvDim_mpilocal = cvDim_mpilocal + 2*myMemCount*nLevEns_M
          endif
        enddo
      enddo
    else
      cvDim_mpiglobal = lst_ben % nlaGlobal * nphase * nLevEns_M * nEns
      cvDim_mpilocal  = nla_mpilocal        * nphase * nLevEns_M * myMemCount
      print*,'cvDim_mpiglobal ', cvDim_mpiglobal, lst_ben % nlaGlobal, nphase, nLevEns_M, nEns
      print*,'cvDim_mpilocal  ', cvDim_mpilocal, nla_mpilocal, nphase, nLevEns_M, myMemCount
    endif
    cvDim_out = cvDim_mpilocal

    !
    !- 4.  Ending
    !
    initialized = .true.

    call tmg_stop(12)

  END SUBROUTINE BEN_setup

!--------------------------------------------------------------------------
! BEN_finalize
!--------------------------------------------------------------------------

  SUBROUTINE BEN_finalize() 1
    implicit none
    integer :: memberIndex

    write(*,*) 'ben_finalize: deallocating B_ensemble arrays'
    do memberIndex = 1, nEns
      deallocate(ensPerturbations(memberIndex)%member_r4)
    enddo
    deallocate(ensPerturbations)
    deallocate(ensLocalCor)
    deallocate(ensLocalVert)

  END SUBROUTINE BEN_finalize



  subroutine ben_getScaleFactor(scaleFactor_out) 1
    implicit none
    real(8) :: scaleFactor_out(:)
    integer :: jlev

    do jlev = 1, nLevInc_T
      scaleFactor_out(jlev) = scaleFactor(jlev)
    enddo

  end subroutine ben_getScaleFactor


!--------------------------------------------------------------------------
! setupEnsemble
!--------------------------------------------------------------------------

  SUBROUTINE setupEnsemble(stamp_in) 1,11
    implicit none
 
    integer :: stamp_in
    real(4), allocatable :: gd2d_r4(:,:)
    real(8), allocatable :: gd2d(:,:)
    real(4), allocatable :: ensPerturbation1_r4(:,:,:) 
    real(4), allocatable :: gd_send_r4(:,:,:,:)
    real(4), allocatable :: gd_recv_r4(:,:,:,:)
    real(8) :: dnens,dnens2
    integer :: ji,jj,jt,jk,memberIndex,jpe,jlonband,jlatband
    integer :: memberIndex2,batchnum,nsize,status,pe_src,pe_dest
    integer :: yourid,youridx,youridy
    integer :: kulin,ndate,ntime
    integer :: readFilePE(256),allLatBeg(256),allLatEnd(256),allLonBeg(256),allLonEnd(256)
    integer :: stamp_last,newdate
    real(8) :: delhh
    logical :: lExists

    ! standard file variables
    integer :: ini,inj,ink,ip1,ip2(9),ip3,ierr,idateo(7),ikey
    character(len=2)   :: cltypvar
    character(len=1)   :: clgrtyp
    character(len=4)   :: clnomvar
    character(len=12)  :: cletiket
    character(len=256) :: cflensin
    character(len=4)   :: censnumber
    character(len=8)   :: datestr_last
    character(len=2)   :: hourstr_last
    integer :: fstlir,fstfrm,fclos,fnom,fstouv,fstinf

    write(*,*) 'Memory Used: ',get_max_rss()/1024,'Mb'

    ! this should come from state vector object
    ngposituu = 1
    ngpositvv = 1+1*nLevEns_M
    ngposittt = 1+2*nLevEns_M
    ngpositq  = 1+2*nLevEns_M+1*nLevEns_T
    ngpositps = 1+2*nLevEns_M+2*nLevEns_T
    ngposittg = 2+2*nLevEns_M+2*nLevEns_T

    allocate(gd_send_r4(lonPerPE,latPerPE,nkgdimEns,mpi_nprocs))
    allocate(gd_recv_r4(lonPerPE,latPerPE,nkgdimEns,mpi_nprocs))

    allocate(ensPerturbations(nEns),stat=ierr)
    if(ierr.ne.0) then
      write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=7.1',ierr
      call abort3d('aborting in ben setupEnsemble')
    endif
    do memberIndex = 1, nEns
      allocate(ensPerturbations(memberIndex)%member_r4(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nkgdimEns,numTime),stat=ierr)
      if(ierr.ne.0) then
        write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=7.2',ierr,memberIndex
        call abort3d('aborting in ben setupEnsemble')
      endif
    enddo
    allocate(ensPerturbation1_r4(ni,nj,nkgdimEns),stat=ierr)
    if(ierr.ne.0) then
      write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=8',ierr
      call abort3d('aborting in ben setupEnsemble')
    endif

    write(*,*) 'Memory Used: ',get_max_rss()/1024,'Mb'

    ! read in raw ensemble (UU,VV,TT,P0,LQ (convert HU to LQ) - covariances)

    CALL rpn_comm_allgather(myLatBeg,1,"mpi_integer",       &
                            allLatBeg,1,"mpi_integer","NS",ierr)
    CALL rpn_comm_allgather(myLatEnd,1,"mpi_integer",       &
                            allLatEnd,1,"mpi_integer","NS",ierr)

    CALL rpn_comm_allgather(myLonBeg,1,"mpi_integer",       &
                            allLonBeg,1,"mpi_integer","EW",ierr)
    CALL rpn_comm_allgather(myLonEnd,1,"mpi_integer",       &
                            allLonEnd,1,"mpi_integer","EW",ierr)

    do memberIndex = 1, nEns
      readFilePE(memberIndex) = mod(memberIndex-1,mpi_nprocs)
    enddo

    delhh = -6.0d0
    call incdatr(stamp_last,stamp_in,delhh)
    ierr = newdate(stamp_last,ndate,ntime,-3)
    write(datestr_last,'(i8.8)') ndate
    write(hourstr_last,'(i2.2)') ntime/1000000
    if(mpi_myid.eq.0) write(*,*) 'DATE,TIME=',ndate,'  ,',ntime

    if(numTime.eq.1) then
      call incdatr(idateo(1),stamp_in, 0.0d0)
    elseif(numTime.eq.3) then
      call incdatr(idateo(1),stamp_in,-3.0d0)
      call incdatr(idateo(2),stamp_in, 0.0d0)
      call incdatr(idateo(3),stamp_in, 3.0d0)
    elseif(numTime.eq.5) then
      call incdatr(idateo(1),stamp_in,-3.0d0)
      call incdatr(idateo(2),stamp_in,-1.5d0)
      call incdatr(idateo(3),stamp_in, 0.0d0)
      call incdatr(idateo(4),stamp_in, 1.5d0)
      call incdatr(idateo(5),stamp_in, 3.0d0)
    elseif(numTime.eq.7) then
      call incdatr(idateo(1),stamp_in,-3.0d0)
      call incdatr(idateo(2),stamp_in,-2.0d0)
      call incdatr(idateo(3),stamp_in,-1.0d0)
      call incdatr(idateo(4),stamp_in, 0.0d0)
      call incdatr(idateo(5),stamp_in, 1.0d0)
      call incdatr(idateo(6),stamp_in, 2.0d0)
      call incdatr(idateo(7),stamp_in, 3.0d0)
    else
      write(*,*) 'bmatrixEnsemble: Problem with number of timesteps for ensemble=',numTime
      call abort3d('aborting in ben setupEnsemble')
    endif
    ip2 = -1
    ip3 = -1
    cltypvar = ' '
    cletiket = ' '

    if(mpi_myid.eq.0) write(*,*) 'idateo=',idateo(1:numTime)

    do jt = 1, numTime  ! read all timesteps
      do memberIndex = 1, nEns

        if(mpi_myid.eq.readFilePE(memberIndex)) then

          ! open the file

          ! first try to open file with 4 digit member number
          write(censnumber,'(i4.4)') memberIndex
          cflensin = trim(enspathname) // '/' // trim(ensfilebasename) // &
                     trim(datestr_last) // trim(hourstr_last) // '_006_' // trim(censnumber)
          inquire(file=cflensin,exist=lExists)
          if(lExists) then
            kulin = 0
            ierr = fnom(kulin,cflensin,'RND+OLD+R/O',0)
          else
            ! now try to open file with (older) 3 digit member number
            write(censnumber,'(i3.3)') memberIndex
            cflensin = trim(enspathname) // '/' // trim(ensfilebasename) // &
                       trim(datestr_last) // trim(hourstr_last) // '_006_' // trim(censnumber)
            inquire(file=cflensin,exist=lExists)
            if(lExists) then
              kulin = 0
              ierr = fnom(kulin,cflensin,'RND+OLD+R/O',0)
            else
              write(*,*) 'filename=',trim(cflensin)
              call abort3d('ben setupEnsemble: Could not open ensemble file')
            endif
          endif
          ierr = fstouv(kulin,'RND+OLD')

          ! get grid parameters by looking at P0
          if(.not.allocated(gd2d_r4)) then
            clnomvar = 'P0' 
            ikey = fstinf(kulin,ini,inj,ink,idateo(1),cletiket,-1,ip2,ip3,cltypvar,clnomvar)
            write(*,*) 'ben_setupensemble: allocating temporary 2D buffer with ini,inj=',ini,inj
            allocate(gd2d_r4(ini,inj))
          endif

          write(*,*) 'Reading time slice ',jt,' for ensemble member:',trim(cflensin)
          write(*,*) 'Memory Used: ',get_max_rss()/1024,'Mb'

          ! read 1 member per mpi task, only 1 timestep

          clnomvar = 'P0' 
          ikey = fstlir(gd2d_r4,kulin,ini,inj,ink,idateo(jt),cletiket,-1,ip2,ip3,cltypvar,clnomvar)
          if(ikey.lt.0) then
            call abort3d('SUENS: Problem with P0 ENS')
          endif
          do jj = 1, nj
            do ji = 1, ni
              if (hco_ben % global) then
                 ensPerturbation1_r4(ji,jj,ngpositps) = sngl(gd2d_r4(ji,nj+1-jj)*MPC_PA_PER_MBAR_R8)
              else
                 ensPerturbation1_r4(ji,jj,ngpositps) = sngl(gd2d_r4(ji,jj)*MPC_PA_PER_MBAR_R8)
              endif
            enddo
          enddo

          do jk = 1, nLevEns_T
            clnomvar = 'TT'
            ikey = fstlir(gd2d_r4,kulin,ini,inj,ink,idateo(jt),cletiket,nip1_T(jk),ip2,ip3,cltypvar,clnomvar)
            if(ikey.lt.0) then
              write(*,*) idateo(jt),cletiket,nip1_T(jk),ip2,ip3,cltypvar,clnomvar
              call abort3d('SUENS: Problem with TT ENS')
            endif
            do jj = 1, nj
              do ji = 1, ni
                if (hco_ben % global) then
                   ensPerturbation1_r4(ji,jj,jk-1+ngposittt) = gd2d_r4(ji,nj+1-jj)
                else
                   ensPerturbation1_r4(ji,jj,jk-1+ngposittt) = gd2d_r4(ji,jj)
                endif
              enddo
            enddo
          enddo

          do jk = 1, nLevEns_T
            clnomvar = 'HU' 
            ikey = fstlir(gd2d_r4,kulin,ini,inj,ink,idateo(jt),cletiket,nip1_T(jk),ip2,ip3,cltypvar,clnomvar)
            if(ikey.lt.0) then
              write(*,*) idateo(jt),cletiket,nip1_T(jk),ip2,ip3,cltypvar,clnomvar
              call abort3d('SUENS: Problem with HU ENS')
            endif
            do jj = 1, nj
              do ji = 1, ni
                if (hco_ben % global) then
                   ensPerturbation1_r4(ji,jj,jk-1+ngpositq) = sngl(log(max(real(gd2d_r4(ji,nj+1-jj),8),MPC_MINIMUM_HU_R8)))
                else
                   ensPerturbation1_r4(ji,jj,jk-1+ngpositq) = sngl(log(max(real(gd2d_r4(ji,jj),8),MPC_MINIMUM_HU_R8)))
                endif
              enddo
            enddo
          enddo

          do jk = 1, nLevEns_M
            clnomvar = 'UU' 
            ikey = fstlir(gd2d_r4,kulin,ini,inj,ink,idateo(jt),cletiket,nip1_M(jk),ip2,ip3,cltypvar,clnomvar)
            if(ikey.lt.0) then
              write(*,*) idateo(jt),cletiket,nip1_M(jk),ip2,ip3,cltypvar,clnomvar
              call abort3d('SUENS: Problem with UU ENS')
            endif
            do jj = 1, nj
              do ji = 1, ni
                if (hco_ben % global) then
                   ensPerturbation1_r4(ji,jj,jk-1+ngposituu) = sngl(gd2d_r4(ji,nj+1-jj)*MPC_M_PER_S_PER_KNOT_R8)
                else
                   ensPerturbation1_r4(ji,jj,jk-1+ngposituu) = sngl(gd2d_r4(ji,jj)*MPC_M_PER_S_PER_KNOT_R8)
                endif
              enddo
            enddo
          enddo

          do jk = 1, nLevEns_M
            clnomvar = 'VV' 
            ikey = fstlir(gd2d_r4,kulin,ini,inj,ink,idateo(jt),cletiket,nip1_M(jk),ip2,ip3,cltypvar,clnomvar)
            if(ikey.lt.0) then
              write(*,*) idateo(jt),cletiket,nip1_M(jk),ip2,ip3,cltypvar,clnomvar
              call abort3d('SUENS: Problem with VV ENS')
            endif
            do jj = 1, nj
              do ji = 1, ni
                if (hco_ben % global) then
                   ensPerturbation1_r4(ji,jj,jk-1+ngpositvv) = sngl(gd2d_r4(ji,nj+1-jj)*MPC_M_PER_S_PER_KNOT_R8)
                else
                   ensPerturbation1_r4(ji,jj,jk-1+ngpositvv) = sngl(gd2d_r4(ji,jj)*MPC_M_PER_S_PER_KNOT_R8)
                endif
              enddo
            enddo
          enddo

          clnomvar = 'TG' 
          ikey = fstlir(gd2d_r4,kulin,ini,inj,ink,idateo(jt),cletiket,-1,ip2,ip3,cltypvar,clnomvar)
          if(ikey.lt.0)  then
            write(*,*) idateo(jt),cletiket,ip2,ip3,cltypvar,clnomvar
            call abort3d('SUENS: Problem with TG ENS')
          else
            do jj = 1, nj
              do ji = 1, ni
                if (hco_ben % global) then
                   ensPerturbation1_r4(ji,jj,ngposittg) = gd2d_r4(ji,nj+1-jj)
                else
                   ensPerturbation1_r4(ji,jj,ngposittg) = gd2d_r4(ji,jj)
                endif
              enddo
            enddo
          endif

          ierr =  fstfrm(kulin)
          ierr =  fclos (kulin)

        endif

        ! do mpi communication
        if(readFilePE(memberIndex).eq.(mpi_nprocs-1) .or. memberIndex.eq.nEns) then
          call tmg_start(13,'PRE_SUENS_COMM')
          batchnum = ceiling(dble(memberIndex)/dble(mpi_nprocs))

!$OMP PARALLEL DO PRIVATE(youridy,youridx,yourid)
          do youridy = 0, (mpi_npey-1)
            do youridx = 0, (mpi_npex-1)
              yourid = youridx + youridy*mpi_npex
              gd_send_r4(:,:,:,yourid+1) =  &
                ensPerturbation1_r4(allLonBeg(youridx+1):allLonEnd(youridx+1),  &
                                    allLatBeg(youridy+1):allLatEnd(youridy+1),:)
            enddo
          enddo
!$OMP END PARALLEL DO

          nsize = lonPerPE*latPerPE*nkgdimEns
          if(mpi_nprocs.gt.1) then
            call rpn_comm_alltoall(gd_send_r4,nsize,"mpi_real4",  &
                                   gd_recv_r4,nsize,"mpi_real4","GRID",ierr)
          else
            gd_recv_r4(:,:,:,1) = gd_send_r4(:,:,:,1)
          endif

!$OMP PARALLEL DO PRIVATE(memberIndex2,yourid)
          do memberIndex2 = 1+(batchnum-1)*mpi_nprocs, memberIndex
            yourid = readFilePE(memberIndex2)
            ensPerturbations(memberIndex2)%member_r4(:,:,:,jt) = gd_recv_r4(:,:,:,yourid+1)
          enddo
!$OMP END PARALLEL DO

          call tmg_stop(13)

        endif ! do communication

      enddo ! memberIndex

    enddo ! jt

    deallocate(ensPerturbation1_r4)
    deallocate(gd_send_r4)
    deallocate(gd_recv_r4)

    write(*,*) 'finished reading and communicating ensemble members...'

    ! remove mean and divide by sqrt(2*(NENS-1)) - extra 2 is needed?
    allocate(gd2d(myLonBeg:myLonEnd,myLatBeg:myLatEnd))
    dnens = 1.0d0/dble(nEns)
    do jt = 1, numTime
!$OMP PARALLEL
!$OMP DO PRIVATE (JK,dnens2,GD2D,MEMBERINDEX,JJ,JI)
      do jk = 1, nkgdimEns
        dnens2 = scaleFactor(verticalLevelEns(jk))/sqrt(1.0d0*dble(nEns-1))
        if(jk.ge.(1+2*nLevEns_M+nLevEns_T) .and. jk.le.(2*nLevEns_M+2*nLevEns_T)) then
          dnens2 = dnens2*scaleFactorLQ(verticalLevelEns(jk))
        endif
        gd2d(:,:) = 0.0d0
        do memberIndex = 1, nEns
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              gd2d(ji,jj) = gd2d(ji,jj) + dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jk,jt))
            enddo
          enddo
        enddo
        do jj = myLatBeg, myLatEnd
          do ji = myLonBeg, myLonEnd
            gd2d(ji,jj) = gd2d(ji,jj)*dnens
          enddo
        enddo
        do memberIndex = 1, nEns
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              ensPerturbations(memberIndex)%member_r4(ji,jj,jk,jt) =      &
                sngl((dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jk,jt))-gd2d(ji,jj))*dnens2)
            enddo
          enddo
        enddo
      enddo
!$OMP END DO
!$OMP END PARALLEL
    enddo
    deallocate(gd2d)

    write(*,*) 'finished adjusting ensemble members...'
 
  END SUBROUTINE setupEnsemble

!--------------------------------------------------------------------------
! setupEnsemble_latlon
!--------------------------------------------------------------------------

  SUBROUTINE setupEnsemble_latlon(stamp_in,return_code) 1,16
    implicit none
 
    integer :: stamp_in,return_code
    real(8) :: gd2d(myLonBeg:myLonEnd,myLatBeg:myLatEnd)
    real(8) :: dnens,dnens2
    integer :: ji,jj,jt,jk,memberIndex,memberIndex2
    integer :: kulin,ndate,ntime
    integer :: stamp_last,newdate
    real(8) :: delhh
    logical :: lExists

    ! standard file variables
    integer :: ini,inj,ink,ip1,ip2(9),ip3,ierr,idateo(7),ikey
    character(len=2)   :: cltypvar
    character(len=1)   :: clgrtyp
    character(len=4)   :: clnomvar
    character(len=12)  :: cletiket
    character(len=256) :: cflensin
    character(len=4)   :: censnumber,latBandNumber,lonBandNumber
    character(len=8)   :: datestr_last
    character(len=2)   :: hourstr_last
    integer :: vfstlir,fstfrm,fclos,fnom,fstouv

    ! set OK value for return_code
    return_code = 0

    ! this should come from state vector object
    ngposituu = 1
    ngpositvv = 1+1*nLevEns_M
    ngposittt = 1+2*nLevEns_M
    ngpositq  = 1+2*nLevEns_M+1*nLevEns_T
    ngpositps = 1+2*nLevEns_M+2*nLevEns_T
    ngposittg = 2+2*nLevEns_M+2*nLevEns_T

    allocate(ensPerturbations(nEns),stat=ierr)
    if(ierr.ne.0) then
      write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=7.1',ierr
      call abort3d('aborting in ben setupEnsemble_latlon')
    endif
    do memberIndex = 1, nEns
      allocate(ensPerturbations(memberIndex)%member_r4(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nkgdimEns,numTime),stat=ierr)
      if(ierr.ne.0) then
        write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=7.2',ierr,memberIndex
        call abort3d('aborting in ben setupEnsemble_latlon')
      endif
    enddo

    ! read in raw ensemble (UU,VV,TT,P0,LQ (convert HU to LQ) - covariances)

    delhh = -6.0d0
    call incdatr(stamp_last,stamp_in,delhh)
    ierr = newdate(stamp_last,ndate,ntime,-3)
    write(datestr_last,'(i8.8)') ndate
    write(hourstr_last,'(i2.2)') ntime/1000000
    if(mpi_myid.eq.0) write(*,*) 'DATE,TIME=',ndate,'  ,',ntime

    if(numTime.eq.1) then
      call incdatr(idateo(1),stamp_in, 0.0d0)
    elseif(numTime.eq.3) then
      call incdatr(idateo(1),stamp_in,-3.0d0)
      call incdatr(idateo(2),stamp_in, 0.0d0)
      call incdatr(idateo(3),stamp_in, 3.0d0)
    elseif(numTime.eq.5) then
      call incdatr(idateo(1),stamp_in,-3.0d0)
      call incdatr(idateo(2),stamp_in,-1.5d0)
      call incdatr(idateo(3),stamp_in, 0.0d0)
      call incdatr(idateo(4),stamp_in, 1.5d0)
      call incdatr(idateo(5),stamp_in, 3.0d0)
    elseif(numTime.eq.7) then
      call incdatr(idateo(1),stamp_in,-3.0d0)
      call incdatr(idateo(2),stamp_in,-2.0d0)
      call incdatr(idateo(3),stamp_in,-1.0d0)
      call incdatr(idateo(4),stamp_in, 0.0d0)
      call incdatr(idateo(5),stamp_in, 1.0d0)
      call incdatr(idateo(6),stamp_in, 2.0d0)
      call incdatr(idateo(7),stamp_in, 3.0d0)
    else
      write(*,*) 'bmatrixEnsemble: Problem with number of timesteps for ensemble=',numTime
      call abort3d('aborting in ben setupEnsemble_latlon')
    endif
    ip2 = -1
    ip3 = -1
    cltypvar = ' '
    cletiket = ' '

    if(mpi_myid.eq.0) write(*,*) 'idateo=',idateo(1:numTime)

    if (hco_ben % global) then
      ! latitude band is reverse order as proc id due to north-south flip
      write(latBandNumber,'(i4.4)') mpi_npey-mpi_myidy
    else
      write(latBandNumber,'(i4.4)') mpi_myidy+1
    endif
    write(lonBandNumber,'(i4.4)') mpi_myidx+1

    do memberIndex = 1, nEns

      ! first try to open file with 4 digit member number
      write(censnumber,'(i4.4)') memberIndex

      !! try filename 'subdomain_0001_0024/2011020100_006_0013'
      cflensin = trim(enspathname) // '/subdomain_' // lonBandNumber // '_' // trim(latBandNumber) // '/' // &
                 trim(ensfilebasename) // trim(datestr_last) // trim(hourstr_last) // &
                 '_006_' // trim(censnumber)
      inquire(file=cflensin,exist=lExists)
      if(lExists) then
        kulin = 0
        ierr = fnom(kulin,cflensin,'RND+OLD+R/O',0)
      else

        write(censnumber,'(i3.3)') memberIndex
        !! try filename 'subdomain_0001_0024/2011020100_006_013'
        cflensin = trim(enspathname) // '/subdomain_' // lonBandNumber // '_' // trim(latBandNumber) // '/' // &
                   trim(ensfilebasename) // trim(datestr_last) // trim(hourstr_last) // &
                   '_006_' // trim(censnumber)
        inquire(file=cflensin,exist=lExists)
        if(lExists) then
          kulin = 0
          ierr = fnom(kulin,cflensin,'RND+OLD+R/O',0)
        else

          write(censnumber,'(i4.4)') memberIndex
          !! try filename 'latband_0024/2011020100_006_0013'
          cflensin = trim(enspathname) // '/latband_' // trim(latBandNumber) // '/' // &
                     trim(ensfilebasename) // trim(datestr_last) // trim(hourstr_last) // &
                     '_006_' // trim(censnumber)
          inquire(file=cflensin,exist=lExists)
          if(lExists) then
            kulin = 0
            ierr = fnom(kulin,cflensin,'RND+OLD+R/O',0)
          else
            write(censnumber,'(i3.3)') memberIndex
            !! try filename 'latband_0024/2011020100_006_013'
            cflensin = trim(enspathname) // '/latband_' // trim(latBandNumber) // '/' // &
                       trim(ensfilebasename) // trim(datestr_last) // trim(hourstr_last) // &
                       '_006_' // trim(censnumber)
            inquire(file=cflensin,exist=lExists)
            if(lExists) then
              kulin = 0
              ierr = fnom(kulin,cflensin,'RND+OLD+R/O',0)
            else
              write(*,*) 'filename=',trim(cflensin)
              write(*,*) '==========================================================================='
              write(*,*) 'ben setupEnsemble_latlon: Could not open ensemble file, trying global files'
              write(*,*) '==========================================================================='
              return_code = -1
              do memberIndex2 = 1, nEns
                deallocate(ensPerturbations(memberIndex2)%member_r4)
              enddo
              deallocate(ensPerturbations)
              return
            endif
          endif
        endif
      endif
      write(*,*) 'ben setupEnsemble_latlon: opening file ', cflensin
      ierr = fstouv(kulin,'RND+OLD')

      write(*,*) 'Reading ',numTime,' time slices for ensemble member:',trim(cflensin)
      write(*,*) 'reading member:',memberIndex

      do jt = 1, numTime  ! read all timesteps

        clnomvar = 'P0' 
        ikey = vfstlir(gd2d,kulin,ini,inj,ink,idateo(jt),cletiket,-1,ip2,ip3,cltypvar,clnomvar)
        if(ikey.lt.0) then
          call abort3d('SUENS: Problem with P0 ENS')
        endif
        call CheckEnsDim(ini,inj,ink,clnomvar) ! IN
        do jj = myLatBeg, myLatEnd
          do ji = myLonBeg, myLonEnd
            if (hco_ben % global) then
               ensPerturbations(memberIndex)%member_r4(ji,jj,ngpositps,jt) = sngl(gd2d(ji,myLatEnd+myLatBeg-jj)*MPC_PA_PER_MBAR_R8)
            else
               ensPerturbations(memberIndex)%member_r4(ji,jj,ngpositps,jt) = sngl(gd2d(ji,jj)*MPC_PA_PER_MBAR_R8)
            endif
          enddo
        enddo

        do jk = 1, nLevEns_T
          clnomvar = 'TT'
          ikey = vfstlir(gd2d,kulin,ini,inj,ink,idateo(jt),cletiket,nip1_T(jk),ip2,ip3,cltypvar,clnomvar)
          if(ikey.lt.0) then
            write(*,*) idateo(jt),cletiket,nip1_T(jk),ip2,ip3,cltypvar,clnomvar
            call abort3d('SUENS: Problem with TT ENS')
          endif
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              if (hco_ben % global) then
                 ensPerturbations(memberIndex)%member_r4(ji,jj,jk-1+ngposittt,jt) = sngl(gd2d(ji,myLatEnd+myLatBeg-jj))
              else
                 ensPerturbations(memberIndex)%member_r4(ji,jj,jk-1+ngposittt,jt) = sngl(gd2d(ji,jj))
              endif
            enddo
          enddo
        enddo

        do jk = 1, nLevEns_T
          clnomvar = 'HU' 
          ikey = vfstlir(gd2d,kulin,ini,inj,ink,idateo(jt),cletiket,nip1_T(jk),ip2,ip3,cltypvar,clnomvar)
          if(ikey.lt.0) then
            write(*,*) idateo(jt),cletiket,nip1_T(jk),ip2,ip3,cltypvar,clnomvar
            call abort3d('SUENS: Problem with HU ENS')
          endif
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              if (hco_ben % global) then
                 ensPerturbations(memberIndex)%member_r4(ji,jj,jk-1+ngpositq,jt) = sngl(log(max(gd2d(ji,myLatEnd+myLatBeg-jj),MPC_MINIMUM_HU_R8)))
              else
                 ensPerturbations(memberIndex)%member_r4(ji,jj,jk-1+ngpositq,jt) = sngl(log(max(gd2d(ji,jj),MPC_MINIMUM_HU_R8)))
              endif
            enddo
          enddo
        enddo

        do jk = 1, nLevEns_M
          clnomvar = 'UU' 
          ikey = vfstlir(gd2d,kulin,ini,inj,ink,idateo(jt),cletiket,nip1_M(jk),ip2,ip3,cltypvar,clnomvar)
          if(ikey.lt.0) then
            write(*,*) idateo(jt),cletiket,nip1_M(jk),ip2,ip3,cltypvar,clnomvar
            call abort3d('SUENS: Problem with UU ENS')
          endif
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              if (hco_ben % global) then
                 ensPerturbations(memberIndex)%member_r4(ji,jj,jk-1+ngposituu,jt) = sngl(gd2d(ji,myLatEnd+myLatBeg-jj)*MPC_M_PER_S_PER_KNOT_R8)
              else
                 ensPerturbations(memberIndex)%member_r4(ji,jj,jk-1+ngposituu,jt) = sngl(gd2d(ji,jj)*MPC_M_PER_S_PER_KNOT_R8)
              endif
            enddo
          enddo
        enddo

        do jk = 1, nLevEns_M
          clnomvar = 'VV' 
          ikey = vfstlir(gd2d,kulin,ini,inj,ink,idateo(jt),cletiket,nip1_M(jk),ip2,ip3,cltypvar,clnomvar)
          if(ikey.lt.0) then
            write(*,*) idateo(jt),cletiket,nip1_M(jk),ip2,ip3,cltypvar,clnomvar
            call abort3d('SUENS: Problem with VV ENS')
          endif
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              if (hco_ben % global) then
                 ensPerturbations(memberIndex)%member_r4(ji,jj,jk-1+ngpositvv,jt) = sngl(gd2d(ji,myLatEnd+myLatBeg-jj)*MPC_M_PER_S_PER_KNOT_R8)
              else
                 ensPerturbations(memberIndex)%member_r4(ji,jj,jk-1+ngpositvv,jt) = sngl(gd2d(ji,jj)*MPC_M_PER_S_PER_KNOT_R8)
              endif
            enddo
          enddo
        enddo

        clnomvar = 'TG' 
        ikey = vfstlir(gd2d,kulin,ini,inj,ink,idateo(jt),cletiket,-1,ip2,ip3,cltypvar,clnomvar)
        if(ikey.lt.0)  then
          write(*,*) idateo(jt),cletiket,ip2,ip3,cltypvar,clnomvar
          call abort3d('SUENS: Problem with TG ENS')
        else
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              if (hco_ben % global) then
                 ensPerturbations(memberIndex)%member_r4(ji,jj,ngposittg,jt) = sngl(gd2d(ji,myLatEnd+myLatBeg-jj))
              else
                 ensPerturbations(memberIndex)%member_r4(ji,jj,ngposittg,jt) = sngl(gd2d(ji,jj))
              endif
            enddo
          enddo
        endif

      enddo

      ierr =  fstfrm(kulin)
      ierr =  fclos (kulin)

      write(*,*) 'done reading member ',memberIndex

    enddo

    write(*,*) 'finished reading ensemble members...'

    ! remove mean and divide by sqrt(2*(NENS-1)) - extra 2 is needed?
    dnens = 1.0d0/dble(nEns)
    do jt = 1, numTime
!$OMP PARALLEL
!$OMP DO PRIVATE (JK,dnens2,GD2D,MEMBERINDEX,JJ,JI)
      do jk = 1, nkgdimEns
        dnens2 = scaleFactor(verticalLevelEns(jk))/sqrt(1.0d0*dble(nEns-1))
        if(jk.ge.(1+2*nLevEns_M+nLevEns_T) .and. jk.le.(2*nLevEns_M+2*nLevEns_T)) then
          dnens2 = dnens2*scaleFactorLQ(verticalLevelEns(jk))
        endif
        gd2d(:,:) = 0.0d0
        do memberIndex = 1, nEns
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              gd2d(ji,jj) = gd2d(ji,jj)+dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jk,jt))
            enddo
          enddo
        enddo
        do jj = myLatBeg, myLatEnd
          do ji = myLonBeg, myLonEnd
            gd2d(ji,jj) = gd2d(ji,jj)*dnens
          enddo
        enddo
        do memberIndex = 1, nEns
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              ensPerturbations(memberIndex)%member_r4(ji,jj,jk,jt) =      &
                sngl((dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jk,jt))-gd2d(ji,jj))*dnens2)
            enddo
          enddo
        enddo
      enddo
!$OMP END DO
!$OMP END PARALLEL
    enddo

    write(*,*) 'finished adjusting ensemble members...'

  END SUBROUTINE setupEnsemble_latlon

!--------------------------------------------------------------------------
! CheckEnsDim
!--------------------------------------------------------------------------

  SUBROUTINE CheckEnsDim(niEns,njEns,nkEns,nomvar) 1,1
    implicit none

    integer,      intent(in) :: niEns,njEns,nkEns
    character(*), intent(in) :: nomvar

    if ( niEns /= lonPerPE  .or. &
         njEns /= latPerPE  .or. &
         nkEns /= 1 ) then

       write(*,*) 'Variable :', trim(nomvar)
       write(*,*) 'i-dim = ', niEns, lonPerPE
       write(*,*) 'j-dim = ', njEns, latPerPE
       write(*,*) 'k-dim = ', nkEns, 1
       call abort3d('Ensemble dimensions are incompatible with the topology and/or the analysis grid')
    endif

  END SUBROUTINE CheckEnsDim

!--------------------------------------------------------------------------
! setupLocalization
!--------------------------------------------------------------------------

  SUBROUTINE setupLocalization(hLocalize,vLocalize,pressureProfile) 1,5
    implicit none

    real(8)  :: zlc,zr,zpole,zcorr,hLocalize(:),vLocalize,pressureProfile(:)

    integer :: ilen,jn,jlat,jla,jlon,jk,jk1,jk2,nsize,ierr

    real(8)  :: local_length(nLevEns_M)

    !
    !- 1. Allocation
    !
    allocate(ensLocalCor(0:ntrunc,nLevEns_M),stat=ierr)
    if (ierr.ne.0 ) then
      write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=9',ierr
      call abort3d('aborting in ben: setupLocalization')
    endif

    allocate(ensLocalVert(nLevEns_M,nLevEns_M),stat=ierr)
    if (ierr.ne.0 ) then
      write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=10',ierr
      call abort3d('aborting in ben: setupLocalization')
    endif

    !
    !- 2.  Compute HORIZONTAL localization correlation matrix
    !

    !- 2.1 Determine localization length scale for each vertical level
    if(hLocalize(2).lt.0.0d0) then
      ! vertically constant horizontal localization
      local_length(:) = hLocalize(1)
    else
      ! vertically varying horizontal localization (linear in log P)
      do jk = 1, nLevEns_M
        local_length(jk) = ( hLocalize(1)*( log(pressureProfile(jk       ))-log(pressureProfile(1 )) ) +    &
                             hLocalize(2)*( log(pressureProfile(nLevEns_M))-log(pressureProfile(jk)) ) ) /  &
                           ( log(pressureProfile(nLevEns_M))-log(pressureProfile(1)) )
        if(mpi_myid.eq.0) write(*,*) 'bmatrixEnsemble: localization length scale (',jk,') = ',local_length(jk)
      enddo      
    endif 

    !- 2.2. Compute matrix
    if (hco_ben % global) then
       call setupGlobalHLoc(local_length) ! IN
    else
       call setupLamHLoc(local_length) ! IN
    endif

    !
    !- 3.  Compute VERTICAL localization correlation matrix
    !
    
    !- 3.1 Calculate 5'th order function
    ZLC = vLocalize/2.0d0
    do jk1 = 1, nLevEns_M
      do jk2 = 1, nLevEns_M
        ZR = abs(log(pressureProfile(jk2)) - log(pressureProfile(jk1)))
        if(ZR.le.ZLC) then
          zcorr = -0.250d0*(ZR/ZLC)**5 + 0.5d0*(ZR/ZLC)**4 + 0.625d0*(ZR/ZLC)**3   &
                  -(5.0d0/3.0d0)*(ZR/ZLC)**2 + 1.0d0
        elseif(ZR.le.(2.0d0*ZLC)) then
          zcorr = (1.0d0/12.0d0)*(ZR/ZLC)**5   -0.5d0*(ZR/ZLC)**4     &
                   + 0.625d0*(ZR/ZLC)**3 +(5.0d0/3.0d0)*(ZR/ZLC)**2  &
                   - 5.0d0*(ZR/ZLC) +4.0d0 -(2.0d0/3.0d0)*(ZLC/ZR) 
        else
          zcorr = 0.0d0
        endif
        if(zcorr.lt.0.0d0) zcorr = 0.0d0
        ensLocalVert(jk1,jk2) = zcorr
      enddo
    enddo

    !- 3.2 Compute sqrt of the matrix if vertical localization requested
    call matsqrt(ensLocalVert(1,1),nLevEns_M,1.0d0)

    !
    !- 4.  Ending
    !
    if(mpi_myid.eq.0) write(*,*)'done setting up localization function'

  END SUBROUTINE setupLocalization

!--------------------------------------------------------------------------
! FifthOrderCorrelFunction
!--------------------------------------------------------------------------

  function FifthOrderCorrelFunction(distance,halflength) result(correlation) 5
    implicit none

    real(8) :: distance, halflength
    real(8) :: correlation

    if ( distance <= halflength ) then
       correlation =        -0.250d0*(distance/halflength)**5  &
                     +         0.5d0*(distance/halflength)**4  &
                     +       0.625d0*(distance/halflength)**3  &
                     - (5.0d0/3.0d0)*(distance/halflength)**2  &
                     + 1.0d0
    else if ( distance <= (2.0d0*halflength) ) then
       correlation =  (1.0d0/12.0d0)*(distance/halflength)**5  &
                     -         0.5d0*(distance/halflength)**4  &
                     +       0.625d0*(distance/halflength)**3  &
                     + (5.0d0/3.0d0)*(distance/halflength)**2  &
                     -         5.0d0*(distance/halflength)     &
                     + 4.0d0                                 &
                     - (2.0d0/3.0d0)*(halflength/distance) 
    else
       correlation = 0.d0
    endif

  end function FifthOrderCorrelFunction

!--------------------------------------------------------------------------
! setupGlobalHLoc
!--------------------------------------------------------------------------

  SUBROUTINE setupGlobalHLoc(local_length) 1,5
    implicit none

    real(8), intent(in)  :: local_length(nLevEns_M)

    real(8) ::   zlc,zr,zpole,zcorr

    ! NOTE: arrays passed to spectral transform are dimensioned as follows
    !       gd: lat/lon tiles and sp: member index
    real(8) :: zsp_gst(nla_mpiglobal,nphase,myMemBeg:myMemEnd)
    real(8) :: zgd_gst(myLonBeg:myLonEnd,nEns,myLatBeg:myLatEnd)
    real(8), allocatable :: zsp(:,:,:)
    real(8), allocatable :: my_zsp(:,:,:)

    integer :: ilen,jn,jlat,jla,jlon,jk,jk1,jk2,nsize,ierr

    if(local_length(1).gt.0.0d0) then

      allocate(zsp(nla_mpiglobal,nphase,nLevEns_M))
      allocate(my_zsp(nla_mpiglobal,nphase,nLevEns_M))
      my_zsp=0.0d0 
      zgd_gst(:,:,:) = 0.0d0
      do jk = 1, nLevEns_M

        ! Calculate 5th Order Correlation Functions in Physical Space
        zlc = 1000.0d0*local_length(jk)/2.0d0
        do jlat = myLatBeg, myLatEnd
          zr = ra * acos(gst_getrmu(jlat,gstID))
          zcorr = FifthOrderCorrelFunction(zr,zlc)
          do jlon = myLonBeg, myLonEnd
            zgd_gst(jlon,1,jlat) = zcorr
          enddo
        enddo

        ! Transform to spectral space (extra manipulation because of mpi spectral transform)
        zsp_gst(:,:,:) = 0.0d0
        call gst_setID(gstID)
        call gst_reespe5(zsp_gst,zgd_gst)
        if(myMemBeg.eq.1) then
          my_zsp(:,:,jk) = zsp_gst(:,:,myMemBeg)
        else
          my_zsp(:,:,jk) = 0.0d0
        endif
      enddo

      nsize = nla_mpiglobal*nphase*nLevEns_M
      call RPN_COMM_allreduce(my_zsp,zsp,nsize,"mpi_double_precision","mpi_sum","GRID",ierr)
      deallocate(my_zsp) 
      ! Copy over to EnsLocalCor and truncate to NTRUNC
      do jk = 1, nLevEns_M
        do jn = 0, ntrunc
          ensLocalCor(jn,jk) = zsp(jn+1,1,jk)
        enddo
      enddo

      ! Make sure it's one at the pole
      do jk = 1, nLevEns_M
        do  jn = 0, ntrunc
          ensLocalCor(jn,jk) = abs(ensLocalCor(jn,jk))
        enddo
      enddo
      do jk = 1, nLevEns_M
        zpole = 0.d0
        do  jn = 0, ntrunc
          zpole = zpole + ensLocalCor(jn,jk)*sqrt((2.d0*jn+1.d0)/2.d0)
        enddo
        if(zpole.le.0.d0) then
          write(*,*)'POLE VALUE NEGATIVE IN setupGlobalHLoc jk=',jk
          call abort3d('setupGlobalHLoc')
        endif
        do jn = 0, ntrunc
          ensLocalCor(jn,jk) = ensLocalCor(jn,jk)/zpole
        enddo
      enddo

      ! Convert back to correlations and take sqrt
      do jk = 1, nLevEns_M
        do jn = 0, ntrunc
          ensLocalCor(jn,jk) = sqrt(0.5d0*ensLocalCor(jn,jk)*((2.0d0/(2.0d0*jn+1.0d0))**0.5d0))
        enddo
      enddo

      deallocate(zsp)
    else

       ! NO HORIZONTAL LOCALIZATION, set ensLocalCor to 1.0 for wavenumber 0
       do jk = 1, nLevEns_M
          ensLocalCor(:,jk) = 0.0d0
          ensLocalCor(0,jk) = 1.0d0
       enddo
    endif

  END SUBROUTINE setupGlobalHLoc

!--------------------------------------------------------------------------
! setupLamHLoc
!--------------------------------------------------------------------------

  SUBROUTINE setupLamHLoc(local_length) 1,10
    implicit none

    real(8), intent(in)  :: local_length(nLevEns_M)

    real(8), allocatable :: sp(:,:,:)
    real(8), allocatable :: gd(:,:,:)
    real(8), allocatable :: SumWeight(:)

    real(8) :: sum

    type(struct_lst)          :: lst_hloc    ! Spectral transform Parameters

    integer :: k, p, e, ila, totwvnb

    character(len=19)   :: kind

    if ( local_length(1) > 0.d0 ) then
      !
      !- 1. Enforce HORIZONTAL LOCALIZATION
      !

      !- 1.1 Setup a non-MPI spectral transform
      call lst_Setup( lst_hloc,                       & ! OUT
                      ni, nj, hco_ben % dlon, ntrunc, & ! IN
                      'NoMpi')                          ! IN

      !- 1.2 Create a correlation function in physical space
      allocate ( gd(ni,nj,nLevEns_M))

      call CreateBiPerCorrelFunction( gd,                     & ! OUT
                                      local_length, nLevEns_M ) ! IN

      !- 1.3 Transform to spectral space
      allocate ( sp(lst_hloc % nla, nphase, nLevEns_M))

      kind = 'GridPointToSpectral'
      call lst_VarTransform( lst_hloc % id,      & ! IN
                             sp,                 & ! OUT
                             gd,                 & ! IN
                             kind, nLevEns_M )     ! IN
 
      !- 1.4 Compute band mean
      allocate(SumWeight(0:ntrunc))
      SumWeight  (:)  = 0.d0

      ensLocalCor(:,:) = 0.d0
      do totwvnb = 0, ntrunc
         do e = 1, lst_hloc % nePerK(totwvnb)
            ila = lst_hloc % ilaFromEK(e,totwvnb)
            do p = 1, lst_hloc % nphase
               SumWeight(totwvnb) = SumWeight(totwvnb) + lst_hloc % Weight(ila)
               do k = 1, nLevEns_M
                  ensLocalCor(totwvnb,k) = ensLocalCor(totwvnb,k) + lst_hloc % Weight(ila) * abs(sp(ila,p,k))
                enddo
             enddo
         enddo
      enddo

      do totwvnb = 0, ntrunc
         if (SumWeight(totwvnb) /= 0.d0) then
            ensLocalCor(totwvnb,:) = ensLocalCor(totwvnb,:) / SumWeight(totwvnb)
         else
            ensLocalCor(totwvnb,:) = 0.d0
         endif
      enddo

      deallocate(SumWeight)

      !- 1.5 Normalization to one of correlation function from spectral densities: Part 1
!$OMP PARALLEL
!$OMP DO PRIVATE (totwvnb,k,sum)
      do k = 1, nLevEns_M
         sum = 0.0d0
         do totwvnb = 0, ntrunc
            sum = sum + real(totwvnb,8) * ensLocalCor(totwvnb,k)
         enddo
         do totwvnb = 0, ntrunc
            if ( sum /= 0.0d0 ) then
               ensLocalCor(totwvnb,k) = ensLocalCor(totwvnb,k) / sum
            else
               ensLocalCor(totwvnb,k) = 0.d0
            endif
         enddo
      enddo
!$OMP END DO
!$OMP END PARALLEL

      !- 1.6 Normalization to one of correlation function from spectral densities: Part 2

      !- 1.6.1 Spectral transform of a delta function (at the center of the domain)
      gd(:,:,:) = 0.d0
      gd(ni/2,nj/2,:) = 1.d0

      kind = 'GridPointToSpectral'
      call lst_VarTransform( lst_hloc % id,      & ! IN
                             sp,                 & ! OUT
                             gd,                 & ! IN
                             kind, nLevEns_M )     ! IN

      !- 1.6.2 Apply the correlation function
!$OMP PARALLEL
!$OMP DO PRIVATE (totwvnb,e,ila,p,k)
      do totwvnb = 0, ntrunc
         do e = 1, lst_hloc % nePerK(totwvnb)
            ila = lst_hloc % ilaFromEK(e,totwvnb)
            do p = 1, nphase
               do k = 1, nLevEns_M
                  sp(ila,p,k) = sp(ila,p,k) * ensLocalCor(totwvnb,k) * &
                                lst_hloc % NormFactor(ila,p) * lst_hloc % NormFactorAd(ila,p)
               enddo
            enddo
         enddo
      enddo
!$OMP END DO
!$OMP END PARALLEL

      !- 1.6.3 Move back to physical space
      kind = 'SpectralToGridPoint'
      call lst_VarTransform( lst_hloc % id,      & ! IN
                             sp,                 & ! IN
                             gd,                 & ! OUT
                             kind, nLevEns_M )     ! IN

      !- 1.6.4 Normalize to 1
      do k = 1, nLevEns_M
         if ( gd(ni/2,nj/2,k) <= 0.d0 ) then
            write(*,*) 'setupLamHLoc: Problem in normalization ',gd(ni/2,nj/2,k)
            call abort3d('aborting in setupLamHLoc')
         endif
         if ( mpi_myid == 0 ) then
           write(*,*) 'setupLamHLoc: Normalization factor = ', k, gd(ni/2,nj/2,k), 1.d0 / gd(ni/2,nj/2,k)
         endif
         ensLocalCor(:,k) = ensLocalCor(:,k) / gd(ni/2,nj/2,k)
      enddo

      !- 1.7 Take sqrt
      ensLocalCor(:,:) = sqrt(ensLocalCor(:,:))

      deallocate(sp)
      deallocate(gd)

    else
      !
      !- 2. NO HORIZONTAL LOCALIZATION: set ensLocalCor to 1.0 for wavenumber 0
      !
      ensLocalCor(:,:) = 0.0d0
      ensLocalCor(0,:) = 1.0d0
    endif

  END SUBROUTINE setupLamHLoc

!--------------------------------------------------------------------------
! CreateBiPerCorrelFunction
!--------------------------------------------------------------------------

  SUBROUTINE  CreateBiPerCorrelFunction(gridpoint,CorrelLength,nk) 1
    implicit none

    integer, intent(in)  :: nk

    real(8), intent(in)  :: CorrelLength(nk)

    real(8), intent(out) :: gridpoint(ni,nj,nk)

    integer          :: i, j, k, iref, jref
    real(8)          :: distance, distance_ref

    gridpoint(:,:,:) = 0.d0

    distance_ref = hco_ben % dlon * RA

    !
    !- Create a bi-periodic correlation function by centering the function in each 4 corners
    !

    !- Lower-Left Corner
    iref = 1
    jref = 1
    do j = 1, nj
       do i = 1, ni
          distance = distance_ref * sqrt( real((i-iref)**2 + (j-jref)**2,8) )
          do k = 1, nk
             gridpoint(i,j,k) = gridpoint(i,j,k) + FifthOrderCorrelFunction(distance,1000.d0*CorrelLength(k)/2.d0) 
          enddo
       enddo
    enddo

    !- Upper-Left Corner
    iref = 1
    jref = nj
    do j = 1, nj
       do i = 1, ni
          distance = distance_ref * sqrt( real((i-iref)**2 + (j-jref)**2,8) )
          do k = 1, nk
             gridpoint(i,j,k) = gridpoint(i,j,k) + FifthOrderCorrelFunction(distance,1000.d0*CorrelLength(k)/2.d0) 
          enddo
       enddo
    enddo

    !- Lower-Right Corner
    iref = ni
    jref = 1
    do j = 1, nj
       do i = 1, ni
          distance = distance_ref * sqrt( real((i-iref)**2 + (j-jref)**2,8) )
          do k = 1, nk
             gridpoint(i,j,k) = gridpoint(i,j,k) + FifthOrderCorrelFunction(distance,1000.d0*CorrelLength(k)/2.d0) 
          enddo
       enddo
    enddo

    !- Upper-Right Corner
    iref = ni
    jref = nj
    do j = 1, nj
       do i = 1, ni
          distance = distance_ref * sqrt( real((i-iref)**2 + (j-jref)**2,8) )
          do k = 1, nk
             gridpoint(i,j,k) = gridpoint(i,j,k) + FifthOrderCorrelFunction(distance,1000.d0*CorrelLength(k)/2.d0) 
          enddo
       enddo
    enddo

  END SUBROUTINE CreateBiPerCorrelFunction

!--------------------------------------------------------------------------
! BEN_reduceToMPILocal
!--------------------------------------------------------------------------

  SUBROUTINE BEN_reduceToMPILocal(cv_mpilocal,cv_mpiglobal,cvDim_mpilocal_out) 1,3
    implicit none
    real(8), intent(out) :: cv_mpilocal(cvDim_mpilocal)
    real(8), intent(in)  :: cv_mpiglobal(cvDim_mpiglobal)
    integer, intent(out) :: cvDim_mpilocal_out

    integer :: jdim_mpilocal, jdim_mpiglobal, ila_mpilocal, ila_mpiglobal
    integer :: jm, jn, memberIndex, jlev, ierr, p

    cvDim_mpilocal_out = cvDim_mpilocal

    ! assign part of mpiglobal vector from current mpi process

    if (hco_ben % global) then

      ! Global
      jdim_mpilocal = 0
      do memberIndex = myMemBeg, myMemEnd

        do jlev = 1, nLevEns_M
          do jm = mymBeg, mymEnd, mymSkip
            do jn = jm, ntrunc

              ! figure out index into global control vector
              ila_mpiglobal = gst_getNIND(jm,gstID) + jn - jm
              if(jm.eq.0) then
                ! for jm=0 only real part
                jdim_mpiglobal = ila_mpiglobal
              else
                ! for jm>0 both real and imaginary part
                jdim_mpiglobal = 2*ila_mpiglobal-1 - (ntrunc+1)
              endif
              ! add offset for level
              jdim_mpiglobal = jdim_mpiglobal + (jlev-1) * (ntrunc+1)*(ntrunc+1)
              ! add offset for member index
              jdim_mpiglobal = jdim_mpiglobal + (memberIndex-1) * (ntrunc+1)*(ntrunc+1)*nLevEns_M

              if(jm.eq.0) then
                ! controlVector only contain real part for jm=0
                jdim_mpilocal = jdim_mpilocal + 1
                cv_mpilocal(jdim_mpilocal) = cv_mpiglobal(jdim_mpiglobal)
              else
                ! controlVector contains real and imag parts for jm>0
                jdim_mpilocal = jdim_mpilocal + 1
                cv_mpilocal(jdim_mpilocal) = cv_mpiglobal(jdim_mpiglobal)
                jdim_mpilocal = jdim_mpilocal + 1
                cv_mpilocal(jdim_mpilocal) = cv_mpiglobal(jdim_mpiglobal+1)
              endif

            enddo
          enddo
        enddo

      enddo

    else
       
      ! LAM

      do memberIndex = myMemBeg, myMemEnd
         do jlev = 1, nLevEns_M
            do ila_mpilocal = 1, lst_ben%nla
               do p = 1, lst_ben%nphase

                  jdim_mpilocal = ( (memberIndex-myMemBeg) * nLevEns_M * lst_ben%nla * lst_ben%nphase ) + &
                                                            ( (jlev-1) * lst_ben%nla * lst_ben%nphase ) + &
                                                                  ( (ila_mpilocal-1) * lst_ben%nphase ) + p

                  ila_mpiglobal = lst_ben%ilaGlobal(ila_mpilocal)
                  jdim_mpiglobal = ( (memberIndex-1) * nLevEns_M * lst_ben%nlaGlobal * lst_ben%nphase ) + &
                                                      ( (jlev-1) * lst_ben%nlaGlobal * lst_ben%nphase ) + &
                                                                 ( (ila_mpiglobal-1) * lst_ben%nphase ) + p
  
                  if ( jdim_mpilocal  > cvDim_mpilocal ) then 
                     write(*,*) 'BEN_reduceToMPILocal: jdim_mpilocal > cvDim_mpilocal ',memberIndex,jlev,ila_mpilocal,p
                     call abort3d('BEN_reduceToMPILocal')
                  end if
                  if ( jdim_mpiglobal > cvDim_mpiglobal) then
                     write(*,*) 'BEN_reduceToMPILocal: jdim_mpiglobal > cvDim_mpiglobal ',memberIndex,jlev,ila_mpilocal,p,ila_mpiglobal
                     call abort3d('BEN_reduceToMPILocal')
                  end if
                  
                  cv_mpilocal(jdim_mpilocal) = cv_mpiglobal(jdim_mpiglobal)

               end do
            end do
         end do
      end do

    end if

  END SUBROUTINE BEN_reduceToMPILocal

!--------------------------------------------------------------------------
! BEN_expandToMPIGlobal
!--------------------------------------------------------------------------

  SUBROUTINE BEN_expandToMPIGlobal(cv_mpilocal,cv_mpiglobal,cvDim_mpiglobal_out) 1,3
    implicit none

    real(8), intent(in)  :: cv_mpilocal(cvDim_mpilocal)
    real(8), intent(out) :: cv_mpiglobal(cvDim_mpiglobal)
    integer, intent(out) :: cvDim_mpiglobal_out

    real(8), allocatable :: cv_maxmpilocal(:)
    real(8), pointer     :: cv_allmaxmpilocal(:,:) => null()
    integer, allocatable :: allmBeg(:),allmEnd(:),allmSkip(:)
    integer, allocatable :: allMemBeg(:),allMemEnd(:)
    integer :: jdim_mpilocal, jdim_mpiglobal, ila_mpiglobal, ila_mpilocal, cvDim_maxmpilocal
    integer :: jm, jn, jproc, memberIndex, jlev, ierr, p
    real(8), allocatable :: my_cv_mpiglobal(:)

    cvDim_mpiglobal_out = cvDim_mpiglobal

    ! gather all local control vectors onto mpi task 0

    if (hco_ben % global) then

       ! Global
       call rpn_comm_allreduce(cvDim_mpilocal,cvDim_maxmpilocal,1,"mpi_integer","mpi_max","GRID",ierr)

       allocate(cv_maxmpilocal(cvDim_maxmpilocal))
       if(mpi_myid.eq.0) allocate(cv_allmaxmpilocal(cvDim_maxmpilocal,mpi_nprocs))

       cv_maxmpilocal(:) = 0.0d0
       cv_maxmpilocal(1:cvDim_mpilocal) = cv_mpilocal(1:cvDim_mpilocal)

       call rpn_comm_gather(cv_maxmpilocal,    cvDim_maxmpilocal, "mpi_double_precision",  &
                            cv_allmaxmpilocal, cvDim_maxmpilocal, "mpi_double_precision", 0, "GRID", ierr )

       deallocate(cv_maxmpilocal)

       allocate(allMemBeg(mpi_nprocs))
       call rpn_comm_allgather(myMemBeg,1,"mpi_integer",       &
                               allMemBeg,1,"mpi_integer","GRID",ierr)
       allocate(allMemEnd(mpi_nprocs))
       call rpn_comm_allgather(myMemEnd,1,"mpi_integer",       &
                               allMemEnd,1,"mpi_integer","GRID",ierr)

       allocate(allmBeg(mpi_nprocs))
       call rpn_comm_allgather(mymBeg,1,"mpi_integer",       &
                               allmBeg,1,"mpi_integer","GRID",ierr)
       allocate(allmEnd(mpi_nprocs))
       call rpn_comm_allgather(mymEnd,1,"mpi_integer",       &
                               allmEnd,1,"mpi_integer","GRID",ierr)
       allocate(allmSkip(mpi_nprocs))
       call rpn_comm_allgather(mymSkip,1,"mpi_integer",       &
                               allmSkip,1,"mpi_integer","GRID",ierr)

       ! reorganize gathered mpilocal control vectors into the mpiglobal control vector
       if(mpi_myid.eq.0) then
         cv_mpiglobal(:) = 0.0d0

!$OMP PARALLEL DO PRIVATE(jproc,jdim_mpilocal,memberIndex,jlev,jm,jn,ila_mpiglobal,jdim_mpiglobal)
         do jproc = 0, (mpi_nprocs-1)
           jdim_mpilocal = 0
           do memberIndex = allMemBeg(jproc+1), allMemEnd(jproc+1)

             do jlev = 1, nLevEns_M
               do jm = allmBeg(jproc+1), allmEnd(jproc+1), allmSkip(jproc+1)
                 do jn = jm, ntrunc

                   ! figure out index into global control vector
                   ila_mpiglobal = gst_getNIND(jm,gstID) + jn - jm
                   if(jm.eq.0) then
                     ! for jm=0 only real part
                     jdim_mpiglobal = ila_mpiglobal
                   else
                     ! for jm>0 both real and imaginary part
                     jdim_mpiglobal = 2*ila_mpiglobal-1 - (ntrunc+1)
                   endif
                   ! add offset for level
                   jdim_mpiglobal = jdim_mpiglobal + (jlev-1) * (ntrunc+1)*(ntrunc+1)
                   ! add offset for member index
                   jdim_mpiglobal = jdim_mpiglobal + (memberIndex-1) * (ntrunc+1)*(ntrunc+1)*nLevEns_M

                   ! index into local control vector
                   if(jm.eq.0) then
                     ! only real component for jm=0
                     jdim_mpilocal = jdim_mpilocal + 1
                     cv_mpiglobal(jdim_mpiglobal) = cv_allmaxmpilocal(jdim_mpilocal,jproc+1)
                   else
                     ! both real and imaginary components for jm>0
                     jdim_mpilocal = jdim_mpilocal + 1
                     cv_mpiglobal(jdim_mpiglobal) = cv_allmaxmpilocal(jdim_mpilocal,jproc+1)
                     jdim_mpilocal = jdim_mpilocal + 1
                     cv_mpiglobal(jdim_mpiglobal+1) = cv_allmaxmpilocal(jdim_mpilocal,jproc+1)
                   endif

                   if(jdim_mpiglobal.gt.cvDim_mpiglobal)   &
                     write(*,*) 'ERROR: jdim,cvDim,mpiglobal=',jdim_mpiglobal,cvDim_mpiglobal,jlev,jn,jm

                 enddo
               enddo
             enddo
           enddo
         enddo ! jproc
!$OMP END PARALLEL DO

      endif ! myid .eq. 0 

      deallocate(allMemBeg)
      deallocate(allMemEnd)
      deallocate(allmBeg)
      deallocate(allmEnd)
      deallocate(allmSkip)
      if(mpi_myid.eq.0) deallocate(cv_allmaxmpilocal)

    else

      ! LAM
      allocate(my_cv_mpiglobal(cvDim_mpiglobal)) 
      my_cv_mpiglobal(:) = 0.0d0

      do memberIndex = myMemBeg, myMemEnd
         do jlev = 1, nLevEns_M
            do ila_mpilocal = 1, lst_ben%nla
               do p = 1, lst_ben%nphase

                  jdim_mpilocal = ( (memberIndex-myMemBeg) * nLevEns_M * lst_ben%nla * lst_ben%nphase ) + &
                                                            ( (jlev-1) * lst_ben%nla * lst_ben%nphase ) + &
                                                                  ( (ila_mpilocal-1) * lst_ben%nphase ) + p

                  ila_mpiglobal = lst_ben%ilaGlobal(ila_mpilocal)
                  jdim_mpiglobal = ( (memberIndex-1) * nLevEns_M * lst_ben%nlaGlobal * lst_ben%nphase ) + &
                                                      ( (jlev-1) * lst_ben%nlaGlobal * lst_ben%nphase ) + &
                                                                 ( (ila_mpiglobal-1) * lst_ben%nphase ) + p
  
                  if ( jdim_mpilocal  > cvDim_mpilocal ) then 
                     write(*,*) 'BEN_expandToMPIGlobal: jdim_mpilocal > cvDim_mpilocal ',memberIndex,jlev,ila_mpilocal,p
                     call abort3d('BEN_expandToMPIGlobal')
                  end if
                  if ( jdim_mpiglobal > cvDim_mpiglobal) then
                     write(*,*) 'BEN_expandToMPIGlobal: jdim_mpiglobal > cvDim_mpiglobal ',memberIndex,jlev,ila_mpilocal,p,ila_mpiglobal
                     call abort3d('BEN_expandToMPIGlobal')
                  end if
            
                  my_cv_mpiglobal(jdim_mpiglobal) = cv_mpilocal(jdim_mpilocal)

               end do
            end do
         end do
      end do

      call rpn_comm_allreduce(my_cv_mpiglobal,cv_mpiglobal,cvDim_mpiglobal,"mpi_double_precision","mpi_sum","GRID",ierr)
      deallocate(my_cv_mpiglobal) 
    end if

  end SUBROUTINE BEN_expandToMPIGlobal

!--------------------------------------------------------------------------
! BEN_BSqrt
!--------------------------------------------------------------------------

  SUBROUTINE BEN_BSqrt(controlVector_in,statevector) 1,11
    implicit none

    real(8)    :: controlVector_in(cvDim_mpilocal) 
    type(struct_gsv) :: statevector
    real(8),allocatable  :: incrementLocal(:,:,:,:)
    real(8)    :: ensAmplitude(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nEns)
    real(8)    :: ensAmplitude2(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nEns)
    real(8)    :: zsp_all(nLevEns_M,nla_mpiglobal,nphase,myMemBeg:myMemEnd)
    real(8)    :: zsp1(nla_mpiglobal,nphase,myMemBeg:myMemEnd)
    real(8), pointer :: field(:,:,:,:)
    integer   :: jlev
    integer   :: ierr,nsize,jj,jk,ji,jt,jkInc,jvar,ilev1,ilev2,topLevOffset
    integer   :: memberIndex

    real(8), allocatable :: gd_out(:,:,:)
    character(len=19)   :: kind

    if(.not. initialized) then
      if(mpi_myid.eq.0) write(*,*) 'bMatrixEnsemble not initialized'
      return
    endif

    if(sum(scaleFactor).eq.0.0d0) then
      if(mpi_myid.eq.0) write(*,*) 'bMatrixEnsemble: scaleFactor=0, skipping bSqrt'
      return
    endif

    if(mpi_myid.eq.0) write(*,*) 'ben_bsqrt: starting'
    if(mpi_myid.eq.0) write(*,*) 'Memory Used: ',get_max_rss()/1024,'Mb'

    allocate(incrementLocal(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nkgdimEns,numTime),stat=ierr)
    if(ierr.ne.0) then
      write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=12',ierr
      call abort3d('aborting in ben_bsqrt')
    endif

!$OMP PARALLEL
!$OMP DO PRIVATE (JJ,JK,JI)
    do jk = 1, nkgdimEns
      do jj = myLatBeg, myLatEnd
        do ji = myLonBeg, myLonEnd
          incrementLocal(ji,jj,jk,:) = 0.0d0
        enddo
      enddo
    enddo
!$OMP END DO
!$OMP END PARALLEL

!$OMP PARALLEL
!$OMP DO PRIVATE (JJ,MEMBERINDEX,JI)
    do memberIndex = 1, nEns
      do jj = myLatBeg, myLatEnd
        do ji = myLonBeg, myLonEnd
          ensAmplitude2(ji,jj,memberIndex) = 0.0d0
        enddo
      enddo
    enddo
!$OMP END DO
!$OMP END PARALLEL

    ! this was necessary to avoid FP overflow with 512 mpi tasks
    zsp_all(:,:,:,:) = 0.0d0

    call localizationSqrt(controlVector_in,zsp_all)

    do jlev = 1, nLevEns_M ! loop over levels in amplitude field

      ! now transform amplitude to grid-point space
      call rpn_comm_barrier("GRID",ierr)
      call tmg_start(64,'BEN_SPECTRAL')

      zsp1(:,:,:) = zsp_all(jlev,:,:,:)

      if (hco_ben % global) then
        allocate( gd_out(myLonBeg:myLonEnd,nEns,myLatBeg:myLatEnd) )
        call gst_setID(gstID)
        call gst_speree5(zsp1,gd_out)

        !- reordering because of gd_out(i,j,k)
!$OMP PARALLEL
!$OMP DO PRIVATE(memberIndex,jj,ji)        
        do memberIndex = 1, nEns
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              ensAmplitude(ji,jj,memberIndex) = gd_out(ji,memberIndex,jj)
            enddo
          enddo
        enddo
!$OMP END DO
!$OMP END PARALLEL
        deallocate(gd_out)
      else

        kind = 'SpectralToGridPoint'
        call lst_VarTransform( lst_ben % id,  & ! IN
                               zsp1,          & ! IN
                               ensAmplitude,  & ! OUT (i,j,k) !!!
                               kind, nEns )     ! IN

      endif

      call tmg_stop(64)

      call addEnsMember(ensAmplitude,ensAmplitude2,incrementLocal,jlev)

      if(is_staggered) then
        ensAmplitude2(:,:,:) = ensAmplitude(:,:,:)
      endif

    enddo

!$OMP PARALLEL
!$OMP DO PRIVATE(jj,jvar,field,ilev1,ilev2,topLevOffset,jt,jk,jkInc,ji)
    do jj = myLatBeg, myLatEnd
      do jvar = 1, vnl_numvarmax 
        if(gsv_varExist(vnl_varNameList(jvar))) then
          field => gsv_getField(statevector,vnl_varNameList(jvar))
          if(vnl_varNameList(jvar).eq.'UU  ') then
            ilev1 = ngposituu
          elseif(vnl_varNameList(jvar).eq.'VV  ') then
            ilev1 = ngpositvv
          elseif(vnl_varNameList(jvar).eq.'TT  ') then
            ilev1 = ngposittt
          elseif(vnl_varNameList(jvar).eq.'HU  ') then
            ilev1 = ngpositq
          elseif(vnl_varNameList(jvar).eq.'P0  ') then
            ilev1 = ngpositps
          elseif(vnl_varNameList(jvar).eq.'TG  ') then
            ilev1 = ngposittg
          else
            call abort3d('ben_bsqrt: No covariances available for variable:' // vnl_varNameList(jvar))
          endif
          if(vnl_vartypeFromVarname(vnl_varNameList(jvar)).eq.'SF') then
            ilev2 = ilev1
            topLevOffset = 1
          elseif(vnl_vartypeFromVarname(vnl_varNameList(jvar)).eq.'MM') then
            ilev2 = ilev1 - 1 + nlevEns_M
            topLevOffset = topLevIndex_M
          else
            ilev2 = ilev1 - 1 + nlevEns_T
            topLevOffset = topLevIndex_T
          endif
          do jt = 1, numTime
            do jk = ilev1, ilev2
              jkInc = jk-ilev1 + topLevOffset
              do ji = myLonBeg, myLonEnd
                field(ji,jkInc,jj,jt) = incrementLocal(ji,jj,jk,jt)
              enddo
            enddo
          enddo
        endif
      enddo
    enddo
!$OMP END DO
!$OMP END PARALLEL

    deallocate(incrementLocal)

    if(mpi_myid.eq.0) write(*,*) 'Memory Used: ',get_max_rss()/1024,'Mb'
    if(mpi_myid.eq.0) write(*,*) 'ben_bsqrt: done'

  END SUBROUTINE BEN_BSqrt

!--------------------------------------------------------------------------
! BEN_BSqrtAd
!--------------------------------------------------------------------------

  SUBROUTINE BEN_BSqrtAd(statevector,controlVector_out) 1,11
    implicit none

    real(8)    :: controlVector_out(cvDim_mpilocal) 
    type(struct_gsv) :: statevector

    real(8)    :: zsp_all(nLevEns_M,nla_mpiglobal,nphase,myMemBeg:myMemEnd)
    real(8)    :: zsp1(nla_mpiglobal,nphase,myMemBeg:myMemEnd)
    real(8), pointer :: field(:,:,:,:)
    real(8),allocatable :: incrementLocal(:,:,:,:)
    real(8)    :: ensAmplitude(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nEns)
    real(8)    :: ensAmplitude2(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nEns)
    real(8),allocatable    :: gd_in(:,:,:)

    integer   :: jlev,jj,jk,ji,jt,jkinc,nsize,ierr,jvar,ilev1,ilev2,topLevOffset
    integer   :: memberIndex

    character(len=19)   :: kind

    if(.not. initialized) then
      if(mpi_myid.eq.0) write(*,*) 'bMatrixEnsemble not initialized'
      return
    endif

    if(sum(scaleFactor).eq.0.0d0) then
      if(mpi_myid.eq.0) write(*,*) 'bMatrixEnsemble: scaleFactor=0, skipping bSqrtAd'
      return
    endif

    if(mpi_myid.eq.0) write(*,*) 'ben_bsqrtad: starting'
    if(mpi_myid.eq.0) write(*,*) 'Memory Used: ',get_max_rss()/1024,'Mb'

    allocate(incrementLocal(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nkgdimEns,numTime),stat=ierr)
    if(ierr.ne.0) then
      write(*,*) 'bmatrixEnsemble: Problem allocating memory! id=14',ierr
      call abort3d('aborting in ben_bsqrtad')
    endif

!$OMP PARALLEL
!$OMP DO PRIVATE(jj,jvar,field,ilev1,ilev2,topLevOffset,jt,jk,jkInc,ji)
    do jj = myLatBeg, myLatEnd
      do jvar = 1, vnl_numvarmax 
        if(gsv_varExist(vnl_varNameList(jvar))) then
          field => gsv_getField(statevector,vnl_varNameList(jvar))
          if(vnl_varNameList(jvar).eq.'UU  ') then
            ilev1 = ngposituu
          elseif(vnl_varNameList(jvar).eq.'VV  ') then
            ilev1 = ngpositvv
          elseif(vnl_varNameList(jvar).eq.'TT  ') then
            ilev1 = ngposittt
          elseif(vnl_varNameList(jvar).eq.'HU  ') then
            ilev1 = ngpositq
          elseif(vnl_varNameList(jvar).eq.'P0  ') then
            ilev1 = ngpositps
          elseif(vnl_varNameList(jvar).eq.'TG  ') then
            ilev1 = ngposittg
          else
            call abort3d('ben_bsqrtad: No covariances available for variable:' // vnl_varNameList(jvar))
          endif
          if(vnl_vartypeFromVarname(vnl_varNameList(jvar)).eq.'SF') then
            ilev2 = ilev1
            topLevOffset = 1
          elseif(vnl_vartypeFromVarname(vnl_varNameList(jvar)).eq.'MM') then
            ilev2 = ilev1 - 1 + nlevEns_M
            topLevOffset = topLevIndex_M
          else
            ilev2 = ilev1 - 1 + nlevEns_T
            topLevOffset = topLevIndex_T
          endif
          do jt = 1, numTime
            do jk = ilev1, ilev2
              jkInc = jk-ilev1 + topLevOffset
              do ji = myLonBeg, myLonEnd
                incrementLocal(ji,jj,jk,jt) = field(ji,jkInc,jj,jt)
              enddo
            enddo
          enddo
        endif
      enddo
    enddo
!$OMP END DO
!$OMP END PARALLEL

    do jlev = 1, nLevEns_M ! loop over levels in amplitude field

      call addEnsMemberAd(incrementLocal,ensAmplitude,jlev)

      zsp1(:,:,:) = 0.0d0 ! needed, not all levels set
      call rpn_comm_barrier("GRID",ierr)
      call tmg_start(64,'BEN_SPECTRAL')

      if (hco_ben % global) then
        allocate( gd_in(myLonBeg:myLonEnd,nEns,myLatBeg:myLatEnd) )
        !- reordering because of gd_in(i,j,k)
!$OMP PARALLEL
!$OMP DO PRIVATE(memberIndex,jj,ji)        
        do memberIndex = 1, nEns
          do jj = myLatBeg, myLatEnd
            do ji = myLonBeg, myLonEnd
              gd_in(ji,memberIndex,jj) = ensAmplitude(ji,jj,memberIndex)
            enddo
          enddo
        enddo
!$OMP END DO
!$OMP END PARALLEL

        call gst_setID(gstID)
        call gst_reespe5(zsp1,gd_in)
        deallocate(gd_in)

      else

        kind = 'GridPointToSpectral'
        call lst_VarTransform( lst_ben % id, & ! IN
                               zsp1,         & ! OUT
                               ensAmplitude, & ! IN (i,j,k) !!!
                               kind, nEns )    ! IN

      endif

      zsp_all(jlev,:,:,:) = zsp1(:,:,:)
      call tmg_stop(64)

    enddo

    call localizationSqrtAd(zsp_all,controlVector_out)

    deallocate(incrementLocal)

    if(mpi_myid.eq.0) write(*,*) 'Memory Used: ',get_max_rss()/1024,'Mb'
    if(mpi_myid.eq.0) write(*,*) 'ben_bsqrtad: done'

  END SUBROUTINE BEN_BSqrtAd

!--------------------------------------------------------------------------
! addEnsMember
!--------------------------------------------------------------------------

  SUBROUTINE addEnsMember(ensAmplitude_in,ensAmplitude2_in,incrementLocal_out,levelIndex) 1
    implicit none

    integer,intent(in) :: levelIndex
    real(8)      :: ensAmplitude_in(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nEns)
    real(8)      :: ensAmplitude2_in(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nEns)
    real(8)      :: incrementLocal_out(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nkgdimEns,numTime)

    integer     :: jvar,jlev,memberIndex,jt,jj,ji,numVar

    call tmg_start(62,'ADDMEM')

    if(is_staggered) then

      if(levelIndex.eq.1) then
        ! use top momentum level amplitudes for top thermo levels
        ensAmplitude2_in(:,:,:) = ensAmplitude_in(:,:,:)
      else
        ! for other levels, interpolate momentum weights to get thermo amplitudes
        ensAmplitude2_in(:,:,:) = 0.5d0*( ensAmplitude2_in(:,:,:) +   &
                                          ensAmplitude_in(:,:,:) )
      endif

      do memberIndex = 1, nEns
!$OMP PARALLEL
!$OMP DO PRIVATE (JT,JJ,JVAR,JLEV,JI)
        do jt = 1, numTime

          ! momentum variables
          do jvar = 1, 2
            jlev = verticalLevel(levelIndex,jvar)
            do jj = myLatBeg, myLatEnd
              do ji = myLonBeg, myLonEnd
                incrementLocal_out(ji,jj,jlev,jt) = incrementLocal_out(ji,jj,jlev,jt) +   &
                  ensAmplitude_in(ji,jj,memberIndex)*dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jlev,jt))
              enddo
            enddo
          enddo

          ! non-surface thermo variables (uses interpolated amplitude field)
          do jvar = 3, 4
            jlev = verticalLevel(levelIndex,jvar)
            do jj = myLatBeg, myLatEnd
              do ji = myLonBeg, myLonEnd
                incrementLocal_out(ji,jj,jlev,jt) = incrementLocal_out(ji,jj,jlev,jt) +   &
                  ensAmplitude2_in(ji,jj,memberIndex)*dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jlev,jt))
              enddo
            enddo
          enddo

          ! surface thermo variables (uses sfc amplitude field)
          if(levelIndex.eq.nLevEns_M) then
            do jvar = 3, 4
              jlev = verticalLevel(levelIndex,jvar)+1
              do jj = myLatBeg, myLatEnd
                do ji = myLonBeg, myLonEnd
                  incrementLocal_out(ji,jj,jlev,jt) = incrementLocal_out(ji,jj,jlev,jt) +   &
                    ensAmplitude_in(ji,jj,memberIndex)*dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jlev,jt))
                enddo
              enddo
            enddo
          endif

          ! other surface variables (uses sfc amplitude field)
          if(levelIndex.eq.nLevEns_M) then
            do jvar = 5, 6
              jlev = verticalLevel(levelIndex,jvar)
              do jj = myLatBeg, myLatEnd
                do ji = myLonBeg, myLonEnd
                  incrementLocal_out(ji,jj,jlev,jt) = incrementLocal_out(ji,jj,jlev,jt) +   &
                    ensAmplitude_in(ji,jj,memberIndex)*dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jlev,jt))
                enddo
              enddo
            enddo
          endif

        enddo ! jt
!$OMP END DO
!$OMP END PARALLEL
      enddo

    else ! not staggered

      if(levelIndex.eq.nLevEns_M) then
        ! surface level, 4 3D and 2 sfc fields
        numVar = 6
      else
        ! above the surface, only 4 3D fields
        numVar = 4
      endif

      do memberIndex = 1, nEns
!$OMP PARALLEL
!$OMP DO PRIVATE (JT,JJ,JVAR,JLEV,JI)
        do jt = 1, numTime
          do jvar = 1, numVar
            jlev = verticalLevel(levelIndex,jvar)
            do jj = myLatBeg, myLatEnd
              do ji = myLonBeg, myLonEnd
                incrementLocal_out(ji,jj,jlev,jt) = incrementLocal_out(ji,jj,jlev,jt) +   &
                  ensAmplitude_in(ji,jj,memberIndex)*dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jlev,jt))
              enddo
            enddo
          enddo
        enddo
!$OMP END DO
!$OMP END PARALLEL
      enddo

    endif

    call tmg_stop(62)

  END SUBROUTINE addEnsMember

!--------------------------------------------------------------------------
! addEnsMemberAd
!--------------------------------------------------------------------------

  SUBROUTINE addEnsMemberAd(incrementLocal_in,ensAmplitude_out,levelIndex) 1
    implicit none

    integer,intent(in) :: levelIndex
    real(8)      :: ensAmplitude_out(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nEns)
    real(8)      :: incrementLocal_in(myLonBeg:myLonEnd,myLatBeg:myLatEnd,nkgdimEns,numTime)

    real(8)     :: dfact
    integer     :: jvar,jlev,memberIndex,jt,jj,ji,numVar

    call tmg_start(62,'ADDMEM')

!$OMP PARALLEL
!$OMP DO PRIVATE (JJ,memberIndex,JI)
    do memberIndex = 1, nEns
      do jj = myLatBeg, myLatEnd
        do ji = myLonBeg, myLonEnd
          ensAmplitude_out(ji,jj,memberIndex) = 0.0d0
        enddo
      enddo
    enddo
!$OMP END DO
!$OMP END PARALLEL

    if(is_staggered) then

      do jt = 1, numTime

        ! momentum variables
        do jvar = 1, 2
          jlev = verticalLevel(levelIndex,jvar)
!$OMP PARALLEL
!$OMP DO PRIVATE (memberIndex,JJ,JI)
          do memberIndex = 1, nEns
            do jj = myLatBeg, myLatEnd
              do ji = myLonBeg, myLonEnd
                ensAmplitude_out(ji,jj,memberIndex) = ensAmplitude_out(ji,jj,memberIndex) +   &
                  incrementLocal_in(ji,jj,jlev,jt)*dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jlev,jt))
              enddo
            enddo
          enddo
!$OMP END DO
!$OMP END PARALLEL
        enddo ! jvar

        ! non-surface thermo variables (impact on amplitude just below: same level index)
        do jvar = 3, 4
          jlev = verticalLevel(levelIndex,jvar)
          if(levelIndex.eq.1) then
            dfact = 1.0d0
          else
            dfact = 0.5d0
          endif
!$OMP PARALLEL
!$OMP DO PRIVATE (memberIndex,JJ,JI)
          do memberIndex = 1, nEns
            do jj = myLatBeg, myLatEnd
              do ji = myLonBeg, myLonEnd
                ensAmplitude_out(ji,jj,memberIndex) = ensAmplitude_out(ji,jj,memberIndex) +   &
                  dfact*incrementLocal_in(ji,jj,jlev,jt)*dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jlev,jt))
              enddo
            enddo
          enddo
!$OMP END DO
!$OMP END PARALLEL
        enddo ! jvar

        ! thermo variables (impact on amplitude just above)
        if(levelIndex.eq.nLevEns_M) then
          dfact = 1.0d0
        else
          dfact = 0.5d0
        endif
        do jvar = 3, 4
          jlev = verticalLevel(levelIndex,jvar) + 1
!$OMP PARALLEL
!$OMP DO PRIVATE (memberIndex,JJ,JI)
          do memberIndex = 1, nEns
            do jj = myLatBeg, myLatEnd
              do ji = myLonBeg, myLonEnd
                ensAmplitude_out(ji,jj,memberIndex) = ensAmplitude_out(ji,jj,memberIndex) +   &
                  dfact*incrementLocal_in(ji,jj,jlev,jt)*dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jlev,jt))
              enddo
            enddo
          enddo
!$OMP END DO
!$OMP END PARALLEL
        enddo ! jvar

        ! surface variables (impact on amplitude at the surface)
        if(levelIndex.eq.nLevEns_M) then
          do jvar = 5, 6
            jlev = verticalLevel(levelIndex,jvar)
!$OMP PARALLEL
!$OMP DO PRIVATE (memberIndex,JJ,JI)
            do memberIndex = 1, nEns
              do jj = myLatBeg, myLatEnd
                do ji = myLonBeg, myLonEnd
                  ensAmplitude_out(ji,jj,memberIndex) = ensAmplitude_out(ji,jj,memberIndex) +   &
                    incrementLocal_in(ji,jj,jlev,jt)*dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jlev,jt))
                enddo
              enddo
            enddo
!$OMP END DO
!$OMP END PARALLEL
          enddo ! jvar
        endif

      enddo ! jt

    else ! not staggered

      if(levelIndex.eq.nLevEns_M) then
        ! surface level, 4 3D and 2 sfc fields
        numVar = 6
      else
        ! above the surface, only 4 3D fields
        numVar = 4
      endif

      do jt = 1, numTime
        do jvar = 1, numVar
          jlev = verticalLevel(levelIndex,jvar)
!$OMP PARALLEL
!$OMP DO PRIVATE (memberIndex,JJ,JI)
          do memberIndex = 1, nEns
            do jj = myLatBeg, myLatEnd
              do ji = myLonBeg, myLonEnd
                ensAmplitude_out(ji,jj,memberIndex) = ensAmplitude_out(ji,jj,memberIndex) +   &
                  incrementLocal_in(ji,jj,jlev,jt)*dble(ensPerturbations(memberIndex)%member_r4(ji,jj,jlev,jt))
              enddo
            enddo
          enddo
!$OMP END DO
!$OMP END PARALLEL
        enddo
      enddo

    endif

    call tmg_stop(62)

  END SUBROUTINE addEnsMemberAd

!--------------------------------------------------------------------------
! localizationSqrt
!--------------------------------------------------------------------------

  SUBROUTINE localizationSqrt(controlVector_in,zsp_all) 1,2
    implicit none

    real(8)           :: controlVector_in(cvDim_mpilocal)
    real(8)           :: zsp_all(nLevEns_M,nla_mpiglobal,nphase,myMemBeg:myMemEnd)

    integer          :: jlev,jla,memberIndex,p
     real(8) ,allocatable :: zsp(:,:,:,:)

    !
    !- 1.  Horizontal Localization
    !
    if (hco_ben % global) then
       call GlobalHLoc( zsp_all,         & ! OUT
                        controlVector_in ) ! IN
    else
       call LamHLoc( zsp_all,         & ! OUT
                     controlVector_in ) ! IN
    endif

    !
    !- 2.  Vertical localization
    !
    allocate(zsp(nLevEns_M,nla_mpiglobal,nphase,myMemBeg:myMemEnd)) 
    call tmg_start(63,'BEN_VLOC')

!    call dgemul(ensLocalVert(1,1),nLevEns_M,'N',  &
!                zsp_all(1,1,1,myMemBeg),nLevEns_M,'N',  &
!                zsp_all(1,1,1,myMemBeg),nLevEns_M,  &
!                nLevEns_M,nLevEns_M,nphase*nla_mpiglobal*myMemCount)

    call dgemm('N','N',nLevEns_M,nphase*nla_mpiglobal*myMemCount,nLevEns_M,1.0D0,&
                ensLocalVert(1,1),nLevEns_M,zsp_all(1,1,1,myMemBeg),nLevEns_M,&
                0.0D0,zsp(1,1,1,myMemBeg),nLevEns_M)

    call tmg_stop(63)
    zsp_all = zsp
    deallocate(zsp)

    
  END SUBROUTINE localizationSqrt

!--------------------------------------------------------------------------
! GlobalHLoc
!--------------------------------------------------------------------------

  SUBROUTINE GlobalHLoc(zsp_all,controlVector_in) 1,2
    implicit none

    real(8), intent(in)  :: controlVector_in(cvDim_mpilocal)
    real(8), intent(out) :: zsp_all(nLevEns_M,nla_mpiglobal,nphase,myMemBeg:myMemEnd)

    integer :: jlev, jm, jn, ila, jdim, memberIndex 

    call tmg_start(65,'BEN_HLOC')

    jdim = 0

    do memberIndex = myMemBeg, myMemEnd

      do jlev = 1, nLevEns_M
        do jm = mymBeg, mymEnd, mymSkip
          do jn = jm, ntrunc
            ila = gst_getnind(jm,gstID) + jn - jm
            if(jm.eq.0) then
              ! controlVector only contain real part for jm=0
              jdim = jdim + 1
              zsp_all(jlev,ila,1,memberIndex) = controlVector_in(jdim)*ensLocalCor(jn_vec(ila),jlev)*rsq2
              zsp_all(jlev,ila,2,memberIndex) = 0.0d0
            else
              ! controlVector contains real and imag parts for jm>0
              jdim = jdim + 1
              zsp_all(jlev,ila,1,memberIndex) = controlVector_in(jdim)*ensLocalCor(jn_vec(ila),jlev)
              jdim = jdim + 1
              zsp_all(jlev,ila,2,memberIndex) = controlVector_in(jdim)*ensLocalCor(jn_vec(ila),jlev)
            endif
          enddo
        enddo
      enddo
      if(jdim.gt.cvDim_mpilocal) then
        write(*,*) 'ben globalHLoc: jdim > cvDim_mpilocal! ',jdim,memberIndex,cvDim_mpilocal
        call abort3d('aborted in ben globalHLoc')
      endif

    enddo

    call tmg_stop(65)

  END SUBROUTINE GlobalHLoc

!--------------------------------------------------------------------------
! LamHLoc
!--------------------------------------------------------------------------

  SUBROUTINE LamHLoc(zsp_all,controlVector_in) 1,1
    implicit none

    real(8), intent(in)  :: controlVector_in(cvDim_mpilocal)
    real(8), intent(out) :: zsp_all(nLevEns_M,nla_mpiglobal,nphase,myMemBeg:myMemEnd)

    integer :: jlev,jla, jdim, memberIndex, p 

    !
    !- Reshape + Horizontal localization + Scaling (parseval)
    !
    jdim = 0

    do memberIndex = myMemBeg, myMemEnd

       call tmg_start(65,'BEN_CAIN')
       do jlev = 1, nLevEns_M
         do jla = 1, nla_mpiglobal
            do p = 1, nphase
              jdim = jdim + 1
              zsp_all(jlev,jla,p,memberIndex) = controlVector_in(jdim)           * &
                                                ensLocalCor(lst_ben%k(jla),jlev) * &
                                                lst_ben % NormFactor(jla,p)
            enddo
         enddo
       enddo
       if (jdim > cvDim_mpilocal ) then
          write(*,*) 'BEN: LamHLoc: jdim > cvDim! ',jdim,memberIndex,cvDim_mpilocal
          call abort3d('aborted in LamHLoc')
       endif
       call tmg_stop(65)

    enddo

  END SUBROUTINE LamHLoc

!--------------------------------------------------------------------------
! localizationSqrtAd
!--------------------------------------------------------------------------

  SUBROUTINE localizationSqrtAd(zsp_all,controlVector_out) 1,2
    implicit none

    real(8) :: controlVector_out(cvDim_mpilocal)
    real(8) :: zsp_all(nLevEns_M,nla_mpiglobal,nphase,myMemBeg:myMemEnd)

    integer :: jlev,jla,memberIndex,p
    real(8),allocatable::zsp(:,:,:,:)

    !
    !- 2.  Vertical Localization

    allocate(zsp(nLevEns_M,nla_mpiglobal,nphase,myMemBeg:myMemEnd) ) 
    zsp=0.0D0
    !
    call tmg_start(63,'BEN_VLOC')

!    call dgemul(ensLocalVert(1,1),nLevEns_M,'N',  &
!                zsp_all(1,1,1,myMemBeg),nLevEns_M,'N',  &
!                zsp_all(1,1,1,myMemBeg),nLevEns_M,  &
!                nLevEns_M,nLevEns_M,nphase*nla_mpiglobal*myMemCount)

    call dgemm('N','N',nLevEns_M,nphase*nla_mpiglobal*myMemCount,nLevEns_M,1.0D0, &
                ensLocalVert(1,1),nLevEns_M, zsp_all(1,1,1,myMemBeg),nLevEns_M,&
                0.0D0,zsp(1,1,1,myMemBeg),nLevEns_M)


    call tmg_stop(63)

    !
    !- 1.  Horizontal Localization
    !
    if (hco_ben % global) then
      call GlobalHLocAd( zsp,           & ! IN
                         controlVector_out )  ! OUT
    else
      call LamHLocAd( zsp,           & ! IN
                      controlVector_out )  ! OUT
    endif
     deallocate(zsp)

  END SUBROUTINE localizationSqrtAd

!--------------------------------------------------------------------------
! GlobalHLocAd
!--------------------------------------------------------------------------

  SUBROUTINE GlobalHLocAd(zsp_all,controlVector_out) 1,2
    implicit none

    real(8), intent(out)   :: controlVector_out(cvDim_mpilocal)
    real(8), intent(in)    :: zsp_all(nLevEns_M,nla_mpiglobal,nphase,myMemBeg:myMemEnd)

    integer :: jlev, jm, jn, ila, jdim, memberIndex 

    call tmg_start(65,'BEN_HLOC')

    jdim = 0

    do memberIndex = myMemBeg, myMemEnd

       do jlev = 1, nLevEns_M
          do jm = mymBeg, mymEnd, mymSkip
            do jn = jm, ntrunc
              ila = gst_getnind(jm,gstID) + jn - jm
              if(jm.eq.0) then
                ! controlVector only contain real part for jm=0
                jdim = jdim + 1
                controlVector_out(jdim) = controlVector_out(jdim) +  &
                                          zsp_all(jlev,ila,1,memberIndex)*ensLocalCor(jn_vec(ila),jlev)*rsq2
              else
                ! controlVector contains real and imag parts for jm>0
                jdim = jdim + 1
                controlVector_out(jdim) = controlVector_out(jdim) +  &
                                          zsp_all(jlev,ila,1,memberIndex)*ensLocalCor(jn_vec(ila),jlev)*2.0d0
                jdim = jdim + 1
                controlVector_out(jdim) = controlVector_out(jdim) +  &
                                          zsp_all(jlev,ila,2,memberIndex)*ensLocalCor(jn_vec(ila),jlev)*2.0d0
              endif
           enddo
         enddo
       enddo
       if(jdim.gt.cvDim_mpilocal) then
          write(*,*) 'ben globalHLocAd: jdim > cvDim_mpilocal! ',jdim,memberIndex,cvDim_mpilocal
          call abort3d('aborted in ben globalHLocAd')
       endif
    
    enddo

    call tmg_stop(65)

  END SUBROUTINE GlobalHLocAd

!--------------------------------------------------------------------------
! LamHLocAd
!--------------------------------------------------------------------------

  SUBROUTINE LamHLocAd(zsp_all,controlVector_out) 1,1
    implicit none

    real(8), intent(out)   :: controlVector_out(cvDim_mpilocal)
    real(8), intent(in)    :: zsp_all(nLevEns_M,nla_mpiglobal,nphase,myMemBeg:myMemEnd)

    integer :: jla, jlev, jdim, memberIndex, p

    !
    !- Reshape + Horizontal localization + Scaling (parseval)
    !
    jdim = 0

    do memberIndex = myMemBeg, myMemEnd

       call tmg_start(65,'BEN_CAIN')
       do jlev = 1, nLevEns_M
         do jla = 1, nla_mpiglobal
           do p = 1, nphase
             jdim = jdim + 1
             controlVector_out(jdim) = controlVector_out(jdim) +           &
                                       ( zsp_all(jlev,jla,p,memberIndex)  * &
                                         ensLocalCor(lst_ben%k(jla),jlev) * &
                                         lst_ben % NormFactorAd(jla,p)    )
           enddo
         enddo
       enddo
       if (jdim > cvDim_mpilocal ) then
          write(*,*) 'BEN: LamHLocAD: jdim > cvDim! ',jdim, memberIndex, cvDim_mpilocal
          call abort3d('aborted in LamHLocAd')
       endif
       call tmg_stop(65)

    enddo

  END SUBROUTINE LamHLocAd

END MODULE BMatrixEnsemble