LCOV - code coverage report
Current view: top level - hybrid - spmm_inv.F90 (source / functions) Hit Total Coverage
Test: FLEUR test coverage Lines: 74 74 100.0 %
Date: 2024-05-02 04:21:52 Functions: 1 1 100.0 %

          Line data    Source code
       1             : module m_spmm_inv
       2             :    use iso_c_binding
       3             :    use m_spmm
       4             : #ifdef _OPENACC
       5             :       USE cublas
       6             : #define CPP_zgemm cublaszgemm
       7             : #define CPP_dgemm cublasdgemm
       8             : #define CPP_zgemv cublaszgemv
       9             : #define CPP_dgemv cublasdgemv
      10             : #define CPP_mtir_c mtir_tmp
      11             : #define CPP_mtir_r mtir_tmp
      12             : #else
      13             : #define CPP_zgemm zgemm
      14             : #define CPP_dgemm dgemm
      15             : #define CPP_zgemv zgemv
      16             : #define CPP_dgemv dgemv
      17             : #define CPP_mtir_c hybdat%coul(ikpt)%mtir%data_c
      18             : #define CPP_mtir_r hybdat%coul(ikpt)%mtir%data_r
      19             : #endif
      20             : contains
      21          66 :    subroutine spmm_invs(fi, mpdata, hybdat, ikpt, mat_in, mat_out)
      22             :       use m_juDFT
      23             :       use m_types
      24             :       use m_reorder
      25             :       use m_calc_l_m_from_lm
      26             :       implicit none
      27             :       type(t_fleurinput), intent(in)    :: fi
      28             :       type(t_mpdata), intent(in)        :: mpdata
      29             :       type(t_hybdat), intent(in)        :: hybdat
      30             :       integer, intent(in)               :: ikpt
      31             :       real, intent(inout)               :: mat_in(:,:)
      32             :       real, intent(inout)               :: mat_out(:,:)
      33             : 
      34             :       integer :: n_vec, i_vec, ibasm, iatom, itype, ieq, l, m, n_size, sz_mtir, sz_hlp, sz_out, sz_mt1
      35             :       integer :: indx0, indx1, indx2, indx3, n, iatom1, ieq1, ishift, itype1, max_lcut_plus_1
      36             :       integer :: ishift1, indx4, lm, iat2, it2, l2, idx1_start, idx3_start, iat, irank, ierr
      37          66 :       real, allocatable :: mt1_tmp(:,:,:,:), mt2_tmp(:,:,:,:), mat_in_line(:), mt3_tmp(:,:,:)
      38          66 :       integer, allocatable :: new_order(:)
      39             : #ifdef _OPENACC 
      40             :       real, allocatable :: mtir_tmp(:,:)
      41             : #endif
      42             : 
      43          66 :       call timestart("spmm_invs")
      44       27922 :       mat_in_line = mat_in(hybdat%nbasp + 1, :)
      45             :       
      46          66 :       n_vec = size(mat_in, 2)
      47             : 
      48          66 :       call timestart("reorder")
      49         198 :       allocate(new_order(size(mat_in,1)))
      50          66 :       call forw_order(fi%atoms, fi%hybinp%lcutm1, mpdata%num_radbasfn, new_order)
      51             :       !$acc data copy(mat_in)
      52          66 :          call reorder(new_order, mat_in)
      53             :       !$acc end data
      54          66 :       call timestop("reorder")
      55             : 
      56          66 :       ibasm = calc_ibasm(fi, mpdata)
      57             : 
      58             : 
      59          66 :       call timestart("copies out of hydat")
      60       36828 :       mt1_tmp = hybdat%coul(ikpt)%mt1_r
      61       46156 :       mt2_tmp = hybdat%coul(ikpt)%mt2_r
      62          66 :       if(ikpt == 1 ) then
      63         450 :          mt3_tmp = hybdat%coul(ikpt)%mt3_r
      64             :       endif
      65          66 :       sz_mt1 = size(mt1_tmp,dim=2)
      66          66 :       sz_hlp  = size(mat_in, 1)
      67          66 :       sz_out  = size(mat_out, 1)     
      68          66 :       call timestop("copies out of hydat")
      69             : 
      70             : 
      71             :       !$acc data copyin(mat_in) copy(mat_out)
      72             :          !$acc data copyin(mt2_tmp)
      73          66 :             call timestart("0 > ibasm: small matricies")
      74             :             ! compute vecout for the indices from 0:ibasm
      75             : #ifndef _OPENACC
      76             :             !$OMP PARALLEL DO default(none) schedule(dynamic)&
      77             :             !$OMP private(iatom, itype, idx1_start, iat2, it2, l2, indx1, idx3_start, indx3)&
      78             :             !$OMP private(lm, l, m, n_size, i_vec)&
      79             :             !$OMP lastprivate(indx2)&
      80          66 :             !$OMP shared(ibasm, mat_in, hybdat, mat_out, fi, mpdata, n_vec, ikpt, mt2_tmp, sz_out, sz_hlp, sz_mt1, mt1_tmp)
      81             : #endif
      82             :             !$acc data copyin(mt1_tmp)
      83             :                do iatom = 1,fi%atoms%nat 
      84             :                   itype = fi%atoms%itype(iatom)
      85             : 
      86             :                   idx1_start = 0
      87             :                   do iat2 =1,iatom-1
      88             :                      it2 = fi%atoms%itype(iat2)
      89             :                      do l2 = 0, fi%hybinp%lcutm1(it2)
      90             :                         idx1_start = idx1_start + (mpdata%num_radbasfn(l2, it2)-1) * (2*l2+1)
      91             :                      enddo
      92             :                   enddo
      93             :                   indx1 = idx1_start
      94             : 
      95             :                   idx3_start = ibasm
      96             :                   do iat2 = 1,iatom-1
      97             :                      it2 = fi%atoms%itype(iat2)
      98             :                      idx3_start = idx3_start + (fi%hybinp%lcutm1(it2)+1)**2
      99             :                   enddo
     100             :                   indx3 = idx3_start 
     101             : 
     102             :                   do lm = 1,(fi%hybinp%lcutm1(itype)+1)**2
     103             :                      call calc_l_m_from_lm(lm, l, m)
     104             :                      indx1 = indx1 + 1
     105             :                      indx2 = indx1 + mpdata%num_radbasfn(l, itype) - 2
     106             :                      indx3 = indx3 + 1
     107             : 
     108             :                      n_size = mpdata%num_radbasfn(l, itype) - 1
     109             :                      !$acc host_data use_device(mt1_tmp, mat_in, mat_out)  
     110             :                      call CPP_dgemm("N","N", n_size, n_vec, n_size, 1.0, mt1_tmp(1,1,l,itype), sz_mt1,&
     111             :                                  mat_in(indx1,1), sz_hlp, 0.0, mat_out(indx1,1), sz_out)
     112             :                      !$acc end host_data
     113             : 
     114             :                      !$acc kernels present(mat_out, mat_in, mt2_tmp)
     115             :                      do i_vec = 1, n_vec
     116             :                         mat_out(indx1:indx2,i_vec) = mat_out(indx1:indx2,i_vec) + mat_in(indx3, i_vec) * mt2_tmp(:n_size,m,l,iatom) 
     117             :                      enddo
     118             :                      !$acc end kernels
     119             :                      indx1 = indx2
     120             :                   END DO
     121             :                END DO
     122             :             !$acc end data
     123             : #ifndef _OPENACC
     124             :             !$OMP END PARALLEL DO
     125             : #endif
     126          66 :             call timestop("0 > ibasm: small matricies")
     127             : 
     128          66 :             IF (indx2 /= ibasm) call judft_error('spmm: error counting basis functions')
     129             : 
     130          66 :             IF (ikpt == 1) THEN
     131             :                !$acc data copyin(mt3_tmp, mat_in_line)
     132          18 :                   call timestart("gamma point 1 inv")
     133          18 :                   iatom = 0
     134          18 :                   indx0 = 0
     135             : #ifndef _OPENACC
     136             :                   !$OMP parallel do default(none) &
     137             :                   !$OMP private(iatom, itype, ishift, l, indx0, indx1, indx2, indx3, indx4, iatom1, itype1, ishift1, i_vec, n_size)&
     138          18 :                   !$OMP private(max_lcut_plus_1) shared(fi, mpdata, hybdat, mat_out, ibasm, n_vec, ikpt, mat_in, mat_in_line, mt2_tmp, mt3_tmp)
     139             : #endif
     140             :                   do iatom = 1, fi%atoms%nat 
     141             :                      itype = fi%atoms%itype(iatom)
     142             :                      ishift = sum([((2*l + 1)*(mpdata%num_radbasfn(l, itype) - 1), l=0, fi%hybinp%lcutm1(itype))])
     143             :                      l = 0
     144             : 
     145             :                      indx0 = 0
     146             :                      do iat = 1,iatom-1
     147             :                         indx0 = indx0 + sum([((2*l + 1)*(mpdata%num_radbasfn(l, fi%atoms%itype(iat)) - 1), l=0, fi%hybinp%lcutm1(fi%atoms%itype(iat)))])
     148             :                      enddo
     149             : 
     150             :                      indx1 = indx0 + 1
     151             :                      indx2 = indx1 + mpdata%num_radbasfn(l, itype) - 2
     152             : 
     153             :                      indx3 = ibasm
     154             :                      n_size = mpdata%num_radbasfn(l, itype) - 1
     155             :                      do iatom1 = 1,fi%atoms%nat 
     156             :                         itype1 = fi%atoms%itype(iatom1)
     157             :                         ishift1 = (fi%hybinp%lcutm1(itype1) + 1)**2
     158             :                         indx4 = indx3 + 1
     159             :                         IF (iatom /= iatom1) then
     160             :                            !$acc kernels present(mat_out, mat_in, mt3_tmp)
     161             :                            do i_vec = 1, n_vec
     162             :                               mat_out(indx1:indx2, i_vec) = mat_out(indx1:indx2, i_vec) &
     163             :                                  + mt3_tmp(:n_size, iatom1, iatom)*mat_in(indx4, i_vec)
     164             :                            enddo
     165             :                            !$acc end kernels
     166             :                         endif
     167             :                         indx3 = indx3 + ishift1
     168             :                      END DO
     169             : 
     170             :                      IF (indx3 /= hybdat%nbasp) call judft_error('spmvec: error counting index indx3')
     171             : 
     172             :                      n_size = mpdata%num_radbasfn(l, itype) - 1
     173             :                      max_lcut_plus_1 = maxval(fi%hybinp%lcutm1) + 1
     174             :                      !$acc kernels present(mat_out, mat_in_line, mt2_tmp)
     175             :                      do i_vec = 1, n_vec
     176             :                         mat_out(indx1:indx2, i_vec) = mat_out(indx1:indx2, i_vec) &
     177             :                            + mt2_tmp(:n_size, 0, max_lcut_plus_1, iatom)*mat_in_line(i_vec)
     178             :                      enddo
     179             :                      !$acc end kernels
     180             :                   END DO
     181             : 
     182             : #ifndef _OPENACC
     183             :                   !$OMP end parallel do
     184             : #endif
     185          18 :                   call timestop("gamma point 1 inv")
     186             :                !$acc end data !mt3_tmp
     187             :             END IF
     188             : 
     189             :             ! compute vecout for the index-range from ibasm+1:nbasm
     190             : 
     191             :             indx1 = sum([(((2*l + 1)*fi%atoms%neq(itype), l=0, fi%hybinp%lcutm1(itype)), &
     192        1034 :                         itype=1, fi%atoms%ntype)]) + mpdata%n_g(ikpt)
     193             : 
     194          66 :             call timestart("ibasm+1 -> dgemm")
     195             : #ifdef _OPENACC
     196             :             call timestart("copy mtir_tmp")
     197             :             allocate(mtir_tmp(hybdat%coul(ikpt)%mtir%matsize1, hybdat%coul(ikpt)%mtir%matsize2), stat=ierr)
     198             :             if(ierr /= 0) call judft_error("can't alloc mtir_tmp")
     199             :             call dlacpy("N", size(mtir_tmp,1), size(mtir_tmp,2), hybdat%coul(ikpt)%mtir%data_r, &
     200             :                         size(hybdat%coul(ikpt)%mtir%data_r,1), mtir_tmp, size(mtir_tmp,1))
     201             :             call timestop("copy mtir_tmp")
     202             : #endif
     203             : 
     204             : 
     205          66 :             sz_mtir = size(CPP_mtir_r, 1)         
     206             :             !$acc data copyin(CPP_mtir_r)
     207             :                !$acc host_data use_device(CPP_mtir_r, mat_in, mat_out)  
     208             :                call CPP_dgemm("N", "N", indx1, n_vec, indx1, 1.0, CPP_mtir_r, sz_mtir, &
     209          66 :                         mat_in(ibasm + 1, 1), sz_hlp, 0.0, mat_out(ibasm + 1, 1), sz_out)
     210             :                !$acc end host_data
     211             :             !$acc end data ! CPP_mtir_r
     212             : #ifdef _OPENACC
     213             :             deallocate(mtir_tmp)
     214             : #endif
     215          66 :             call timestop("ibasm+1 -> dgemm")
     216             : 
     217          66 :             call timestart("dot prod")
     218          66 :             iatom = 0
     219          66 :             indx1 = ibasm; indx2 = 0; indx3 = 0
     220         154 :             DO itype = 1, fi%atoms%ntype
     221         242 :                DO ieq = 1, fi%atoms%neq(itype)
     222          88 :                   iatom = iatom + 1
     223         616 :                   DO l = 0, fi%hybinp%lcutm1(itype)
     224         440 :                      n = mpdata%num_radbasfn(l, itype)
     225        2728 :                      DO m = -l, l
     226        2200 :                         indx1 = indx1 + 1
     227        2200 :                         indx2 = indx2 + 1
     228        2200 :                         indx3 = indx3 + n - 1
     229             : 
     230             :                         !$acc host_data use_device(mat_in, mt2_tmp, mat_out)
     231             :                         call CPP_dgemv("T", n-1, n_vec, 1.0, mat_in(indx2,1), sz_hlp, mt2_tmp(1, m, l, iatom), 1, &
     232        2200 :                            1.0, mat_out(indx1,1), sz_out)
     233             :                         !$acc end host_data
     234             : 
     235        2640 :                         indx2 = indx3
     236             :                      END DO
     237             : 
     238             :                   END DO
     239             :                END DO
     240             :             END DO
     241             :          !$acc end data ! mt2_tmp
     242          66 :          call timestop("dot prod")
     243             :       !$acc end data ! mat_in, mat_out
     244             : 
     245             : 
     246             : 
     247          66 :       IF (ikpt == 1) THEN
     248          18 :          call timestart("gamma point 2 inv")
     249          18 :          iatom = 0
     250          18 :          indx0 = 0
     251          42 :          DO itype = 1, fi%atoms%ntype
     252         288 :             ishift = sum([((2*l + 1)*(mpdata%num_radbasfn(l, itype) - 1), l=0, fi%hybinp%lcutm1(itype))])
     253          66 :             DO ieq = 1, fi%atoms%neq(itype)
     254          24 :                iatom = iatom + 1
     255          24 :                indx1 = indx0 + 1
     256          24 :                indx2 = indx1 + mpdata%num_radbasfn(0, itype) - 2
     257          24 :                n_size = mpdata%num_radbasfn(0, itype) - 1
     258        9490 :                do i_vec = 1, n_vec
     259             :                   mat_out(hybdat%nbasp + 1, i_vec) = mat_out(hybdat%nbasp + 1, i_vec) &
     260             :                                                    + dot_product(mt2_tmp(:n_size, 0, maxval(fi%hybinp%lcutm1) + 1, iatom), &
     261      102278 :                                                                   mat_in(indx1:indx2, i_vec))
     262             :                enddo
     263          48 :                indx0 = indx0 + ishift
     264             :             END DO
     265             :          END DO
     266             : 
     267             :          !$OMP PARALLEL DO default(none) schedule(dynamic)&
     268             :          !$OMP private(iatom, itype, indx1, indx2, itype1, ishift1) &
     269             :          !$OMP private(ieq1, iatom1, indx3, indx4, n_size, i_vec) &
     270          18 :          !$OMP shared(fi, n_vec, mat_out, ibasm, mpdata, mat_in, hybdat, ikpt, mt3_tmp)
     271             :          do iatom = 1, fi%atoms%nat 
     272             :             itype = fi%atoms%itype(iatom)
     273             :             indx1 = ibasm + sum([((fi%hybinp%lcutm1(fi%atoms%itype(iat)) + 1)**2, iat=1,iatom-1)]) + 1
     274             : 
     275             :             iatom1 = 0
     276             :             indx2 = 0
     277             :             DO itype1 = 1, fi%atoms%ntype
     278             :                ishift1 = sum([((2*l + 1)*(mpdata%num_radbasfn(l, itype1) - 1), l=0, fi%hybinp%lcutm1(itype1))])
     279             :                DO ieq1 = 1, fi%atoms%neq(itype1)
     280             :                   iatom1 = iatom1 + 1
     281             :                   IF (iatom1 == iatom) CYCLE
     282             : 
     283             :                   indx3 = indx2 + (ieq1 - 1)*ishift1 + 1
     284             :                   indx4 = indx3 + mpdata%num_radbasfn(0, itype1) - 2
     285             : 
     286             :                   n_size = mpdata%num_radbasfn(0, itype1) - 1
     287             :                   do i_vec = 1, n_vec
     288             :                      mat_out(indx1, i_vec) = mat_out(indx1, i_vec) &
     289             :                                           + dot_product(mt3_tmp(:n_size, iatom, iatom1), &
     290             :                                                          mat_in(indx3:indx4, i_vec))
     291             :                   enddo
     292             :                END DO
     293             :                indx2 = indx2 + fi%atoms%neq(itype1)*ishift1
     294             :             END DO
     295             :          END DO
     296             :          !$OMP END PARALLEL DO
     297          18 :          call timestop("gamma point 2 inv")
     298             : 
     299             :       END IF
     300             : 
     301          66 :       call timestart("reorder") 
     302          66 :       call back_order(fi%atoms, fi%hybinp%lcutm1, mpdata%num_radbasfn, new_order)
     303             :       !$acc data copy(mat_in, mat_out)
     304          66 :          call reorder(new_order, mat_in)
     305          66 :          call reorder(new_order, mat_out)
     306             :       !$acc end data
     307          66 :       call timestop("reorder")
     308          66 :       call timestop("spmm_invs")
     309          66 :    end subroutine spmm_invs
     310             : end module m_spmm_inv

Generated by: LCOV version 1.14