LCOV - code coverage report
Current view: top level - hybrid - work_package.F90 (source / functions) Hit Total Coverage
Test: FLEUR test coverage Lines: 98 121 81.0 %
Date: 2024-05-02 04:21:52 Functions: 7 22 31.8 %

          Line data    Source code
       1             : module m_work_package
       2             :    use m_types
       3             :    use m_distribute_mpi
       4             :    use m_divide_most_evenly
       5             :    use m_mtir_size
       6             : #ifdef _OPENACC 
       7             :    use openacc
       8             :    use iso_c_binding
       9             : #endif
      10             :    implicit none
      11             :    private
      12             :    type,public:: t_band_package  
      13             :       integer :: start_idx, psize, rank, size
      14             :    contains 
      15             :       procedure :: init => t_band_package_init
      16             :    end type t_band_package
      17             : 
      18             :    type,public:: t_q_package 
      19             :       integer :: rank, size, ptr
      20             :       type(t_hybmpi) :: submpi
      21             :       type(t_band_package), allocatable :: band_packs(:)
      22             :    contains
      23             :       procedure :: init => t_q_package_init
      24             :       procedure :: free => t_q_package_free
      25             :    end type t_q_package 
      26             : 
      27             :    type,public:: t_qwps
      28             :       type(t_q_package), allocatable :: q_packs 
      29             :    end type t_qwps
      30             : 
      31             :    type,public:: t_k_package
      32             :       integer :: nk, rank, size
      33             :       type(t_hybmpi) :: submpi
      34             :       type(t_q_package), allocatable :: q_packs(:)
      35             :    contains
      36             :       procedure :: init  => t_k_package_init 
      37             :       procedure :: print => t_k_package_print
      38             :       procedure :: free  => t_k_package_free
      39             :    end type t_k_package 
      40             : 
      41             :    type,public:: t_work_package 
      42             :       integer :: rank, size, n_kpacks, max_kpacks
      43             :       type(t_k_package), allocatable :: k_packs(:)
      44             :       type(t_hybmpi) :: submpi
      45             :    contains
      46             :       procedure :: init  => t_work_package_init 
      47             :       procedure :: print => t_work_package_print
      48             :       procedure :: owner_nk => t_work_package_owner_nk
      49             :       procedure :: has_nk => t_work_package_has_nk
      50             :       procedure :: free => t_work_package_free
      51             :    end type t_work_package
      52             : 
      53             : contains
      54          16 :    subroutine t_work_package_free(work_pack)
      55             :       implicit none 
      56             :       class(t_work_package), intent(inout) :: work_pack 
      57             :       integer :: i
      58             : 
      59          16 :       if(allocated(work_pack%k_packs)) then 
      60          40 :          do i = 1, size(work_pack%k_packs)
      61          40 :             call work_pack%k_packs(i)%free() 
      62             :          enddo
      63          40 :          deallocate(work_pack%k_packs)
      64             :       endif
      65          16 :    end subroutine t_work_package_free 
      66             : 
      67          24 :    subroutine t_k_package_free(k_pack)
      68             :       implicit none 
      69             :       class(t_k_package), intent(inout) :: k_pack 
      70             :       integer :: i
      71             : 
      72          24 :       if(allocated(k_pack%q_packs)) then 
      73         112 :          do i = 1, size(k_pack%q_packs)
      74         112 :             call k_pack%q_packs(i)%free()
      75             :          enddo 
      76         112 :          deallocate(k_pack%q_packs)
      77             :       endif
      78          24 :    end subroutine t_k_package_free 
      79             : 
      80          88 :    subroutine t_q_package_free(q_pack) 
      81             :       implicit none 
      82             :       class(t_q_package), intent(inout) :: q_pack 
      83             : 
      84          88 :       if(allocated(q_pack%band_packs)) deallocate(q_pack%band_packs)
      85           0 :    end subroutine t_q_package_free
      86             : 
      87          16 :    subroutine t_work_package_init(work_pack, fi, hybdat, mpdata, wp_mpi, jsp, rank, size) 
      88             :       implicit none 
      89             :       class(t_work_package), intent(inout) :: work_pack
      90             :       type(t_fleurinput), intent(in)       :: fi
      91             :       type(t_hybdat), intent(in)           :: hybdat 
      92             :       type(t_mpdata), intent(in)           :: mpdata
      93             :       type(t_hybmpi), intent(in)           :: wp_mpi
      94             :       integer, intent(in)                  :: rank, size, jsp
      95             : 
      96          16 :       call timestart("t_work_package_init")
      97          16 :       work_pack%rank    = rank
      98          16 :       work_pack%size    = size
      99          16 :       work_pack%submpi  = wp_mpi
     100             : 
     101          16 :       call split_into_work_packages(work_pack, fi, hybdat, mpdata, jsp)
     102             : 
     103          16 :       call timestop("t_work_package_init")
     104          16 :    end subroutine t_work_package_init
     105             : 
     106          24 :    subroutine t_k_package_init(k_pack, fi, hybdat, mpdata, k_wide_mpi, jsp, nk)
     107             :       implicit none 
     108             :       class(t_k_package), intent(inout) :: k_pack
     109             :       type(t_fleurinput), intent(in)    :: fi
     110             :       type(t_hybdat), intent(in)        :: hybdat
     111             :       type(t_mpdata), intent(in)        :: mpdata
     112             :       type(t_hybmpi), intent(in)        :: k_wide_mpi
     113             :       type(t_hybmpi)                    :: q_wide_mpi
     114             : 
     115             :       integer, intent(in)  :: nk, jsp
     116             :       integer              :: iq, jq, loc_num_qs, i, cnt, n_groups, idx, q_rank, w_cnt
     117          24 :       integer, allocatable :: loc_qs(:)
     118             : 
     119          24 :       n_groups = min(k_wide_mpi%size, fi%kpts%EIBZ(nk)%nkpt)
     120          96 :       allocate(loc_qs(n_groups), source=0)
     121          48 :       do w_cnt = 1, n_groups 
     122          48 :          do i = w_cnt, fi%kpts%EIBZ(nk)%nkpt, n_groups 
     123          88 :             loc_qs(w_cnt) = loc_qs(w_cnt) + 1 
     124             :          enddo
     125             :       enddo
     126             : 
     127          24 :       call distribute_mpi(loc_qs, k_wide_mpi, q_wide_mpi, q_rank)
     128             : 
     129          24 :       k_pack%submpi = k_wide_mpi
     130          24 :       k_pack%nk = nk 
     131             :       
     132         160 :       allocate(k_pack%q_packs(loc_qs(q_rank+1)))
     133          24 :       cnt = 0
     134         112 :       do iq = q_rank+1,fi%kpts%EIBZ(nk)%nkpt, n_groups
     135          88 :          cnt = cnt + 1
     136          88 :          jq = fi%kpts%EIBZ(nk)%pointer(iq)
     137         112 :          call k_pack%q_packs(cnt)%init(fi, hybdat, mpdata, q_wide_mpi, jsp, nk, iq, jq)
     138             :       enddo
     139          24 :    end subroutine t_k_package_init
     140             : 
     141          88 :    subroutine t_q_package_init(q_pack, fi, hybdat, mpdata, q_wide_mpi, jsp, nk, rank, ptr)
     142             :       implicit none 
     143             :       class(t_q_package), intent(inout) :: q_pack 
     144             :       type(t_fleurinput), intent(in)    :: fi
     145             :       type(t_hybdat), intent(in)        :: hybdat
     146             :       type(t_mpdata), intent(in)        :: mpdata
     147             :       type(t_hybmpi), intent(in)        :: q_wide_mpi
     148             :       integer, intent(in)               :: rank, ptr, jsp, nk
     149             : 
     150             :       integer              :: target_psize
     151             :       integer              :: n_parts, ikqpt, i
     152          88 :       integer, allocatable :: start_idx(:), psize(:)
     153             : 
     154          88 :       q_pack%submpi = q_wide_mpi
     155          88 :       q_pack%rank   = rank 
     156          88 :       q_pack%size   = fi%kpts%EIBZ(nk)%nkpt
     157          88 :       q_pack%ptr    = ptr
     158             : 
     159         440 :       ikqpt = fi%kpts%get_nk(fi%kpts%to_first_bz(fi%kpts%bkf(:,nk) + fi%kpts%bkf(:,ptr)))
     160          88 :       n_parts = calc_n_parts(fi, hybdat, mpdata%n_g, q_pack, ikqpt, jsp)
     161             :       
     162         352 :       allocate(start_idx(n_parts), psize(n_parts))
     163         264 :       allocate(q_pack%band_packs(n_parts))
     164             : 
     165          88 :       call divide_most_evenly(hybdat%nobd(ikqpt, jsp), n_parts, start_idx, psize)
     166             : 
     167         176 :       do i = 1, n_parts
     168          88 :          call q_pack%band_packs(i)%init(start_idx(i), psize(i), i, n_parts)
     169             :       enddo
     170          88 :    end subroutine t_q_package_init
     171             : 
     172          88 :    subroutine t_band_package_init(band_pack, start_idx, psize, rank, size)
     173             :       implicit none 
     174             :       class(t_band_package), intent(inout) :: band_pack 
     175             :       integer, intent(in)                  :: rank, size, start_idx, psize
     176             : 
     177          88 :       band_pack%start_idx = start_idx
     178          88 :       band_pack%psize     = psize 
     179          88 :       band_pack%rank      = rank
     180          88 :       band_pack%size      = size
     181           0 :    end subroutine t_band_package_init
     182             : 
     183           0 :    subroutine t_work_package_print(work_pack)
     184             :       implicit none
     185             :       class(t_work_package), intent(inout) :: work_pack
     186             :       integer :: i 
     187             : 
     188           0 :       write (*,*) "WP (" // int2str(work_pack%rank) // "/" // int2str(work_pack%size) // ") has: "
     189           0 :       do i = 1,size(work_pack%k_packs)
     190           0 :          call work_pack%k_packs(i)%print()
     191             :       enddo
     192           0 :    end subroutine t_work_package_print 
     193             : 
     194           0 :    subroutine t_k_package_print(k_pack)
     195             :       implicit none 
     196             :       class(t_k_package), intent(in) :: k_pack 
     197             : 
     198           0 :       write (*,*) "kpoint: "
     199           0 :       write (*,*) "nk = ", k_pack%nk
     200           0 :    end subroutine t_k_package_print
     201             : 
     202          16 :    subroutine split_into_work_packages(work_pack, fi, hybdat, mpdata, jsp)
     203             : #ifdef CPP_MPI
     204             :       use mpi 
     205             : #endif
     206             :       implicit none 
     207             :       class(t_work_package), intent(inout) :: work_pack
     208             :       type(t_fleurinput), intent(in)       :: fi
     209             :       type(t_hybdat), intent(in)           :: hybdat
     210             :       type(t_mpdata), intent(in)           :: mpdata
     211             :       integer, intent(in)                  :: jsp
     212             :       integer :: k_cnt, i, ierr
     213             :       
     214          16 :       if(work_pack%rank < modulo(fi%kpts%nkpt, work_pack%size)) then
     215           8 :          work_pack%n_kpacks = ceiling(1.0*fi%kpts%nkpt / work_pack%size)
     216             :       else 
     217           8 :          work_pack%n_kpacks = floor(1.0*fi%kpts%nkpt / work_pack%size)
     218             :       endif
     219          72 :       allocate(work_pack%k_packs(work_pack%n_kpacks))
     220             : 
     221             : #ifdef CPP_MPI
     222          16 :       call MPI_AllReduce(work_pack%n_kpacks, work_pack%max_kpacks, 1, MPI_INTEGER, MPI_MAX, MPI_COMM_WORLD, ierr)
     223             : #else    
     224             :       work_pack%max_kpacks = work_pack%n_kpacks
     225             : #endif
     226          16 :       if(work_pack%n_kpacks /= work_pack%max_kpacks) then
     227           8 :          call judft_warn("Your parallization is not efficient. Make sure that nkpts%pe == 0 or nkpts <= pe")
     228             :       endif 
     229             : 
     230             :       
     231             :       ! get my k-list
     232          16 :       k_cnt = 1
     233          16 :       do i = work_pack%rank+1, fi%kpts%nkpt, work_pack%size
     234          24 :          work_pack%k_packs(k_cnt)%rank = k_cnt -1
     235          24 :          work_pack%k_packs(k_cnt)%size = work_pack%n_kpacks
     236             : 
     237          24 :          call work_pack%k_packs(k_cnt)%init(fi, hybdat, mpdata, work_pack%submpi, jsp, i)
     238          24 :          k_cnt = k_cnt + 1
     239             :       enddo
     240          16 :    end subroutine split_into_work_packages
     241             : 
     242             : 
     243           0 :    function t_work_package_owner_nk(work_pack, nk) result(owner) 
     244             :       use m_types_hybmpi
     245             :       implicit none 
     246             :       class(t_work_package), intent(in) :: work_pack
     247             :       integer, intent(in)               :: nk
     248             :       integer                           :: owner
     249             : 
     250           0 :       owner = modulo(nk-1, work_pack%size)
     251           0 :    end function t_work_package_owner_nk
     252             : 
     253           0 :    function t_work_package_has_nk(work_pack, nk) result(has_nk) 
     254             :       implicit none 
     255             :       class(t_work_package), intent(in) :: work_pack
     256             :       integer, intent(in)               :: nk
     257             :       logical                           :: has_nk
     258             :       integer :: i 
     259             : 
     260           0 :       has_nk = .false.
     261           0 :       do i = 1, work_pack%n_kpacks 
     262           0 :          if (work_pack%k_packs(i)%nk == nk) then
     263             :             has_nk = .True.
     264             :             exit
     265             :          endif
     266             :       enddo
     267           0 :    end function t_work_package_has_nk
     268             : 
     269          88 :    function calc_n_parts(fi, hybdat, n_g, q_pack, ikqpt, jsp) result(n_parts)
     270             :       implicit none 
     271             :       type(t_fleurinput), intent(in) :: fi
     272             :       type(t_hybdat), intent(in)     :: hybdat
     273             :       integer, intent(in)            :: n_g(:), ikqpt, jsp
     274             :       class(t_q_package), intent(in) :: q_pack 
     275             :       
     276             :       integer :: n_parts, me, ierr, ikpt
     277             : 
     278             :       integer(8), parameter :: i8_one = 1
     279             :       integer(8)            :: coulomb_size, exch_size, indx_size, nsest_size, target_size, rc_factor 
     280             :       integer(8)            :: cprod_size, spmm_peak, max_peak
     281             :       integer(8)            :: max_nbasm, max_nbands, psize
     282             : 
     283          88 :       rc_factor  = merge(8, 16, fi%sym%invs)
     284         792 :       max_nbasm  = maxval(hybdat%nbasm)
     285        1276 :       max_nbands = maxval(hybdat%nbands)
     286             : 
     287          88 :       target_size = target_memsize(fi, hybdat)
     288          88 :       coulomb_size = 0.0
     289         352 :       do ikpt = 1,fi%kpts%nkpt
     290         352 :          coulomb_size = max(int(mtir_size(fi, n_g, ikpt),kind=8)**2, coulomb_size)
     291             :       enddo
     292             :       ! size in byte
     293          88 :       coulomb_size = rc_factor * coulomb_size
     294        1276 :       exch_size    = rc_factor * maxval(i8_one*hybdat%nbands)**2
     295        1276 :       indx_size    = 4 *         maxval(i8_one*hybdat%nbands)**2
     296        1276 :       nsest_size   = 4 *         maxval(i8_one*hybdat%nbands)
     297             : 
     298        1276 :       psize = maxval(hybdat%nobd)
     299          88 :       do while(psize > 1)
     300          88 :          cprod_size = max_nbasm * max_nbands * psize * rc_factor
     301             : 
     302          88 :          spmm_peak = 2*cprod_size + coulomb_size + exch_size + indx_size + nsest_size
     303             : 
     304         176 :          max_peak = maxval([spmm_peak])
     305             : 
     306          88 :          if(max_peak <= target_size) then 
     307             :             exit 
     308             :          endif
     309          88 :          psize = psize - 1
     310             :       enddo
     311             : 
     312        1276 :       n_parts = ceiling(1.0*maxval(hybdat%nobd)/psize)
     313          88 :       do while(mod(n_parts, q_pack%submpi%size) /= 0)
     314           0 :          n_parts = n_parts + 1
     315             :       enddo
     316             : 
     317          88 :       if(n_parts > hybdat%nobd(ikqpt, jsp)) then 
     318           0 :          write (*,*) "too many parts... reducing to nobd"
     319           0 :          n_parts = hybdat%nobd(ikqpt, jsp)
     320             :       endif
     321             : #ifdef CPP_MPI
     322          88 :       call MPI_COMM_RANK(MPI_COMM_WORLD, me, ierr)
     323             : #else
     324             :       me = 0
     325             : #endif
     326             :       !if(me == 0) write (*,*) "psize: " // int2str(psize) // " max_peak: " // int2str(max_peak) // " nparts: " // int2str(n_parts)
     327          88 :    end function calc_n_parts
     328             : 
     329             :    integer(8) function target_memsize(fi, hybdat)
     330             : 
     331             :       implicit none 
     332             :       type(t_fleurinput), intent(in) :: fi
     333             :       type(t_hybdat), intent(in)     :: hybdat
     334             : 
     335             : #ifdef _OPENACC    
     336             :       integer           :: ikpt
     337             :       integer(C_SIZE_T) :: gpu_mem
     338             :       real              :: coulomb_size, exch_size
     339             : 
     340             :       gpu_mem = acc_get_property(0,acc_device_current, acc_property_free_memory)
     341             :       target_memsize = int(0.75*gpu_mem, kind=8)
     342             : #else
     343          88 :       target_memsize = int(15e9, kind=8) ! 15 Gb
     344             : #endif
     345             :    end function target_memsize
     346           0 : end module m_work_package

Generated by: LCOV version 1.14