Line data Source code
1 : module m_work_package
2 : use m_types
3 : use m_distribute_mpi
4 : use m_divide_most_evenly
5 : use m_mtir_size
6 : #ifdef _OPENACC
7 : use openacc
8 : use iso_c_binding
9 : #endif
10 : implicit none
11 : private
12 : type,public:: t_band_package
13 : integer :: start_idx, psize, rank, size
14 : contains
15 : procedure :: init => t_band_package_init
16 : end type t_band_package
17 :
18 : type,public:: t_q_package
19 : integer :: rank, size, ptr
20 : type(t_hybmpi) :: submpi
21 : type(t_band_package), allocatable :: band_packs(:)
22 : contains
23 : procedure :: init => t_q_package_init
24 : procedure :: free => t_q_package_free
25 : end type t_q_package
26 :
27 : type,public:: t_qwps
28 : type(t_q_package), allocatable :: q_packs
29 : end type t_qwps
30 :
31 : type,public:: t_k_package
32 : integer :: nk, rank, size
33 : type(t_hybmpi) :: submpi
34 : type(t_q_package), allocatable :: q_packs(:)
35 : contains
36 : procedure :: init => t_k_package_init
37 : procedure :: print => t_k_package_print
38 : procedure :: free => t_k_package_free
39 : end type t_k_package
40 :
41 : type,public:: t_work_package
42 : integer :: rank, size, n_kpacks, max_kpacks
43 : type(t_k_package), allocatable :: k_packs(:)
44 : type(t_hybmpi) :: submpi
45 : contains
46 : procedure :: init => t_work_package_init
47 : procedure :: print => t_work_package_print
48 : procedure :: owner_nk => t_work_package_owner_nk
49 : procedure :: has_nk => t_work_package_has_nk
50 : procedure :: free => t_work_package_free
51 : end type t_work_package
52 :
53 : contains
54 16 : subroutine t_work_package_free(work_pack)
55 : implicit none
56 : class(t_work_package), intent(inout) :: work_pack
57 : integer :: i
58 :
59 16 : if(allocated(work_pack%k_packs)) then
60 40 : do i = 1, size(work_pack%k_packs)
61 40 : call work_pack%k_packs(i)%free()
62 : enddo
63 40 : deallocate(work_pack%k_packs)
64 : endif
65 16 : end subroutine t_work_package_free
66 :
67 24 : subroutine t_k_package_free(k_pack)
68 : implicit none
69 : class(t_k_package), intent(inout) :: k_pack
70 : integer :: i
71 :
72 24 : if(allocated(k_pack%q_packs)) then
73 112 : do i = 1, size(k_pack%q_packs)
74 112 : call k_pack%q_packs(i)%free()
75 : enddo
76 112 : deallocate(k_pack%q_packs)
77 : endif
78 24 : end subroutine t_k_package_free
79 :
80 88 : subroutine t_q_package_free(q_pack)
81 : implicit none
82 : class(t_q_package), intent(inout) :: q_pack
83 :
84 88 : if(allocated(q_pack%band_packs)) deallocate(q_pack%band_packs)
85 0 : end subroutine t_q_package_free
86 :
87 16 : subroutine t_work_package_init(work_pack, fi, hybdat, mpdata, wp_mpi, jsp, rank, size)
88 : implicit none
89 : class(t_work_package), intent(inout) :: work_pack
90 : type(t_fleurinput), intent(in) :: fi
91 : type(t_hybdat), intent(in) :: hybdat
92 : type(t_mpdata), intent(in) :: mpdata
93 : type(t_hybmpi), intent(in) :: wp_mpi
94 : integer, intent(in) :: rank, size, jsp
95 :
96 16 : call timestart("t_work_package_init")
97 16 : work_pack%rank = rank
98 16 : work_pack%size = size
99 16 : work_pack%submpi = wp_mpi
100 :
101 16 : call split_into_work_packages(work_pack, fi, hybdat, mpdata, jsp)
102 :
103 16 : call timestop("t_work_package_init")
104 16 : end subroutine t_work_package_init
105 :
106 24 : subroutine t_k_package_init(k_pack, fi, hybdat, mpdata, k_wide_mpi, jsp, nk)
107 : implicit none
108 : class(t_k_package), intent(inout) :: k_pack
109 : type(t_fleurinput), intent(in) :: fi
110 : type(t_hybdat), intent(in) :: hybdat
111 : type(t_mpdata), intent(in) :: mpdata
112 : type(t_hybmpi), intent(in) :: k_wide_mpi
113 : type(t_hybmpi) :: q_wide_mpi
114 :
115 : integer, intent(in) :: nk, jsp
116 : integer :: iq, jq, loc_num_qs, i, cnt, n_groups, idx, q_rank, w_cnt
117 24 : integer, allocatable :: loc_qs(:)
118 :
119 24 : n_groups = min(k_wide_mpi%size, fi%kpts%EIBZ(nk)%nkpt)
120 96 : allocate(loc_qs(n_groups), source=0)
121 48 : do w_cnt = 1, n_groups
122 48 : do i = w_cnt, fi%kpts%EIBZ(nk)%nkpt, n_groups
123 88 : loc_qs(w_cnt) = loc_qs(w_cnt) + 1
124 : enddo
125 : enddo
126 :
127 24 : call distribute_mpi(loc_qs, k_wide_mpi, q_wide_mpi, q_rank)
128 :
129 24 : k_pack%submpi = k_wide_mpi
130 24 : k_pack%nk = nk
131 :
132 160 : allocate(k_pack%q_packs(loc_qs(q_rank+1)))
133 24 : cnt = 0
134 112 : do iq = q_rank+1,fi%kpts%EIBZ(nk)%nkpt, n_groups
135 88 : cnt = cnt + 1
136 88 : jq = fi%kpts%EIBZ(nk)%pointer(iq)
137 112 : call k_pack%q_packs(cnt)%init(fi, hybdat, mpdata, q_wide_mpi, jsp, nk, iq, jq)
138 : enddo
139 24 : end subroutine t_k_package_init
140 :
141 88 : subroutine t_q_package_init(q_pack, fi, hybdat, mpdata, q_wide_mpi, jsp, nk, rank, ptr)
142 : implicit none
143 : class(t_q_package), intent(inout) :: q_pack
144 : type(t_fleurinput), intent(in) :: fi
145 : type(t_hybdat), intent(in) :: hybdat
146 : type(t_mpdata), intent(in) :: mpdata
147 : type(t_hybmpi), intent(in) :: q_wide_mpi
148 : integer, intent(in) :: rank, ptr, jsp, nk
149 :
150 : integer :: target_psize
151 : integer :: n_parts, ikqpt, i
152 88 : integer, allocatable :: start_idx(:), psize(:)
153 :
154 88 : q_pack%submpi = q_wide_mpi
155 88 : q_pack%rank = rank
156 88 : q_pack%size = fi%kpts%EIBZ(nk)%nkpt
157 88 : q_pack%ptr = ptr
158 :
159 440 : ikqpt = fi%kpts%get_nk(fi%kpts%to_first_bz(fi%kpts%bkf(:,nk) + fi%kpts%bkf(:,ptr)))
160 88 : n_parts = calc_n_parts(fi, hybdat, mpdata%n_g, q_pack, ikqpt, jsp)
161 :
162 352 : allocate(start_idx(n_parts), psize(n_parts))
163 264 : allocate(q_pack%band_packs(n_parts))
164 :
165 88 : call divide_most_evenly(hybdat%nobd(ikqpt, jsp), n_parts, start_idx, psize)
166 :
167 176 : do i = 1, n_parts
168 88 : call q_pack%band_packs(i)%init(start_idx(i), psize(i), i, n_parts)
169 : enddo
170 88 : end subroutine t_q_package_init
171 :
172 88 : subroutine t_band_package_init(band_pack, start_idx, psize, rank, size)
173 : implicit none
174 : class(t_band_package), intent(inout) :: band_pack
175 : integer, intent(in) :: rank, size, start_idx, psize
176 :
177 88 : band_pack%start_idx = start_idx
178 88 : band_pack%psize = psize
179 88 : band_pack%rank = rank
180 88 : band_pack%size = size
181 0 : end subroutine t_band_package_init
182 :
183 0 : subroutine t_work_package_print(work_pack)
184 : implicit none
185 : class(t_work_package), intent(inout) :: work_pack
186 : integer :: i
187 :
188 0 : write (*,*) "WP (" // int2str(work_pack%rank) // "/" // int2str(work_pack%size) // ") has: "
189 0 : do i = 1,size(work_pack%k_packs)
190 0 : call work_pack%k_packs(i)%print()
191 : enddo
192 0 : end subroutine t_work_package_print
193 :
194 0 : subroutine t_k_package_print(k_pack)
195 : implicit none
196 : class(t_k_package), intent(in) :: k_pack
197 :
198 0 : write (*,*) "kpoint: "
199 0 : write (*,*) "nk = ", k_pack%nk
200 0 : end subroutine t_k_package_print
201 :
202 16 : subroutine split_into_work_packages(work_pack, fi, hybdat, mpdata, jsp)
203 : #ifdef CPP_MPI
204 : use mpi
205 : #endif
206 : implicit none
207 : class(t_work_package), intent(inout) :: work_pack
208 : type(t_fleurinput), intent(in) :: fi
209 : type(t_hybdat), intent(in) :: hybdat
210 : type(t_mpdata), intent(in) :: mpdata
211 : integer, intent(in) :: jsp
212 : integer :: k_cnt, i, ierr
213 :
214 16 : if(work_pack%rank < modulo(fi%kpts%nkpt, work_pack%size)) then
215 8 : work_pack%n_kpacks = ceiling(1.0*fi%kpts%nkpt / work_pack%size)
216 : else
217 8 : work_pack%n_kpacks = floor(1.0*fi%kpts%nkpt / work_pack%size)
218 : endif
219 72 : allocate(work_pack%k_packs(work_pack%n_kpacks))
220 :
221 : #ifdef CPP_MPI
222 16 : call MPI_AllReduce(work_pack%n_kpacks, work_pack%max_kpacks, 1, MPI_INTEGER, MPI_MAX, MPI_COMM_WORLD, ierr)
223 : #else
224 : work_pack%max_kpacks = work_pack%n_kpacks
225 : #endif
226 16 : if(work_pack%n_kpacks /= work_pack%max_kpacks) then
227 8 : call judft_warn("Your parallization is not efficient. Make sure that nkpts%pe == 0 or nkpts <= pe")
228 : endif
229 :
230 :
231 : ! get my k-list
232 16 : k_cnt = 1
233 16 : do i = work_pack%rank+1, fi%kpts%nkpt, work_pack%size
234 24 : work_pack%k_packs(k_cnt)%rank = k_cnt -1
235 24 : work_pack%k_packs(k_cnt)%size = work_pack%n_kpacks
236 :
237 24 : call work_pack%k_packs(k_cnt)%init(fi, hybdat, mpdata, work_pack%submpi, jsp, i)
238 24 : k_cnt = k_cnt + 1
239 : enddo
240 16 : end subroutine split_into_work_packages
241 :
242 :
243 0 : function t_work_package_owner_nk(work_pack, nk) result(owner)
244 : use m_types_hybmpi
245 : implicit none
246 : class(t_work_package), intent(in) :: work_pack
247 : integer, intent(in) :: nk
248 : integer :: owner
249 :
250 0 : owner = modulo(nk-1, work_pack%size)
251 0 : end function t_work_package_owner_nk
252 :
253 0 : function t_work_package_has_nk(work_pack, nk) result(has_nk)
254 : implicit none
255 : class(t_work_package), intent(in) :: work_pack
256 : integer, intent(in) :: nk
257 : logical :: has_nk
258 : integer :: i
259 :
260 0 : has_nk = .false.
261 0 : do i = 1, work_pack%n_kpacks
262 0 : if (work_pack%k_packs(i)%nk == nk) then
263 : has_nk = .True.
264 : exit
265 : endif
266 : enddo
267 0 : end function t_work_package_has_nk
268 :
269 88 : function calc_n_parts(fi, hybdat, n_g, q_pack, ikqpt, jsp) result(n_parts)
270 : implicit none
271 : type(t_fleurinput), intent(in) :: fi
272 : type(t_hybdat), intent(in) :: hybdat
273 : integer, intent(in) :: n_g(:), ikqpt, jsp
274 : class(t_q_package), intent(in) :: q_pack
275 :
276 : integer :: n_parts, me, ierr, ikpt
277 :
278 : integer(8), parameter :: i8_one = 1
279 : integer(8) :: coulomb_size, exch_size, indx_size, nsest_size, target_size, rc_factor
280 : integer(8) :: cprod_size, spmm_peak, max_peak
281 : integer(8) :: max_nbasm, max_nbands, psize
282 :
283 88 : rc_factor = merge(8, 16, fi%sym%invs)
284 792 : max_nbasm = maxval(hybdat%nbasm)
285 1276 : max_nbands = maxval(hybdat%nbands)
286 :
287 88 : target_size = target_memsize(fi, hybdat)
288 88 : coulomb_size = 0.0
289 352 : do ikpt = 1,fi%kpts%nkpt
290 352 : coulomb_size = max(int(mtir_size(fi, n_g, ikpt),kind=8)**2, coulomb_size)
291 : enddo
292 : ! size in byte
293 88 : coulomb_size = rc_factor * coulomb_size
294 1276 : exch_size = rc_factor * maxval(i8_one*hybdat%nbands)**2
295 1276 : indx_size = 4 * maxval(i8_one*hybdat%nbands)**2
296 1276 : nsest_size = 4 * maxval(i8_one*hybdat%nbands)
297 :
298 1276 : psize = maxval(hybdat%nobd)
299 88 : do while(psize > 1)
300 88 : cprod_size = max_nbasm * max_nbands * psize * rc_factor
301 :
302 88 : spmm_peak = 2*cprod_size + coulomb_size + exch_size + indx_size + nsest_size
303 :
304 176 : max_peak = maxval([spmm_peak])
305 :
306 88 : if(max_peak <= target_size) then
307 : exit
308 : endif
309 88 : psize = psize - 1
310 : enddo
311 :
312 1276 : n_parts = ceiling(1.0*maxval(hybdat%nobd)/psize)
313 88 : do while(mod(n_parts, q_pack%submpi%size) /= 0)
314 0 : n_parts = n_parts + 1
315 : enddo
316 :
317 88 : if(n_parts > hybdat%nobd(ikqpt, jsp)) then
318 0 : write (*,*) "too many parts... reducing to nobd"
319 0 : n_parts = hybdat%nobd(ikqpt, jsp)
320 : endif
321 : #ifdef CPP_MPI
322 88 : call MPI_COMM_RANK(MPI_COMM_WORLD, me, ierr)
323 : #else
324 : me = 0
325 : #endif
326 : !if(me == 0) write (*,*) "psize: " // int2str(psize) // " max_peak: " // int2str(max_peak) // " nparts: " // int2str(n_parts)
327 88 : end function calc_n_parts
328 :
329 : integer(8) function target_memsize(fi, hybdat)
330 :
331 : implicit none
332 : type(t_fleurinput), intent(in) :: fi
333 : type(t_hybdat), intent(in) :: hybdat
334 :
335 : #ifdef _OPENACC
336 : integer :: ikpt
337 : integer(C_SIZE_T) :: gpu_mem
338 : real :: coulomb_size, exch_size
339 :
340 : gpu_mem = acc_get_property(0,acc_device_current, acc_property_free_memory)
341 : target_memsize = int(0.75*gpu_mem, kind=8)
342 : #else
343 88 : target_memsize = int(15e9, kind=8) ! 15 Gb
344 : #endif
345 : end function target_memsize
346 0 : end module m_work_package
|