LCOV - code coverage report
Current view: top level - eigen - hsmt_nonsph.F90 (source / functions) Hit Total Coverage
Test: FLEUR test coverage Lines: 72 87 82.8 %
Date: 2024-04-26 04:44:34 Functions: 1 1 100.0 %

          Line data    Source code
       1             : !--------------------------------------------------------------------------------
       2             : ! Copyright (c) 2016 Peter Grünberg Institut, Forschungszentrum Jülich, Germany
       3             : ! This file is part of FLEUR and available as free software under the conditions
       4             : ! of the MIT license as expressed in the LICENSE file in more detail.
       5             : !--------------------------------------------------------------------------------
       6             : MODULE m_hsmt_nonsph
       7             :    USE m_juDFT
       8             :    IMPLICIT NONE
       9             :    PRIVATE
      10             :    PUBLIC hsmt_nonsph
      11             : 
      12             : CONTAINS
      13       17948 :    SUBROUTINE hsmt_nonsph(n,fmpi,sym,atoms,ilSpinPr,ilSpin,igSpinPr,igSpin,chi,noco,nococonv,cell,lapw,td,fjgj,hmat,set0,lapwq,fjgjq)
      14             :       USE m_hsmt_fjgj
      15             :       USE m_types
      16             :       USE m_hsmt_ab
      17             : #ifdef _OPENACC
      18             :       USE cublas
      19             : #define CPP_zgemm cublaszgemm
      20             : #define CPP_zherk cublaszherk
      21             : #define CPP_data_c data_c
      22             : #else
      23             : #define CPP_zgemm zgemm
      24             : #define CPP_zherk zherk
      25             : #define CPP_data_c hmat%data_c
      26             : #endif
      27             : 
      28             :       TYPE(t_mpi),      INTENT(IN) :: fmpi
      29             :       TYPE(t_sym),      INTENT(IN) :: sym
      30             :       TYPE(t_noco),     INTENT(IN) :: noco
      31             :       TYPE(t_nococonv), INTENT(IN) :: nococonv
      32             :       TYPE(t_cell),     INTENT(IN) :: cell
      33             :       TYPE(t_atoms),    INTENT(IN) :: atoms
      34             :       TYPE(t_lapw),     INTENT(IN) :: lapw
      35             :       TYPE(t_tlmplm),   INTENT(IN) :: td
      36             :       TYPE(t_fjgj),     INTENT(IN) :: fjgj
      37             : 
      38             :       INTEGER,          INTENT(IN) :: n, ilSpinPr, ilSpin, igSpinPr, igSpin
      39             :       COMPLEX,          INTENT(IN) :: chi
      40             :       LOGICAL,          INTENT(IN) :: set0  !if true, initialize the hmat matrix with zeros
      41             : 
      42             :       CLASS(t_mat),INTENT(INOUT)     ::hmat
      43             : 
      44             :       TYPE(t_lapw), OPTIONAL, INTENT(IN) :: lapwq ! Additional set of lapw, in case
      45             :       TYPE(t_fjgj), OPTIONAL, INTENT(IN) :: fjgjq ! the left and right ones differ.
      46             : 
      47             :       INTEGER :: nn, na, ab_size, l, ll, size_ab_select
      48             :       INTEGER :: size_data_c, size_ab, size_ab2 !these data-dimensions are not available so easily in openacc, hence we store them
      49             :       INTEGER :: ikGPr, ikG
      50             :       REAL    :: rchi
      51             :       COMPLEX :: cchi
      52             :       LOGICAL :: l_samelapw
      53             : 
      54       17948 :       COMPLEX, ALLOCATABLE :: ab1(:,:),ab_select(:,:)
      55       17948 :       COMPLEX, ALLOCATABLE :: abCoeffs(:,:), ab2(:,:), h_loc(:,:), data_c(:,:)
      56             : 
      57       17948 :       COMPLEX, ALLOCATABLE :: abCoeffsPr(:,:)
      58       17948 :       TYPE(t_lapw) :: lapwPr
      59       17948 :       TYPE(t_fjgj) :: fjgjPr
      60             : 
      61       17948 :       CALL timestart("non-spherical setup")
      62             : 
      63       17948 :       l_samelapw = .FALSE.
      64       17948 :       IF (.NOT.PRESENT(lapwq)) l_samelapw = .TRUE.
      65             :       IF (.NOT.l_samelapw) THEN
      66           0 :          lapwPr = lapwq
      67           0 :          fjgjPr = fjgjq
      68             :       ELSE
      69       17948 :          lapwPr = lapw
      70       17948 :          fjgjPr = fjgj
      71             :       END IF
      72             :       !$acc data copyin(fjgjPr,fjgjPr%fj,fjgjPr%gj)
      73       53844 :       size_ab = maxval(lapw%nv)
      74             : 
      75       17948 :       IF (fmpi%n_size==1) Then
      76        3296 :          size_ab_select=size_ab
      77             :       ELSE
      78       14652 :          size_ab_select=lapwPr%num_local_cols(igSpin)
      79             :       END IF
      80             : 
      81       71792 :       ALLOCATE(ab_select(size_ab_select, 2 * atoms%lmaxd * (atoms%lmaxd + 2) + 2))
      82             :       ALLOCATE(abCoeffs(2 * atoms%lmaxd * (atoms%lmaxd + 2) + 2, MAXVAL(lapw%nv)),&
      83      161532 :              & ab1(size_ab, 2 * atoms%lmaxd * (atoms%lmaxd + 2) + 2))
      84             :       ! TODO: Check, whether this is necessary or shifting to
      85             :       !       max(MAXVAL(lapwq%nv),MAXVAL(lapw%nv)) in abCoeffs is also enough.
      86      107688 :       ALLOCATE(abCoeffsPr(2 * atoms%lmaxd * (atoms%lmaxd + 2) + 2, MAXVAL(lapwPr%nv)))
      87             : 
      88       17948 :       IF (igSpinPr.NE.igSpin) THEN
      89         544 :          ALLOCATE(ab2(lapwPr%nv(igSpinPr), 2 * atoms%lmaxd * (atoms%lmaxd + 2) + 2))
      90         136 :          size_ab2 = lapwPr%nv(igSpinPr)
      91             :       ELSE
      92       17812 :          ALLOCATE(ab2(1,1))
      93       17812 :          size_ab2 = 1
      94             :       END IF
      95             : 
      96             : #ifndef _OPENACC
      97       17948 :       IF (hmat%l_real) THEN
      98       23566 :          IF (ANY(SHAPE(hmat%data_c)/=SHAPE(hmat%data_r))) THEN
      99        3016 :             DEALLOCATE(hmat%data_c)
     100       12064 :             ALLOCATE(hmat%data_c(SIZE(hmat%data_r, 1), SIZE(hmat%data_r, 2)))
     101             :          END IF
     102        9866 :          !$OMP PARALLEL DO DEFAULT(shared)
     103             :          DO l = 1, size(hmat%data_c, 2)
     104             :             hmat%data_c(:,l) = 0.0
     105             :          END DO
     106             :          !$OMP END PARALLEL DO
     107             :       END IF
     108       17948 :       size_data_c = size(hmat%data_c, 1)
     109             : #else
     110             :       IF (hmat%l_real) THEN
     111             :          ALLOCATE(data_c(SIZE(hmat%data_r, 1), SIZE(hmat%data_r, 2)))
     112             :          size_data_c = size(data_c, 1)
     113             :       ELSE
     114             :          ALLOCATE(data_c(SIZE(hmat%data_c, 1), SIZE(hmat%data_c, 2)))
     115             :          size_data_c = size(data_c, 1)
     116             :       END IF
     117             : #endif
     118             : 
     119       71792 :       ALLOCATE(h_loc(SIZE(td%h_loc_nonsph, 1), SIZE(td%h_loc_nonsph, 1)))
     120   336335008 :       h_loc = td%h_loc_nonsph(0:, 0:, n, ilSpinPr, ilSpin)
     121             : 
     122             : #ifdef _OPENACC
     123             :       !$acc enter data create(ab2,ab1,abCoeffs,abCoeffsPr,data_c,ab_select)copyin(h_loc)
     124             :       !$acc kernels present(data_c) default(none)
     125             :       data_c(:, :)=0.0
     126             :       !$acc end kernels
     127             : #endif
     128             : 
     129       36228 :       DO nn = 1,atoms%neq(n)
     130       18280 :          na = atoms%firstAtom(n) - 1 + nn
     131       36228 :          IF ((sym%invsat(na)==0) .OR. (sym%invsat(na)==1)) THEN
     132       18112 :             rchi = MERGE(REAL(chi), REAL(chi)*2, (sym%invsat(na)==0))
     133       18112 :             cchi = MERGE(chi, chi*2, (sym%invsat(na)==0))
     134             : 
     135             :             ! abCoeffs for \sigma_{\alpha} and \sigma_{g}
     136             :             ! Denoted in comments as a
     137             :             ! [local spin primed -> '; global spin primed -> pr]
     138       18112 :             CALL timestart("hsmt_ab_1")
     139             :             CALL hsmt_ab(sym, atoms, noco, nococonv, ilSpin, igSpin, n, na, cell, &
     140       18112 :                        & lapw, fjgj, abCoeffs, ab_size, .TRUE.)
     141       18112 :             CALL timestop("hsmt_ab_1")
     142             : 
     143       18112 :             IF (l_samelapw.AND.(ilSpinPr==ilSpin)) THEN
     144             :                !!$acc update device(ab)
     145             :                !$acc host_data use_device(abCoeffs,ab1,h_loc)
     146             :                CALL CPP_zgemm("C", "N", lapw%nv(igSpin), ab_size, ab_size, cmplx(1.0, 0.0), &
     147             :                             & abCoeffs, SIZE(abCoeffs, 1), h_loc, size(td%h_loc_nonsph, 1), &
     148       17668 :                             & cmplx(0.0, 0.0), ab1, size_ab)
     149             :                !$acc end host_data
     150             :             ELSE ! Needed, because t^H .NE. t!
     151             :                !!$acc update device(ab)
     152             :                !$acc host_data use_device(abCoeffs,ab1,h_loc)
     153             :                CALL CPP_zgemm("C", "C", lapw%nv(igSpin), ab_size, ab_size, cmplx(1.0, 0.0), &
     154             :                             & abCoeffs, SIZE(abCoeffs, 1), h_loc, size(td%h_loc_nonsph, 1), &
     155         444 :                             & cmplx(0.0, 0.0), ab1, size_ab)
     156             :                !$acc end host_data
     157             :             END IF
     158             : 
     159             :             ! ab1 = MATMUL(TRANSPOSE(abCoeffs(:ab_size,:lapw%nv(igSpin))),h_loc(:ab_size,:ab_size,n,ilSpin))
     160             :             ! In locally diagonal case:
     161             :             ! ab1 = a^H * L (lower triangular matrix from Cholesky decomposition)
     162             :             ! Locally offdiagonal case:
     163             :             ! ab1 = a^H * t (potential matrix in lmp lm etc.)
     164             :             ! .NOT.l_samelapw:
     165             :             ! ab1 = a^H * t^H
     166             : 
     167             :             ! Of these ab1 coeffs only a part is needed in case of MPI parallelism
     168             :             !$acc kernels default(none) present(ab_select,ab1)copyin(fmpi)
     169       18112 :             IF (fmpi%n_size>1) THEN
     170   208651852 :                ab_select(:, :) = ab1(fmpi%n_rank+1:lapw%nv(igSpin):fmpi%n_size, :)
     171             :             ELSE
     172    61483020 :                ab_select(:, :) = ab1(:, :) !All of ab1 needed
     173             :             END IF
     174             :             !$acc end kernels
     175             : 
     176       18112 :             IF (igSpinPr==igSpin) THEN
     177       17976 :                IF (ilSpinPr==ilSpin) THEN
     178       17532 :                   IF (l_samelapw) THEN
     179       17532 :                      IF (fmpi%n_size==1) THEN !use z-herk trick on single PE
     180             :                         !$acc host_data use_device(data_c,ab1)
     181        3082 :                         IF (set0 .and. nn == 1) THEN
     182             :                            !CPP_data_c = CMPLX(0.0,0.0)
     183             :                            CALL CPP_zherk("U", "N", lapw%nv(igSpinPr), ab_size, Rchi, &
     184         324 :                                         & ab1, size_ab, 0.0, CPP_data_c, size_data_c)
     185             :                         ELSE
     186             :                            CALL CPP_zherk("U", "N", lapw%nv(igSpinPr), ab_size, Rchi, &
     187        2758 :                                         & ab1, size_ab, 1.0, CPP_data_c, size_data_c)
     188             :                         END IF
     189             :                         !$acc end host_data
     190             :                         ! conjgsolve:
     191             :                         ! data_c += Rchi * a^H * H * a
     192             :                         ! [only upper triangle]
     193             :                      ELSE ! zgemm case
     194             :                         !$acc host_data use_device(data_c,ab1,ab_select)
     195       14450 :                         IF (set0 .and. nn == 1) THEN
     196             :                            !CPP_data_c = CMPLX(0.0,0.0)
     197             :                            CALL CPP_zgemm("N", "C", lapw%nv(igSpinPr), size_ab_select, ab_size, cchi, &
     198             :                                         & ab1, size_ab, ab_select, lapw%num_local_cols(igSpinPr), &
     199        1724 :                                         & CMPLX(0.0, 0.0), CPP_data_c, size_data_c)
     200             :                         ELSE
     201             :                            CALL CPP_zgemm("N", "C", lapw%nv(igSpinPr), size_ab_select, ab_size, cchi, &
     202             :                                         & ab1, size_ab, ab_select, lapw%num_local_cols(igSpinPr), &
     203       12726 :                                         & CMPLX(1.0, 0.0), CPP_data_c, size_data_c)
     204             :                         END IF
     205             :                         !$acc end host_data
     206             :                         ! conjgsolve:
     207             :                         ! ab_select = a^H * L
     208             :                         ! ab1 = a^H * L
     209             :                         ! data_c += cchi * ab1 * abselect^H
     210             :                         !         = cchi * a^H * H * a
     211             :                      END IF
     212             :                   ELSE ! Case for additional q on left vector.
     213           0 :                      CALL timestart("hsmt_ab_2")
     214             :                      CALL hsmt_ab(sym, atoms, noco, nococonv, ilSpin, igSpin, n, na, cell, &
     215           0 :                                 & lapwPr, fjgjPr, abCoeffsPr, ab_size, .TRUE.)
     216             :                      !!$acc update device (abCoeffsPr)
     217           0 :                      CALL timestop("hsmt_ab_2")
     218             : 
     219             :                      !$acc host_data use_device(abCoeffsPr,data_c,ab1,ab_select)
     220           0 :                      IF (set0 .and. nn == 1) THEN
     221             :                         CALL CPP_zgemm("C", "C", lapwPr%nv(igSpin), size_ab_select, ab_size, chi, &
     222             :                                      & abCoeffsPr, SIZE(abCoeffsPr, 1), ab_select, size_ab_select, &
     223           0 :                                      & CMPLX(0.0, 0.0), CPP_data_c, SIZE_data_c)
     224             :                      ELSE
     225             :                         CALL CPP_zgemm("C", "C", lapwPr%nv(igSpin), size_ab_select, ab_size, chi, &
     226             :                                      & abCoeffsPr, SIZE(abCoeffsPr, 1), ab_select, size_ab_select, &
     227           0 :                                      & CMPLX(1.0, 0.0), CPP_data_c, SIZE_data_c)
     228             :                      END IF
     229             :                      !$acc end host_data
     230             :                      ! data_c += chi * aq * abselect^H
     231             :                      !         = chi * aq^H * t * a
     232             :                   END IF
     233             :                ELSE !This is the case of a local off-diagonal contribution.
     234             :                   !It is not Hermitian, so we NEED to use zgemm CALL
     235             : 
     236             :                   ! abCoeffs for \sigma_{\alpha}^{'} and \sigma_{g}
     237         444 :                   CALL timestart("hsmt_ab_3")
     238             :                   CALL hsmt_ab(sym, atoms, noco, nococonv, ilSpinPr, igSpin, n, na, cell, &
     239         444 :                              & lapwPr, fjgjPr, abCoeffsPr, ab_size, .TRUE.)
     240             :                   !!$acc update device(abCoeffsPr)
     241         444 :                   CALL timestop("hsmt_ab_3")
     242             : 
     243             :                   !$acc host_data use_device(abCoeffsPr,data_c,ab1,ab_select)
     244         444 :                   IF (set0 .and. nn == 1) THEN
     245             :                      !CPP_data_c = CMPLX(0.0,0.0)
     246             :                      CALL CPP_zgemm("C", "C", lapwPr%nv(igSpinPr), size_ab_select, ab_size, chi, &
     247             :                                   & abCoeffsPr, SIZE(abCoeffsPr, 1), ab_select, size_ab_select, &
     248         444 :                                   & CMPLX(0.0, 0.0), CPP_data_c, SIZE_data_c)
     249             :                   ELSE
     250             :                      CALL CPP_zgemm("C", "C", lapwPr%nv(igSpinPr), size_ab_select, ab_size, chi, &
     251             :                                   & abCoeffsPr, SIZE(abCoeffsPr, 1), ab_select, size_ab_select, &
     252           0 :                                   & CMPLX(1.0, 0.0), CPP_data_c, SIZE_data_c)
     253             :                   END IF
     254             :                   !$acc end host_data
     255             :                   ! conjgsolve:
     256             :                   ! ab_select = a^H * t
     257             :                   ! abCoeffs = a'
     258             :                   ! data_c += chi * abCoeffs^H * ab_select^H
     259             :                   !         = chi * a'^H * t * a
     260             :                   ! .NOT.l_samelapw:
     261             :                   ! ab_select = a^H * t^H
     262             :                   ! abCoeffs = aq'
     263             :                   ! data_c += chi * abCoeffs^H * ab_select^H
     264             :                   !         = chi * aq'^H * t * a
     265             :                END IF
     266             :             ELSE  !here the l_ss off-diagonal part starts
     267             :                !Second set of abCoeffs is needed
     268             :                ! abCoeffs for \sigma_{\alpha}^{'} and \sigma_{g}^{'}
     269         136 :                CALL timestart("hsmt_ab_4")
     270             :                CALL hsmt_ab(sym, atoms, noco, nococonv, ilSpinPr, igSpinPr, n, na, cell, &
     271         136 :                           & lapwPr, fjgjPr, abCoeffsPr, ab_size, .TRUE.)
     272         136 :                CALL timestop("hsmt_ab_4")
     273         136 :                IF (ilSpinPr==ilSpin) THEN
     274         136 :                   IF (l_samelapw) THEN
     275             :                      !!$acc update device (abCoeffs)
     276             :                      !$acc host_data use_device(abCoeffsPr,h_loc,ab2)
     277             :                      CALL CPP_zgemm("C", "N", lapwPr%nv(igSpinPr), ab_size, ab_size, CMPLX(1.0, 0.0), &
     278             :                                   & abCoeffsPr, SIZE(abCoeffsPr, 1), h_loc, size(td%h_loc_nonsph, 1), &
     279         136 :                                   & CMPLX(0.0, 0.0), ab2, size_ab2)
     280             :                      !$acc end host_data
     281             :                      !Multiply for Hamiltonian
     282             :                      !$acc host_data use_device(ab2,ab1,data_c,ab_select)
     283             :                      CALL CPP_zgemm("N", "C", lapwPr%nv(igSpinPr), lapwPr%num_local_cols(igSpin), ab_size, chi, &
     284             :                                   & ab2, size_ab2, ab_select, size_ab_select, &
     285         136 :                                   & CMPLX(1.0, 0.0), CPP_data_c, size_data_c)
     286             :                      !$acc end host_data
     287             :                      ! conjgsolve:
     288             :                      ! ab2 = aPr'^H * L
     289             :                      ! ab_select = a^H * L
     290             :                      ! data_c += chi * ab2 * ab_select^H
     291             :                      !         = chi * aPr'^H * H * a
     292             :                   ELSE
     293             :                      !$acc host_data use_device(abCoeffsPr,data_c,ab_select)
     294           0 :                      IF (set0 .AND. nn == 1) THEN
     295             :                         !CPP_data_c = CMPLX(0.0,0.0)
     296             :                         CALL CPP_zgemm("C", "C", lapwPr%nv(igSpinPr), lapwPr%num_local_cols(igSpin), ab_size, cchi, &
     297             :                                      & abCoeffsPr, SIZE(abCoeffsPr, 1), ab_select, size_ab_select, &
     298           0 :                                      & CMPLX(0.0, 0.0), CPP_data_c, SIZE_data_c)
     299             :                      ELSE
     300             :                         CALL CPP_zgemm("C", "C", lapwPr%nv(igSpinPr), lapwPr%num_local_cols(igSpin), ab_size, cchi, &
     301             :                                      & abCoeffsPr, SIZE(abCoeffsPr, 1), ab_select, size_ab_select, &
     302           0 :                                      CMPLX(1.0, 0.0), CPP_data_c, SIZE_data_c)
     303             :                      END IF
     304             :                      !$acc end host_data
     305             :                      ! abCoeffs = aqPr'
     306             :                      ! ab_select = a^H * t^H
     307             :                      ! data_c += cchi * abCoeffs^H *  abselect^H
     308             :                      !         = cchi * aqPr'^H * t * a
     309             :                   END IF
     310             :                ELSE
     311             :                   !$acc host_data use_device(abCoeffsPr,ab1,data_c,ab_select)
     312           0 :                   IF (set0 .AND. nn == 1) THEN
     313             :                      !CPP_data_c = CMPLX(0.0,0.0)
     314             :                      CALL CPP_zgemm("C", "C", lapwPr%nv(igSpinPr), lapwPr%num_local_cols(igSpin), ab_size, cchi, &
     315             :                                   & abCoeffsPr, SIZE(abCoeffsPr, 1), ab_select, size_ab_select, &
     316           0 :                                   & CMPLX(0.0, 0.0), CPP_data_c, SIZE_data_c)
     317             :                   ELSE
     318             :                      CALL CPP_zgemm("C", "C", lapwPr%nv(igSpinPr), lapwPr%num_local_cols(igSpin), ab_size, cchi, &
     319             :                                   & abCoeffsPr, SIZE(abCoeffsPr, 1), ab_select, size_ab_select, &
     320           0 :                                   CMPLX(1.0, 0.0), CPP_data_c, SIZE_data_c)
     321             :                   END IF
     322             :                   !$acc end host_data
     323             :                   ! conjgsolve:
     324             :                   ! ab_select = a^H * t
     325             :                   ! abCoeffs = aPr'
     326             :                   ! data_c += chi * abCoeffs^H * ab_select^H
     327             :                   !         = chi * aPr'^H * t * a
     328             :                   ! .NOT.l_samelapw:
     329             :                   ! ab_select = a^H * t^H
     330             :                   ! abCoeffs = aqPr'
     331             :                   ! data_c += chi * abCoeffs^H * ab_select^H
     332             :                   !         = chi * aqPr'^H * t * a
     333             :                END IF
     334             :             END IF
     335             :          END IF
     336             :       END DO
     337             : 
     338             : #ifdef _OPENACC
     339             :          IF (hmat%l_real) THEN
     340             :             !$acc kernels present(hmat,hmat%data_r,data_c) default(none)
     341             :             hmat%data_r = hmat%data_r + real(data_c)
     342             :             !$acc end kernels
     343             :          ELSE
     344             :             if (set0) THEN
     345             :                !$acc kernels present(hmat,hmat%data_c,data_c) default(none)
     346             :                hmat%data_c =  data_c
     347             :                !$acc end kernels
     348             :             else
     349             :                !$acc kernels present(hmat,hmat%data_c,data_c) default(none)
     350             :                hmat%data_c = hmat%data_c + data_c
     351             :                !$acc end kernels
     352             :             endif
     353             :          END IF
     354             :       
     355             : #else
     356       17948 :       IF (hmat%l_real) THEN
     357        9866 :          !$OMP PARALLEL DO DEFAULT(shared)
     358             :          DO l = 1, size(hmat%data_c, 2)
     359             :             hmat%data_r(:, l) = hmat%data_r(:, l) + REAL(hmat%data_c(:, l))
     360             :          END DO
     361             :          !$OMP END PARALLEL DO
     362             :       END IF
     363             : #endif
     364             :      !$acc exit data delete(ab2,ab1,abCoeffs,abCoeffsPr,data_c,ab_select,h_loc)
     365       17948 :       DEALLOCATE(ab_select,abCoeffs,abCoeffsPr,ab1,ab2,h_loc)
     366             :       IF (ALLOCATED(data_c)) DEALLOCATE(data_c)
     367             :       !$acc end data
     368       17948 :       CALL timestop("non-spherical setup")
     369       17948 :    END SUBROUTINE hsmt_nonsph
     370             : 
     371             : END MODULE m_hsmt_nonsph

Generated by: LCOV version 1.14