LCOV - code coverage report
Current view: top level - hybrid - spmm_noinv.F90 (source / functions) Hit Total Coverage
Test: FLEUR test coverage Lines: 96 98 98.0 %
Date: 2024-05-02 04:21:52 Functions: 1 1 100.0 %

          Line data    Source code
       1             : module m_spmm_noinv
       2             :    use iso_c_binding
       3             :    use m_spmm
       4             : #ifdef _OPENACC
       5             :    USE cublas
       6             : #define CPP_zgemm cublaszgemm
       7             : #define CPP_dgemm cublasdgemm
       8             : #define CPP_zgemv cublaszgemv
       9             : #define CPP_dgemv cublasdgemv
      10             : #define CPP_mtir_c mtir_tmp
      11             : #define CPP_mtir_r mtir_tmp
      12             : #else
      13             : #define CPP_zgemm zgemm
      14             : #define CPP_dgemm dgemm
      15             : #define CPP_zgemv zgemv
      16             : #define CPP_dgemv dgemv
      17             : #define CPP_mtir_c hybdat%coul(ikpt)%mtir%data_c
      18             : #define CPP_mtir_r hybdat%coul(ikpt)%mtir%data_r
      19             : #endif
      20             : contains
      21          22 :    subroutine spmm_noinvs(fi, mpdata, hybdat, ikpt, conjg_mtir, mat_in, mat_out)
      22             :       use m_juDFT
      23             :       use m_types
      24             :       use m_reorder
      25             :       use m_constants
      26             :       use m_calc_l_m_from_lm
      27             : 
      28             :       implicit none
      29             :       type(t_fleurinput), intent(in)    :: fi
      30             :       type(t_mpdata), intent(in)        :: mpdata
      31             :       type(t_hybdat), intent(inout)     :: hybdat
      32             :       integer, intent(in)               :: ikpt
      33             :       logical, intent(in)               :: conjg_mtir
      34             :       complex, intent(inout)            :: mat_in(:,:)
      35             :       complex, intent(inout)            :: mat_out(:,:)
      36             : 
      37             :       integer :: n_vec, i_vec, ibasm, iatom, itype, ieq, l, m, n_size
      38             :       integer :: indx0, indx1, indx2, indx3, n, iatom1, ieq1, ishift, itype1
      39             :       integer :: ishift1, indx4, lm, idx1_start, idx3_start, ld_mt1_tmp
      40             :       integer :: iat2, it2, l2, iat, ierr, irank, i, sz_mtir, sz_in, sz_out, max_l_cut
      41             :       integer(C_SIZE_T) :: free_mem, tot_mem
      42          22 :       integer, allocatable :: new_order(:)
      43          22 :       complex, allocatable :: mt1_tmp(:,:,:,:), mt2_tmp(:,:,:,:), mt3_tmp(:,:,:), mat_in_line(:)
      44             : #ifdef _OPENACC
      45             :       complex, allocatable :: mtir_tmp(:,:)
      46             : #endif
      47             : 
      48          22 :       call timestart("spmm_noinvs")
      49          22 :       call timestart("copy mt2_c")
      50       21824 :       mt2_tmp = hybdat%coul(ikpt)%mt2_c
      51          22 :       call timestop("copy mt2_c")
      52             : 
      53          22 :       sz_in  = size(mat_in, 1)
      54          22 :       sz_out  = size(mat_out, 1)
      55          22 :       n_vec = size(mat_in, 2)
      56             : 
      57          66 :       allocate(mat_in_line(size(mat_in,2)))
      58             : 
      59          22 :       call timestart("copyin gpu")
      60             :       !$acc data copyin(mt2_tmp) copy(mat_in) copyout(mat_out) create(mat_in_line)
      61             :          !$acc wait
      62          22 :          call timestop("copyin gpu")
      63             : 
      64             :          !$acc kernels present(mat_in_line, mat_in)
      65       15136 :          mat_in_line = mat_in(hybdat%nbasp + 1, :)
      66             :          !$acc end kernels
      67             :          
      68          22 :          call timestart("reorder forw")
      69          66 :          allocate(new_order(size(mat_in,1)))
      70          22 :          call forw_order(fi%atoms, fi%hybinp%lcutm1, mpdata%num_radbasfn, new_order)
      71          22 :          call reorder(new_order, mat_in)
      72          22 :          call timestop("reorder forw")
      73             : 
      74          22 :          ibasm = calc_ibasm(fi, mpdata)
      75             : 
      76             :          ! compute vecout for the indices from 0:ibasm
      77          22 :          call timestart("0 > ibasm: small matricies")
      78          22 :          call timestart("alloc&cpy mt1_tmp")
      79         132 :          allocate(mt1_tmp, mold=hybdat%coul(ikpt)%mt1_c, stat=ierr)
      80          22 :          ld_mt1_tmp = size(mt1_tmp,dim=2) ! special multiplication
      81          22 :          if(ierr /= 0) call judft_error("can't alloc mt1_tmp")
      82         110 :          call zcopy(size(mt1_tmp), hybdat%coul(ikpt)%mt1_c, 1, mt1_tmp, 1)
      83          22 :          call timestop("alloc&cpy mt1_tmp")
      84             : 
      85             : 
      86             :          !$acc kernels present(mat_out)
      87     7792982 :          mat_out = cmplx_0
      88             :          !$acc end kernels
      89             : 
      90             :          !$acc data copyin(mt1_tmp)
      91             : #ifndef _OPENACC
      92             :             !$OMP PARALLEL DO default(none) schedule(dynamic)&
      93             :             !$OMP private(iatom, itype, idx1_start, iat2, it2, l2, indx1, idx3_start, indx3)&
      94             :             !$OMP private(lm, l, m, n_size, i_vec)&
      95             :             !$OMP lastprivate(indx2)&
      96          22 :             !$OMP shared(ibasm, mat_in, hybdat, mat_out, fi, mpdata, n_vec, ikpt, ld_mt1_tmp, sz_out, sz_in, mt1_tmp, mt2_tmp)
      97             : #endif
      98             :             do iatom = 1, fi%atoms%nat
      99             :                itype = fi%atoms%itype(iatom)
     100             : 
     101             :                idx1_start = 0
     102             :                do iat2 =1,iatom-1
     103             :                   it2 = fi%atoms%itype(iat2)
     104             :                   do l2 = 0, fi%hybinp%lcutm1(it2)
     105             :                      idx1_start = idx1_start + (mpdata%num_radbasfn(l2, it2)-1) * (2*l2+1)
     106             :                   enddo
     107             :                enddo
     108             :                indx1 = idx1_start
     109             : 
     110             :                idx3_start = ibasm
     111             :                do iat2 = 1,iatom-1
     112             :                   it2 = fi%atoms%itype(iat2)
     113             :                   idx3_start = idx3_start + (fi%hybinp%lcutm1(it2)+1)**2
     114             :                enddo
     115             :                indx3 = idx3_start 
     116             :                do lm = 1, (fi%hybinp%lcutm1(itype) + 1)**2
     117             :                   call calc_l_m_from_lm(lm, l, m)
     118             :                   indx1 = indx1 + 1
     119             :                   indx2 = indx1 + mpdata%num_radbasfn(l, itype) - 2
     120             :                   indx3 = indx3 + 1
     121             : 
     122             :                   n_size = mpdata%num_radbasfn(l, itype) - 1
     123             : 
     124             :                   !$acc host_data use_device(mt1_tmp, mat_in, mat_out)
     125             :                   call CPP_zgemm("N","N", n_size, n_vec, n_size, cmplx_1, mt1_tmp(1,1,l,itype), ld_mt1_tmp,&
     126             :                               mat_in(indx1,1), sz_in, cmplx_0, mat_out(indx1,1), sz_out)
     127             :                   !$acc end host_data
     128             : 
     129             :                   !$acc kernels present(mat_out, mt2_tmp, mat_in)
     130             :                   do i_vec = 1, n_vec
     131             :                      do i = 0, indx2-indx1
     132             :                         mat_out(indx1+i,i_vec) = mat_out(indx1+i,i_vec) + mt2_tmp(i+1, m, l, iatom) * mat_in(indx3, i_vec)
     133             :                      enddo
     134             :                   enddo
     135             :                   !$acc end kernels
     136             : 
     137             :                   indx1 = indx2
     138             :                END DO
     139             :             END DO
     140             : #ifndef _OPENACC
     141             :             !$OMP END PARALLEL DO
     142             : #endif
     143             :          !$acc end data
     144             :          !$acc wait
     145          22 :          deallocate(mt1_tmp)
     146          22 :          call timestop("0 > ibasm: small matricies")
     147             : 
     148          22 :          IF (indx2 /= ibasm) call judft_error('spmvec: error counting basis functions')
     149             : 
     150          22 :          IF (ikpt == 1) THEN
     151           6 :             call timestart("gamma point 1 noinv")
     152           6 :             call timestart("cpy mt3_tmp")
     153          30 :             allocate(mt3_tmp, mold=hybdat%coul(ikpt)%mt3_c, stat=ierr)
     154           6 :             if(ierr /= 0 ) call judft_error("can't alloc mt3_tmp")
     155          24 :             call zcopy(size(mt3_tmp), hybdat%coul(ikpt)%mt3_c, 1, mt3_tmp, 1)
     156           6 :             call timestop("cpy mt3_tmp")
     157             : 
     158          18 :             max_l_cut = maxval(fi%hybinp%lcutm1)
     159             : #ifdef _OPENACC
     160             :             !$acc data copyin(mt3_tmp)
     161             : #else
     162             :             !$OMP PARALLEL DO default(none) schedule(dynamic)&
     163             :             !$OMP private(iatom, itype, indx0, l, m, indx1, indx2, iatom1, indx3) &
     164             :             !$OMP private(indx4, i_vec, n_size, itype1, ishift1,ieq1) &
     165           6 :             !$OMP shared(fi, n_vec, mpdata, hybdat, ibasm, mat_out, mat_in, ikpt, mat_in_line, mt3_tmp, mt2_tmp, max_l_cut)
     166             : #endif
     167             :                do iatom = 1,fi%atoms%nat
     168             :                   itype = fi%atoms%itype(iatom)
     169             :                   indx0 = 0
     170             :                   do iat = 1,iatom-1
     171             :                      indx0 = indx0 + sum([((2*l + 1)*(mpdata%num_radbasfn(l, fi%atoms%itype(iat)) - 1), l=0, fi%hybinp%lcutm1(fi%atoms%itype(iat)))])
     172             :                   enddo
     173             :                   l = 0
     174             :                   m = 0
     175             : 
     176             :                   indx1 = indx0 + 1
     177             :                   indx2 = indx1 + mpdata%num_radbasfn(l, itype) - 2
     178             : 
     179             :                   iatom1 = 0
     180             :                   indx3 = ibasm
     181             :                   n_size = mpdata%num_radbasfn(l, itype) - 1
     182             :                   DO itype1 = 1, fi%atoms%ntype
     183             :                      ishift1 = (fi%hybinp%lcutm1(itype1) + 1)**2
     184             :                      DO ieq1 = 1, fi%atoms%neq(itype1)
     185             :                         iatom1 = iatom1 + 1
     186             :                         indx4 = indx3 + (ieq1 - 1)*ishift1 + 1
     187             :                         if (iatom /= iatom1) then
     188             :                            !$acc kernels present(mat_out, mt3_tmp, mat_in) default(none)
     189             :                            do i_vec = 1, n_vec
     190             :                               mat_out(indx1:indx2, i_vec) = mat_out(indx1:indx2, i_vec) + mt3_tmp(:n_size, iatom1, iatom)*mat_in(indx4, i_vec)
     191             :                            enddo
     192             :                            !$acc end kernels
     193             :                         endif
     194             :                      END DO
     195             :                      indx3 = indx3 + fi%atoms%neq(itype1)*ishift1
     196             :                   END DO
     197             :                   IF (indx3 /= hybdat%nbasp) call judft_error('spmvec: error counting index indx3')
     198             : 
     199             :                   n_size = mpdata%num_radbasfn(l, itype) - 1
     200             :                   !$acc kernels present(mat_out, mt2_tmp, mat_in_line) default(none)
     201             :                   do i_vec = 1, n_vec
     202             :                      mat_out(indx1:indx2, i_vec) = mat_out(indx1:indx2, i_vec) + mt2_tmp(:n_size, 0, max_l_cut + 1, iatom)*mat_in_line(i_vec)
     203             :                   enddo
     204             :                   !$acc end kernels
     205             :                END DO
     206             : #ifdef _OPENACC
     207             :             !$acc end data !(mt3_tmp)
     208             : #else
     209             :             !$OMP END PARALLEL DO
     210             : #endif
     211           6 :             call timestop("gamma point 1 noinv")
     212             :          END IF
     213             :          ! compute vecout for the index-range from ibasm+1:nbasm
     214             : 
     215          22 :          call timestart("calc indx1")
     216             :          indx1 = sum([(((2*l + 1)*fi%atoms%neq(itype), l=0, fi%hybinp%lcutm1(itype)), &
     217         506 :                      itype=1, fi%atoms%ntype)]) + mpdata%n_g(ikpt)
     218          22 :          call timestop("calc indx1")
     219             : 
     220             : #ifdef _OPENACC
     221             :          call timestart("copy mtir_tmp")
     222             :          allocate(mtir_tmp(hybdat%coul(ikpt)%mtir%matsize1, hybdat%coul(ikpt)%mtir%matsize2), stat=ierr)
     223             :          if(ierr /= 0) call judft_error("can't alloc mtir_tmp")
     224             :          call zlacpy("N", size(mtir_tmp,1), size(mtir_tmp,2), hybdat%coul(ikpt)%mtir%data_c, &
     225             :                      size(hybdat%coul(ikpt)%mtir%data_c,1), mtir_tmp, size(mtir_tmp,1))
     226             :          call timestop("copy mtir_tmp")
     227             : #endif
     228             : 
     229          22 :          call timestart("acc kernels")
     230             :          !$acc enter data copyin(mtir_tmp)
     231          22 :          if(conjg_mtir) then
     232             :             !$acc kernels present(mtir_tmp)
     233           0 :             CPP_mtir_c = conjg(CPP_mtir_c)
     234             :             !$acc end kernels
     235             :          endif
     236          22 :          call timestop("acc kernels")
     237             : 
     238          22 :          call timestart("ibasm+1->nbasm: zgemm")
     239          22 :          sz_mtir = size(CPP_mtir_c,1)
     240             : 
     241             :          !$acc host_data use_device(CPP_mtir_c, mat_in, mat_out)
     242             :          call CPP_zgemm("N", "N", indx1, n_vec, indx1, cmplx_1, CPP_mtir_c, sz_mtir, &
     243          22 :                      mat_in(ibasm + 1, 1), sz_in, cmplx_0, mat_out(ibasm + 1, 1), sz_out)
     244             :          !$acc end host_data
     245             :          !$acc exit data delete(CPP_mtir_c)
     246             : #ifdef _OPENACC
     247             :          deallocate(mtir_tmp)
     248             : #else       
     249          22 :          if(conjg_mtir) then
     250           0 :             CPP_mtir_c = conjg(CPP_mtir_c)
     251             :          endif
     252             : #endif
     253             :          !$acc wait
     254          22 :          call timestop("ibasm+1->nbasm: zgemm")
     255             : 
     256          22 :          call timestart("dot prod")
     257             :          !$acc kernels present(mt2_tmp)
     258       21714 :          mt2_tmp = conjg(mt2_tmp)
     259             :          !$acc end kernels
     260             : 
     261          22 :          iatom = 0
     262          22 :          indx1 = ibasm; indx2 = 0; indx3 = 0
     263          66 :          DO itype = 1, fi%atoms%ntype
     264         110 :             DO ieq = 1, fi%atoms%neq(itype)
     265          44 :                iatom = iatom + 1
     266         308 :                DO l = 0, fi%hybinp%lcutm1(itype)
     267         220 :                   n = mpdata%num_radbasfn(l, itype)
     268        1364 :                   DO m = -l, l
     269        1100 :                      indx1 = indx1 + 1
     270        1100 :                      indx2 = indx2 + 1
     271        1100 :                      indx3 = indx3 + n - 1
     272             : 
     273             :                      !$acc host_data use_device(mat_in, mt2_tmp, mat_out)
     274             :                      call CPP_zgemv("T", n-1, n_vec, cmplx_1, mat_in(indx2,1), sz_in, mt2_tmp(1, m, l, iatom), 1, &
     275        1100 :                      cmplx_1, mat_out(indx1,1), sz_out)
     276             :                      !$acc end host_data
     277             : 
     278        1320 :                      indx2 = indx3
     279             :                   END DO
     280             : 
     281             :                END DO
     282             :             END DO
     283             :          END DO
     284          22 :          call timestop("dot prod")
     285             : 
     286          22 :          IF (ikpt == 1) THEN
     287           6 :             call timestart("gamma point 2 noinv")
     288           6 :             iatom = 0
     289           6 :             indx0 = 0
     290             : 
     291          18 :             max_l_cut = maxval(fi%hybinp%lcutm1)
     292          18 :             DO itype = 1, fi%atoms%ntype
     293         144 :                ishift = sum([((2*l + 1)*(mpdata%num_radbasfn(l, itype) - 1), l=0, fi%hybinp%lcutm1(itype))])
     294          30 :                DO ieq = 1, fi%atoms%neq(itype)
     295          12 :                   iatom = iatom + 1
     296          12 :                   indx1 = indx0 + 1
     297          12 :                   indx2 = indx1 + mpdata%num_radbasfn(0, itype) - 2
     298          12 :                   n_size = mpdata%num_radbasfn(0, itype) - 1
     299             : 
     300             :                   !$acc host_data use_device(mat_in, mt2_tmp, mat_out)
     301             :                   call CPP_zgemv("T", n_size, n_vec, cmplx_1, mat_in(indx1,1), sz_in, &
     302          12 :                      mt2_tmp(1,0,max_l_cut + 1, iatom), 1, cmplx_1, mat_out(hybdat%nbasp + 1, 1), sz_out)
     303             :                   !$acc end host_data
     304          24 :                   indx0 = indx0 + ishift
     305             :                END DO
     306             :             END DO
     307             : 
     308             :             !$acc data copyin(mt3_tmp)
     309             :                !$acc kernels present(mt3_tmp)
     310         234 :                mt3_tmp = conjg(mt3_tmp)
     311             :                !$acc end kernels
     312             : #ifndef _OPENACC
     313             :                !$OMP PARALLEL DO default(none) &
     314             :                !$OMP private(iatom, itype, indx1, iatom1, indx2, itype1, ishift1, indx3, indx4, n_size) &
     315           6 :                !$OMP shared(fi, mpdata, hybdat,mat_out, mat_in, ibasm, ikpt, n_vec, mt3_tmp, sz_out, sz_in)
     316             : #endif
     317             :                do iatom = 1, fi%atoms%nat 
     318             :                   itype = fi%atoms%itype(iatom)
     319             :                   indx1 = ibasm + sum([((fi%hybinp%lcutm1(fi%atoms%itype(iat)) + 1)**2, iat=1,iatom-1)]) + 1
     320             :                   iatom1 = 0
     321             :                   indx2 = 0
     322             :                   DO itype1 = 1, fi%atoms%ntype
     323             :                      ishift1 = sum([((2*l + 1)*(mpdata%num_radbasfn(l, itype1) - 1), l=0, fi%hybinp%lcutm1(itype1))])
     324             :                      DO ieq1 = 1, fi%atoms%neq(itype1)
     325             :                         iatom1 = iatom1 + 1
     326             :                         IF (iatom1 /= iatom) then
     327             :                            indx3 = indx2 + (ieq1 - 1)*ishift1 + 1
     328             :                            indx4 = indx3 + mpdata%num_radbasfn(0, itype1) - 2
     329             :                            n_size = mpdata%num_radbasfn(0, itype1) - 1
     330             : 
     331             :                            !$acc host_data use_device(mat_in, mt3_tmp, mat_out)
     332             :                            call CPP_zgemv("T", n_size, n_vec, cmplx_1, mat_in(indx3,1), sz_in, mt3_tmp(1, iatom, iatom1), 1, &
     333             :                                     cmplx_1, mat_out(indx1,1), sz_out)
     334             :                            !$acc end host_data
     335             :                         endif
     336             :                      END DO
     337             :                      indx2 = indx2 + fi%atoms%neq(itype1)*ishift1
     338             :                   END DO
     339             :                END DO
     340             : #ifndef _OPENACC
     341             :                !$OMP END PARALLEL DO
     342             : #endif
     343             :             !$acc end data !(mt3_tmp)
     344           6 :             deallocate(mt3_tmp) 
     345           6 :             call timestop("gamma point 2 noinv")
     346             :          END IF
     347             : 
     348          22 :          call timestart("reorder back")
     349          22 :          call back_order(fi%atoms, fi%hybinp%lcutm1, mpdata%num_radbasfn, new_order)
     350          22 :          call reorder(new_order, mat_in)
     351          22 :          call reorder(new_order, mat_out)
     352          22 :          call timestop("reorder back")
     353             :       
     354          22 :          call timestart("copyout")
     355             :       !$acc end data !mt2_tmp, mat_in, mat_out
     356             :       !$acc wait
     357          22 :       call timestop("copyout")
     358          22 :       call timestop("spmm_noinvs")
     359          22 :    end subroutine spmm_noinvs
     360             : end module m_spmm_noinv

Generated by: LCOV version 1.14