Line data Source code
1 : module m_spmm_noinv
2 : use iso_c_binding
3 : use m_spmm
4 : #ifdef _OPENACC
5 : USE cublas
6 : #define CPP_zgemm cublaszgemm
7 : #define CPP_dgemm cublasdgemm
8 : #define CPP_zgemv cublaszgemv
9 : #define CPP_dgemv cublasdgemv
10 : #define CPP_mtir_c mtir_tmp
11 : #define CPP_mtir_r mtir_tmp
12 : #else
13 : #define CPP_zgemm zgemm
14 : #define CPP_dgemm dgemm
15 : #define CPP_zgemv zgemv
16 : #define CPP_dgemv dgemv
17 : #define CPP_mtir_c hybdat%coul(ikpt)%mtir%data_c
18 : #define CPP_mtir_r hybdat%coul(ikpt)%mtir%data_r
19 : #endif
20 : contains
21 22 : subroutine spmm_noinvs(fi, mpdata, hybdat, ikpt, conjg_mtir, mat_in, mat_out)
22 : use m_juDFT
23 : use m_types
24 : use m_reorder
25 : use m_constants
26 : use m_calc_l_m_from_lm
27 :
28 : implicit none
29 : type(t_fleurinput), intent(in) :: fi
30 : type(t_mpdata), intent(in) :: mpdata
31 : type(t_hybdat), intent(inout) :: hybdat
32 : integer, intent(in) :: ikpt
33 : logical, intent(in) :: conjg_mtir
34 : complex, intent(inout) :: mat_in(:,:)
35 : complex, intent(inout) :: mat_out(:,:)
36 :
37 : integer :: n_vec, i_vec, ibasm, iatom, itype, ieq, l, m, n_size
38 : integer :: indx0, indx1, indx2, indx3, n, iatom1, ieq1, ishift, itype1
39 : integer :: ishift1, indx4, lm, idx1_start, idx3_start, ld_mt1_tmp
40 : integer :: iat2, it2, l2, iat, ierr, irank, i, sz_mtir, sz_in, sz_out, max_l_cut
41 : integer(C_SIZE_T) :: free_mem, tot_mem
42 22 : integer, allocatable :: new_order(:)
43 22 : complex, allocatable :: mt1_tmp(:,:,:,:), mt2_tmp(:,:,:,:), mt3_tmp(:,:,:), mat_in_line(:)
44 : #ifdef _OPENACC
45 : complex, allocatable :: mtir_tmp(:,:)
46 : #endif
47 :
48 22 : call timestart("spmm_noinvs")
49 22 : call timestart("copy mt2_c")
50 21824 : mt2_tmp = hybdat%coul(ikpt)%mt2_c
51 22 : call timestop("copy mt2_c")
52 :
53 22 : sz_in = size(mat_in, 1)
54 22 : sz_out = size(mat_out, 1)
55 22 : n_vec = size(mat_in, 2)
56 :
57 66 : allocate(mat_in_line(size(mat_in,2)))
58 :
59 22 : call timestart("copyin gpu")
60 : !$acc data copyin(mt2_tmp) copy(mat_in) copyout(mat_out) create(mat_in_line)
61 : !$acc wait
62 22 : call timestop("copyin gpu")
63 :
64 : !$acc kernels present(mat_in_line, mat_in)
65 15136 : mat_in_line = mat_in(hybdat%nbasp + 1, :)
66 : !$acc end kernels
67 :
68 22 : call timestart("reorder forw")
69 66 : allocate(new_order(size(mat_in,1)))
70 22 : call forw_order(fi%atoms, fi%hybinp%lcutm1, mpdata%num_radbasfn, new_order)
71 22 : call reorder(new_order, mat_in)
72 22 : call timestop("reorder forw")
73 :
74 22 : ibasm = calc_ibasm(fi, mpdata)
75 :
76 : ! compute vecout for the indices from 0:ibasm
77 22 : call timestart("0 > ibasm: small matricies")
78 22 : call timestart("alloc&cpy mt1_tmp")
79 132 : allocate(mt1_tmp, mold=hybdat%coul(ikpt)%mt1_c, stat=ierr)
80 22 : ld_mt1_tmp = size(mt1_tmp,dim=2) ! special multiplication
81 22 : if(ierr /= 0) call judft_error("can't alloc mt1_tmp")
82 110 : call zcopy(size(mt1_tmp), hybdat%coul(ikpt)%mt1_c, 1, mt1_tmp, 1)
83 22 : call timestop("alloc&cpy mt1_tmp")
84 :
85 :
86 : !$acc kernels present(mat_out)
87 7792982 : mat_out = cmplx_0
88 : !$acc end kernels
89 :
90 : !$acc data copyin(mt1_tmp)
91 : #ifndef _OPENACC
92 : !$OMP PARALLEL DO default(none) schedule(dynamic)&
93 : !$OMP private(iatom, itype, idx1_start, iat2, it2, l2, indx1, idx3_start, indx3)&
94 : !$OMP private(lm, l, m, n_size, i_vec)&
95 : !$OMP lastprivate(indx2)&
96 22 : !$OMP shared(ibasm, mat_in, hybdat, mat_out, fi, mpdata, n_vec, ikpt, ld_mt1_tmp, sz_out, sz_in, mt1_tmp, mt2_tmp)
97 : #endif
98 : do iatom = 1, fi%atoms%nat
99 : itype = fi%atoms%itype(iatom)
100 :
101 : idx1_start = 0
102 : do iat2 =1,iatom-1
103 : it2 = fi%atoms%itype(iat2)
104 : do l2 = 0, fi%hybinp%lcutm1(it2)
105 : idx1_start = idx1_start + (mpdata%num_radbasfn(l2, it2)-1) * (2*l2+1)
106 : enddo
107 : enddo
108 : indx1 = idx1_start
109 :
110 : idx3_start = ibasm
111 : do iat2 = 1,iatom-1
112 : it2 = fi%atoms%itype(iat2)
113 : idx3_start = idx3_start + (fi%hybinp%lcutm1(it2)+1)**2
114 : enddo
115 : indx3 = idx3_start
116 : do lm = 1, (fi%hybinp%lcutm1(itype) + 1)**2
117 : call calc_l_m_from_lm(lm, l, m)
118 : indx1 = indx1 + 1
119 : indx2 = indx1 + mpdata%num_radbasfn(l, itype) - 2
120 : indx3 = indx3 + 1
121 :
122 : n_size = mpdata%num_radbasfn(l, itype) - 1
123 :
124 : !$acc host_data use_device(mt1_tmp, mat_in, mat_out)
125 : call CPP_zgemm("N","N", n_size, n_vec, n_size, cmplx_1, mt1_tmp(1,1,l,itype), ld_mt1_tmp,&
126 : mat_in(indx1,1), sz_in, cmplx_0, mat_out(indx1,1), sz_out)
127 : !$acc end host_data
128 :
129 : !$acc kernels present(mat_out, mt2_tmp, mat_in)
130 : do i_vec = 1, n_vec
131 : do i = 0, indx2-indx1
132 : mat_out(indx1+i,i_vec) = mat_out(indx1+i,i_vec) + mt2_tmp(i+1, m, l, iatom) * mat_in(indx3, i_vec)
133 : enddo
134 : enddo
135 : !$acc end kernels
136 :
137 : indx1 = indx2
138 : END DO
139 : END DO
140 : #ifndef _OPENACC
141 : !$OMP END PARALLEL DO
142 : #endif
143 : !$acc end data
144 : !$acc wait
145 22 : deallocate(mt1_tmp)
146 22 : call timestop("0 > ibasm: small matricies")
147 :
148 22 : IF (indx2 /= ibasm) call judft_error('spmvec: error counting basis functions')
149 :
150 22 : IF (ikpt == 1) THEN
151 6 : call timestart("gamma point 1 noinv")
152 6 : call timestart("cpy mt3_tmp")
153 30 : allocate(mt3_tmp, mold=hybdat%coul(ikpt)%mt3_c, stat=ierr)
154 6 : if(ierr /= 0 ) call judft_error("can't alloc mt3_tmp")
155 24 : call zcopy(size(mt3_tmp), hybdat%coul(ikpt)%mt3_c, 1, mt3_tmp, 1)
156 6 : call timestop("cpy mt3_tmp")
157 :
158 18 : max_l_cut = maxval(fi%hybinp%lcutm1)
159 : #ifdef _OPENACC
160 : !$acc data copyin(mt3_tmp)
161 : #else
162 : !$OMP PARALLEL DO default(none) schedule(dynamic)&
163 : !$OMP private(iatom, itype, indx0, l, m, indx1, indx2, iatom1, indx3) &
164 : !$OMP private(indx4, i_vec, n_size, itype1, ishift1,ieq1) &
165 6 : !$OMP shared(fi, n_vec, mpdata, hybdat, ibasm, mat_out, mat_in, ikpt, mat_in_line, mt3_tmp, mt2_tmp, max_l_cut)
166 : #endif
167 : do iatom = 1,fi%atoms%nat
168 : itype = fi%atoms%itype(iatom)
169 : indx0 = 0
170 : do iat = 1,iatom-1
171 : indx0 = indx0 + sum([((2*l + 1)*(mpdata%num_radbasfn(l, fi%atoms%itype(iat)) - 1), l=0, fi%hybinp%lcutm1(fi%atoms%itype(iat)))])
172 : enddo
173 : l = 0
174 : m = 0
175 :
176 : indx1 = indx0 + 1
177 : indx2 = indx1 + mpdata%num_radbasfn(l, itype) - 2
178 :
179 : iatom1 = 0
180 : indx3 = ibasm
181 : n_size = mpdata%num_radbasfn(l, itype) - 1
182 : DO itype1 = 1, fi%atoms%ntype
183 : ishift1 = (fi%hybinp%lcutm1(itype1) + 1)**2
184 : DO ieq1 = 1, fi%atoms%neq(itype1)
185 : iatom1 = iatom1 + 1
186 : indx4 = indx3 + (ieq1 - 1)*ishift1 + 1
187 : if (iatom /= iatom1) then
188 : !$acc kernels present(mat_out, mt3_tmp, mat_in) default(none)
189 : do i_vec = 1, n_vec
190 : mat_out(indx1:indx2, i_vec) = mat_out(indx1:indx2, i_vec) + mt3_tmp(:n_size, iatom1, iatom)*mat_in(indx4, i_vec)
191 : enddo
192 : !$acc end kernels
193 : endif
194 : END DO
195 : indx3 = indx3 + fi%atoms%neq(itype1)*ishift1
196 : END DO
197 : IF (indx3 /= hybdat%nbasp) call judft_error('spmvec: error counting index indx3')
198 :
199 : n_size = mpdata%num_radbasfn(l, itype) - 1
200 : !$acc kernels present(mat_out, mt2_tmp, mat_in_line) default(none)
201 : do i_vec = 1, n_vec
202 : mat_out(indx1:indx2, i_vec) = mat_out(indx1:indx2, i_vec) + mt2_tmp(:n_size, 0, max_l_cut + 1, iatom)*mat_in_line(i_vec)
203 : enddo
204 : !$acc end kernels
205 : END DO
206 : #ifdef _OPENACC
207 : !$acc end data !(mt3_tmp)
208 : #else
209 : !$OMP END PARALLEL DO
210 : #endif
211 6 : call timestop("gamma point 1 noinv")
212 : END IF
213 : ! compute vecout for the index-range from ibasm+1:nbasm
214 :
215 22 : call timestart("calc indx1")
216 : indx1 = sum([(((2*l + 1)*fi%atoms%neq(itype), l=0, fi%hybinp%lcutm1(itype)), &
217 506 : itype=1, fi%atoms%ntype)]) + mpdata%n_g(ikpt)
218 22 : call timestop("calc indx1")
219 :
220 : #ifdef _OPENACC
221 : call timestart("copy mtir_tmp")
222 : allocate(mtir_tmp(hybdat%coul(ikpt)%mtir%matsize1, hybdat%coul(ikpt)%mtir%matsize2), stat=ierr)
223 : if(ierr /= 0) call judft_error("can't alloc mtir_tmp")
224 : call zlacpy("N", size(mtir_tmp,1), size(mtir_tmp,2), hybdat%coul(ikpt)%mtir%data_c, &
225 : size(hybdat%coul(ikpt)%mtir%data_c,1), mtir_tmp, size(mtir_tmp,1))
226 : call timestop("copy mtir_tmp")
227 : #endif
228 :
229 22 : call timestart("acc kernels")
230 : !$acc enter data copyin(mtir_tmp)
231 22 : if(conjg_mtir) then
232 : !$acc kernels present(mtir_tmp)
233 0 : CPP_mtir_c = conjg(CPP_mtir_c)
234 : !$acc end kernels
235 : endif
236 22 : call timestop("acc kernels")
237 :
238 22 : call timestart("ibasm+1->nbasm: zgemm")
239 22 : sz_mtir = size(CPP_mtir_c,1)
240 :
241 : !$acc host_data use_device(CPP_mtir_c, mat_in, mat_out)
242 : call CPP_zgemm("N", "N", indx1, n_vec, indx1, cmplx_1, CPP_mtir_c, sz_mtir, &
243 22 : mat_in(ibasm + 1, 1), sz_in, cmplx_0, mat_out(ibasm + 1, 1), sz_out)
244 : !$acc end host_data
245 : !$acc exit data delete(CPP_mtir_c)
246 : #ifdef _OPENACC
247 : deallocate(mtir_tmp)
248 : #else
249 22 : if(conjg_mtir) then
250 0 : CPP_mtir_c = conjg(CPP_mtir_c)
251 : endif
252 : #endif
253 : !$acc wait
254 22 : call timestop("ibasm+1->nbasm: zgemm")
255 :
256 22 : call timestart("dot prod")
257 : !$acc kernels present(mt2_tmp)
258 21714 : mt2_tmp = conjg(mt2_tmp)
259 : !$acc end kernels
260 :
261 22 : iatom = 0
262 22 : indx1 = ibasm; indx2 = 0; indx3 = 0
263 66 : DO itype = 1, fi%atoms%ntype
264 110 : DO ieq = 1, fi%atoms%neq(itype)
265 44 : iatom = iatom + 1
266 308 : DO l = 0, fi%hybinp%lcutm1(itype)
267 220 : n = mpdata%num_radbasfn(l, itype)
268 1364 : DO m = -l, l
269 1100 : indx1 = indx1 + 1
270 1100 : indx2 = indx2 + 1
271 1100 : indx3 = indx3 + n - 1
272 :
273 : !$acc host_data use_device(mat_in, mt2_tmp, mat_out)
274 : call CPP_zgemv("T", n-1, n_vec, cmplx_1, mat_in(indx2,1), sz_in, mt2_tmp(1, m, l, iatom), 1, &
275 1100 : cmplx_1, mat_out(indx1,1), sz_out)
276 : !$acc end host_data
277 :
278 1320 : indx2 = indx3
279 : END DO
280 :
281 : END DO
282 : END DO
283 : END DO
284 22 : call timestop("dot prod")
285 :
286 22 : IF (ikpt == 1) THEN
287 6 : call timestart("gamma point 2 noinv")
288 6 : iatom = 0
289 6 : indx0 = 0
290 :
291 18 : max_l_cut = maxval(fi%hybinp%lcutm1)
292 18 : DO itype = 1, fi%atoms%ntype
293 144 : ishift = sum([((2*l + 1)*(mpdata%num_radbasfn(l, itype) - 1), l=0, fi%hybinp%lcutm1(itype))])
294 30 : DO ieq = 1, fi%atoms%neq(itype)
295 12 : iatom = iatom + 1
296 12 : indx1 = indx0 + 1
297 12 : indx2 = indx1 + mpdata%num_radbasfn(0, itype) - 2
298 12 : n_size = mpdata%num_radbasfn(0, itype) - 1
299 :
300 : !$acc host_data use_device(mat_in, mt2_tmp, mat_out)
301 : call CPP_zgemv("T", n_size, n_vec, cmplx_1, mat_in(indx1,1), sz_in, &
302 12 : mt2_tmp(1,0,max_l_cut + 1, iatom), 1, cmplx_1, mat_out(hybdat%nbasp + 1, 1), sz_out)
303 : !$acc end host_data
304 24 : indx0 = indx0 + ishift
305 : END DO
306 : END DO
307 :
308 : !$acc data copyin(mt3_tmp)
309 : !$acc kernels present(mt3_tmp)
310 234 : mt3_tmp = conjg(mt3_tmp)
311 : !$acc end kernels
312 : #ifndef _OPENACC
313 : !$OMP PARALLEL DO default(none) &
314 : !$OMP private(iatom, itype, indx1, iatom1, indx2, itype1, ishift1, indx3, indx4, n_size) &
315 6 : !$OMP shared(fi, mpdata, hybdat,mat_out, mat_in, ibasm, ikpt, n_vec, mt3_tmp, sz_out, sz_in)
316 : #endif
317 : do iatom = 1, fi%atoms%nat
318 : itype = fi%atoms%itype(iatom)
319 : indx1 = ibasm + sum([((fi%hybinp%lcutm1(fi%atoms%itype(iat)) + 1)**2, iat=1,iatom-1)]) + 1
320 : iatom1 = 0
321 : indx2 = 0
322 : DO itype1 = 1, fi%atoms%ntype
323 : ishift1 = sum([((2*l + 1)*(mpdata%num_radbasfn(l, itype1) - 1), l=0, fi%hybinp%lcutm1(itype1))])
324 : DO ieq1 = 1, fi%atoms%neq(itype1)
325 : iatom1 = iatom1 + 1
326 : IF (iatom1 /= iatom) then
327 : indx3 = indx2 + (ieq1 - 1)*ishift1 + 1
328 : indx4 = indx3 + mpdata%num_radbasfn(0, itype1) - 2
329 : n_size = mpdata%num_radbasfn(0, itype1) - 1
330 :
331 : !$acc host_data use_device(mat_in, mt3_tmp, mat_out)
332 : call CPP_zgemv("T", n_size, n_vec, cmplx_1, mat_in(indx3,1), sz_in, mt3_tmp(1, iatom, iatom1), 1, &
333 : cmplx_1, mat_out(indx1,1), sz_out)
334 : !$acc end host_data
335 : endif
336 : END DO
337 : indx2 = indx2 + fi%atoms%neq(itype1)*ishift1
338 : END DO
339 : END DO
340 : #ifndef _OPENACC
341 : !$OMP END PARALLEL DO
342 : #endif
343 : !$acc end data !(mt3_tmp)
344 6 : deallocate(mt3_tmp)
345 6 : call timestop("gamma point 2 noinv")
346 : END IF
347 :
348 22 : call timestart("reorder back")
349 22 : call back_order(fi%atoms, fi%hybinp%lcutm1, mpdata%num_radbasfn, new_order)
350 22 : call reorder(new_order, mat_in)
351 22 : call reorder(new_order, mat_out)
352 22 : call timestop("reorder back")
353 :
354 22 : call timestart("copyout")
355 : !$acc end data !mt2_tmp, mat_in, mat_out
356 : !$acc wait
357 22 : call timestop("copyout")
358 22 : call timestop("spmm_noinvs")
359 22 : end subroutine spmm_noinvs
360 : end module m_spmm_noinv
|