LCOV - code coverage report
Current view: top level - hybrid - wavefproducts_aux.F90 (source / functions) Hit Total Coverage
Test: FLEUR test coverage Lines: 60 89 67.4 %
Date: 2024-05-02 04:21:52 Functions: 3 6 50.0 %

          Line data    Source code
       1             : module m_wavefproducts_aux
       2             :    use m_types_fftGrid
       3             :    use m_types
       4             : CONTAINS
       5          88 :    subroutine wavefproducts_IS_FFT(fi, ik, iq, g_t, jsp, bandoi, bandof, mpdata, hybdat, lapw, stars, nococonv, &
       6          88 :                                    ikqpt, z_k, z_kqpt_p, c_phase_kqpt, cprod)
       7             :       !$ use omp_lib
       8             :       use m_constants
       9             :       use m_judft
      10             :       use m_fft_interface
      11             :       use m_io_hybrid
      12             :       use m_juDFT
      13             : #ifdef CPP_MPI
      14             :       use mpi
      15             : #endif
      16             :       implicit NONE
      17             :       type(t_fleurinput), intent(in)  :: fi
      18             :       TYPE(t_nococonv), INTENT(IN)    :: nococonv
      19             :       TYPE(t_lapw), INTENT(IN)        :: lapw
      20             :       TYPE(t_mpdata), intent(in)      :: mpdata
      21             :       TYPE(t_hybdat), INTENT(INOUT)   :: hybdat
      22             :       type(t_stars), intent(in)       :: stars
      23             :       type(t_mat), intent(in)         :: z_k
      24             :       type(t_mat), intent(inout)      :: z_kqpt_p, cprod
      25             :       !     - scalars -
      26             :       INTEGER, INTENT(IN)      ::  ik, iq, jsp, g_t(3), bandoi, bandof
      27             :       INTEGER, INTENT(IN)      ::  ikqpt
      28             :       !     - arrays -
      29             :       complex, intent(inout)    :: c_phase_kqpt(hybdat%nbands(ikqpt,jsp))
      30             : 
      31          88 :       complex, allocatable  :: prod(:,:), psi_k(:, :), psi_kqpt(:,:)
      32             : 
      33          88 :       type(t_mat)     :: z_kqpt
      34          88 :       type(t_lapw)    :: lapw_ikqpt
      35         704 :       type(t_fft)     :: fft, wavef2rs_fft
      36        1144 :       type(t_fftgrid) :: stepf, grid
      37             : 
      38             : 
      39             :       integer, parameter :: blocksize = 512
      40             :       integer :: g(3), igptm, iob, n_omp, j, jstart, loop_length
      41             :       integer :: ok, nbasfcn, psize, iband, ierr, i, max_igptm
      42         176 :       integer, allocatable :: band_list(:), g_ptr(:)
      43             :       real    :: inv_vol, gcutoff, max_imag
      44             : 
      45             :       logical :: real_warned
      46             : 
      47          88 :       real_warned = .False.
      48             : 
      49          88 :       call timestart("wavef_IS_FFT")
      50          88 :       max_igptm = mpdata%n_g(iq)
      51             : 
      52          88 :       gcutoff = (2*fi%input%rkmax + fi%mpinp%g_cutoff) * fi%hybinp%fftcut
      53          88 :       inv_vol = 1/sqrt(fi%cell%omtil)
      54          88 :       psize = bandof - bandoi + 1
      55             :       !this is for the exact result. Christoph recommend 2*gmax+gcutm for later
      56          88 :       if (2*fi%input%rkmax + fi%mpinp%g_cutoff > fi%input%gmax) then
      57          44 :          write (*, *) "WARNING: not accurate enough: 2*kmax+gcutm >= fi%input%gmax"
      58             :          !call juDFT_error("not accurate enough: 2*kmax+gcutm >= fi%input%gmax")
      59             :       endif
      60             : 
      61          88 :       call stepf%init(fi%cell, fi%sym, gcutoff)
      62             :       block
      63             :          type(t_cell)         :: cell !unused 
      64          88 :          call stepf%putfieldOnGrid(stars, stars%ustep)
      65             :       end block
      66          88 :       call fft%init(stepf%dimensions, .false., batch_size=1, l_gpu=.True.)
      67             :       !$acc data copyin(stepf, stepf%grid, stepf%gridlength)
      68             :          ! after we transform psi_k*stepf*psi_kqpt back  to 
      69             :          ! G-space we have to divide by stepf%gridLength. We do this now
      70             : 
      71             :          !$acc kernels default(none) present(stepf, stepf%grid, stepf%gridLength)
      72      904574 :          stepf%grid = stepf%grid * inv_vol / stepf%gridLength
      73             :          !$acc end kernels
      74             : 
      75          88 :          call fft%exec(stepf%grid)
      76          88 :          call fft%free()
      77             :          
      78          88 :          call setup_g_ptr(mpdata, stepf, g_t, iq, g_ptr)
      79             :          
      80          88 :          CALL lapw_ikqpt%init(fi, nococonv, ikqpt)
      81             : 
      82          88 :          nbasfcn = lapw_ikqpt%hyb_num_bas_fun(fi)
      83          88 :          call z_kqpt%alloc(z_k%l_real, nbasfcn, psize)
      84          88 :          call z_kqpt_p%init(z_kqpt)
      85             : 
      86        2702 :          band_list = [(i, i=bandoi, bandof)]
      87             :          call read_z(fi%atoms, fi%cell, hybdat, fi%kpts, fi%sym, fi%noco, nococonv, fi%input, ikqpt, jsp, z_kqpt, &
      88          88 :                      c_phase=c_phase_kqpt, parent_z=z_kqpt_p, list=band_list)
      89             : #ifdef CPP_MPI
      90          88 :          call timestart("read_z barrier")
      91          88 :          call MPI_Barrier(MPI_COMM_WORLD, ierr)
      92          88 :          hybdat%max_q = hybdat%max_q - 1
      93          88 :          call timestop("read_z barrier")
      94             : #endif
      95             : 
      96         352 :          allocate(psi_kqpt(0:stepf%gridLength-1, psize), stat=ierr)
      97          88 :          if(ierr /= 0) call juDFT_error("can't alloc psi_kqpt")
      98             : 
      99             :          !$acc data create(psi_kqpt)
     100          88 :             call grid%init(fi%cell, fi%sym, gcutoff)
     101          88 :             call wavef2rs_fft%init(grid%dimensions, .false., batch_size=psize, l_gpu=.True.)
     102             :             !$acc data copyin(z_kqpt, z_kqpt%l_real, z_kqpt%data_r, z_kqpt%data_c, lapw_ikqpt, lapw_ikqpt%nv, lapw_ikqpt%gvec,&
     103             :             !$acc             jsp, bandoi, bandof, psize, grid, grid%dimensions)
     104          88 :                call timestart("1st wavef2rs")
     105          88 :                call wavef2rs(fi, lapw_ikqpt, z_kqpt, gcutoff, 1, psize, jsp, grid, wavef2rs_fft, psi_kqpt)
     106          88 :                call timestop("1st wavef2rs")
     107             : 
     108             :                !$acc kernels default(none) present(psi_kqpt, stepf, stepf%grid)
     109         930 :                do iob = 1, psize 
     110     9096578 :                   psi_kqpt(:,iob) = psi_kqpt(:,iob) * stepf%grid
     111             :                enddo
     112             :                !$acc end kernels
     113             :             !$acc end data
     114          88 :             call wavef2rs_fft%free()
     115             :             !call grid%free()
     116             : 
     117          88 :             call timestart("Big OMP loop")
     118             : #ifndef _OPENACC
     119             :             !$OMP PARALLEL default(none) &
     120             :             !$OMP private(iband, iob, g, igptm, prod, psi_k, ok, fft, wavef2rs_fft, max_imag, grid) &
     121             :             !$OMP shared(hybdat, psi_kqpt, cprod,  mpdata, iq, g_t, psize, gcutoff, max_igptm)&
     122          88 :             !$OMP shared(jsp, z_k, stars, lapw, fi, inv_vol, ik, real_warned, n_omp, bandoi, stepf, g_ptr)
     123             : #endif
     124             : 
     125             : !            call timestart("alloc&init")
     126             :             allocate (prod(0:stepf%gridLength - 1, psize), stat=ok)
     127             :             if (ok /= 0) call juDFT_error("can't alloc prod")
     128             :             allocate (psi_k(0:stepf%gridLength - 1, 1), stat=ok)
     129             :             if (ok /= 0) call juDFT_error("can't alloc psi_k")
     130             : 
     131             :             call fft%init(stepf%dimensions, .true., batch_size=psize, l_gpu=.True.)
     132             :             call grid%init(fi%cell, fi%sym, gcutoff)
     133             :             call wavef2rs_fft%init(grid%dimensions, .false., batch_size=1, l_gpu=.True.)
     134             : !            call timestop("alloc&init")
     135             : 
     136             :             !$acc data copyin(z_k, z_k%l_real, z_k%data_r, z_k%data_c, lapw, lapw%nv, lapw%gvec)&
     137             :             !$acc      copyin(hybdat, hybdat%nbasp, g_ptr, grid, grid%dimensions, jsp)&
     138             :             !$acc      create(psi_k, prod)
     139             : #ifndef _OPENACC
     140             :                !$OMP DO
     141             : #endif
     142             :                do iband = 1, hybdat%nbands(ik,jsp)
     143             :                   call wavef2rs(fi, lapw, z_k, gcutoff, iband, iband, jsp, grid, wavef2rs_fft, psi_k)
     144             :                   
     145             :                   !$acc kernels default(none) present(prod, psi_k, psi_kqpt, stepf, stepf%gridlength)               
     146             :                   do iob = 1, psize
     147             :                      do j = 0, stepf%gridlength-1
     148             :                         prod(j,iob) = conjg(psi_k(j, 1)) * psi_kqpt(j, iob)
     149             :                      enddo
     150             :                   enddo
     151             :                   !$acc end kernels
     152             : 
     153             :                   call fft%exec_batch(prod)
     154             :             
     155             :                   if (cprod%l_real) then
     156             :                      if (.not. real_warned) then
     157             :                         !$acc kernels present(prod) copyout(max_imag)
     158             :                         max_imag = maxval(abs(aimag(prod)))
     159             :                         !$acc end kernels
     160             :                         if(max_imag > 1e-8) then
     161             :                            write (*, *) "Imag part non-zero in too large"
     162             :                            real_warned = .True.
     163             :                         endif
     164             :                      endif
     165             :                         
     166             :                      !$acc kernels default(none) present(cprod, cprod%data_r, prod, g_ptr)
     167             :                      !$acc loop independent
     168             :                      do iob = 1, psize
     169             :                         !$acc loop independent
     170             :                         DO igptm = 1, max_igptm
     171             :                            cprod%data_r(hybdat%nbasp + igptm, iob + (iband - 1)*psize) = real(prod(g_ptr(igptm), iob))
     172             :                         enddo
     173             :                      enddo
     174             :                      !$acc end kernels
     175             :                   else
     176             :                      !$acc kernels default(none) present(cprod, cprod%data_c, prod, g_ptr)
     177             :                      !$acc loop independent
     178             :                      do iob = 1, psize
     179             :                         !$acc loop independent
     180             :                         DO igptm = 1, max_igptm
     181             :                            cprod%data_c(hybdat%nbasp + igptm, iob + (iband - 1)*psize) = prod(g_ptr(igptm), iob)
     182             :                         enddo
     183             :                      enddo
     184             :                      !$acc end kernels
     185             :                   endif
     186             :                enddo
     187             : #ifndef _OPENACC
     188             :                !$OMP END DO
     189             : #endif
     190             :             !$acc end data 
     191             :             call fft%free()
     192             :             !call grid%free()
     193             :             call wavef2rs_fft%free()
     194             :          !$acc end data ! psi_kqpt
     195             :          deallocate (prod, psi_k)
     196             :       !$acc end data ! stepf, stepf%grid
     197             : 
     198             : #ifndef _OPENACC
     199             :       !$OMP END PARALLEL
     200             : #endif
     201             :       !call stepf%free() 
     202             : 
     203          88 :       call timestop("Big OMP loop")
     204          88 :       deallocate(psi_kqpt)
     205          88 :       call timestop("wavef_IS_FFT")
     206          88 :    end subroutine wavefproducts_IS_FFT
     207             : 
     208          88 :    subroutine setup_g_ptr(mpdata, stepf, g_t, iq, g_out)
     209             :       implicit none
     210             :       type(t_mpdata), intent(in)          :: mpdata 
     211             :       type(t_fftgrid), intent(in)         :: stepf 
     212             :       integer, intent(in)                 :: g_t(:), iq
     213             :       integer, allocatable, intent(inout) :: g_out(:)
     214             : 
     215             :       integer :: igptm, g(3)
     216             : 
     217          88 :       if(allocated(g_out)) deallocate(g_out)
     218         264 :       allocate(g_out(mpdata%n_g(iq)))
     219             : 
     220       10168 :       DO igptm = 1, mpdata%n_g(iq)
     221       40320 :          g = mpdata%g(:, mpdata%gptm_ptr(igptm, iq)) - g_t
     222       10168 :          g_out(igptm) = stepf%g2fft(g)
     223             :       enddo
     224          88 :    end subroutine setup_g_ptr
     225             : 
     226        4596 :    subroutine wavef2rs(fi, lapw, zmat, gcutoff,  bandoi, bandof, jspin, grid, fft, psi)
     227             :       ! put block of wave functions through FFT
     228             : !$    use omp_lib
     229             :       use m_types
     230             :       use m_fft_interface
     231             :       implicit none
     232             :       type(t_fleurinput), intent(in) :: fi
     233             :       type(t_lapw), intent(in)       :: lapw
     234             :       type(t_mat), intent(in)        :: zmat
     235             :       integer, intent(in)            :: jspin, bandoi, bandof
     236             :       real, intent(in)               :: gcutoff
     237             :       type(t_fftgrid), intent(inout) :: grid
     238             :       type(t_fft), intent(inout)     :: fft
     239             :       complex, intent(inout)         :: psi(0:, bandoi:) ! (nv,ne)
     240             : 
     241             :       integer :: iv, nu, psize, dims(3)
     242             : 
     243             : #ifndef _OPENACC 
     244        4596 :       !$omp parallel do default(none) private(nu) shared(grid, bandoi, bandof, lapw, jspin, zMat, psi)
     245             : #endif
     246             :       do nu = bandoi, bandof
     247             :          call grid%put_state_on_external_grid(lapw, jspin, zMat, nu, psi(:,nu), l_gpu=.True.)
     248             :       enddo
     249             : #ifndef _OPENACC 
     250             :       !$omp end parallel do
     251             : #endif   
     252             : 
     253        4596 :       call fft%exec_batch(psi)
     254        4596 :    end subroutine wavef2rs
     255             : 
     256           0 :    subroutine prep_list_of_gvec(lapw, mpdata, g_bounds, g_t, iq, jsp, pointer, gpt0, ngpt0)
     257             :       use m_types
     258             :       use m_juDFT
     259             :       implicit none
     260             :       type(t_lapw), intent(in)    :: lapw
     261             :       TYPE(t_mpdata), intent(in)         :: mpdata
     262             :       integer, intent(in)    :: g_bounds(:), g_t(:), iq, jsp
     263             :       integer, allocatable, intent(inout) :: pointer(:, :, :), gpt0(:, :)
     264             :       integer, intent(inout) :: ngpt0
     265             : 
     266             :       integer :: ic, ig1, igptm, iigptm, ok, g(3)
     267             : 
     268             :       allocate (pointer(-g_bounds(1):g_bounds(1), &
     269             :                         -g_bounds(2):g_bounds(2), &
     270           0 :                         -g_bounds(3):g_bounds(3)), stat=ok)
     271           0 :       IF (ok /= 0) call juDFT_error('wavefproducts_noinv2: error allocation pointer')
     272           0 :       allocate (gpt0(3, size(pointer)), stat=ok)
     273           0 :       IF (ok /= 0) call juDFT_error('wavefproducts_noinv2: error allocation gpt0')
     274             : 
     275           0 :       call timestart("prep list of Gvec")
     276           0 :       pointer = 0
     277           0 :       ic = 0
     278           0 :       DO ig1 = 1, lapw%nv(jsp)
     279           0 :          DO igptm = 1, mpdata%n_g(iq)
     280           0 :             iigptm = mpdata%gptm_ptr(igptm, iq)
     281           0 :             g = lapw%gvec(:, ig1, jsp) + mpdata%g(:, iigptm) - g_t
     282           0 :             IF (pointer(g(1), g(2), g(3)) == 0) THEN
     283           0 :                ic = ic + 1
     284           0 :                gpt0(:, ic) = g
     285           0 :                pointer(g(1), g(2), g(3)) = ic
     286             :             END IF
     287             :          END DO
     288             :       END DO
     289           0 :       ngpt0 = ic
     290           0 :       call timestop("prep list of Gvec")
     291           0 :    end subroutine prep_list_of_gvec
     292             : 
     293           0 :    function calc_number_of_basis_functions(lapw, atoms, noco) result(nbasfcn)
     294             :       use m_types
     295             :       implicit NONE
     296             :       type(t_lapw), intent(in)  :: lapw
     297             :       type(t_atoms), intent(in) :: atoms
     298             :       type(t_noco), intent(in)  :: noco
     299             :       integer                   :: nbasfcn
     300             : 
     301           0 :       if (noco%l_noco) then
     302           0 :          nbasfcn = lapw%nv(1) + lapw%nv(2) + 2*atoms%nlotot
     303             :       else
     304           0 :          nbasfcn = lapw%nv(1) + atoms%nlotot
     305             :       endif
     306           0 :    end function calc_number_of_basis_functions
     307             : 
     308           0 :    function outer_prod(x, y) result(outer)
     309             :       implicit NONE
     310             :       complex, intent(in) :: x(:), y(:)
     311             :       complex :: outer(size(x), size(y))
     312             :       integer  :: i, j
     313             : 
     314           0 :       do j = 1, size(y)
     315           0 :          do i = 1, size(x)
     316           0 :             outer(i, j) = x(i)*y(j)
     317             :          enddo
     318             :       enddo
     319           0 :    end function outer_prod
     320             : end module m_wavefproducts_aux

Generated by: LCOV version 1.14