Line data Source code
1 : module m_spmm_inv
2 : use iso_c_binding
3 : use m_spmm
4 : #ifdef _OPENACC
5 : USE cublas
6 : #define CPP_zgemm cublaszgemm
7 : #define CPP_dgemm cublasdgemm
8 : #define CPP_zgemv cublaszgemv
9 : #define CPP_dgemv cublasdgemv
10 : #define CPP_mtir_c mtir_tmp
11 : #define CPP_mtir_r mtir_tmp
12 : #else
13 : #define CPP_zgemm zgemm
14 : #define CPP_dgemm dgemm
15 : #define CPP_zgemv zgemv
16 : #define CPP_dgemv dgemv
17 : #define CPP_mtir_c hybdat%coul(ikpt)%mtir%data_c
18 : #define CPP_mtir_r hybdat%coul(ikpt)%mtir%data_r
19 : #endif
20 : contains
21 66 : subroutine spmm_invs(fi, mpdata, hybdat, ikpt, mat_in, mat_out)
22 : use m_juDFT
23 : use m_types
24 : use m_reorder
25 : use m_calc_l_m_from_lm
26 : implicit none
27 : type(t_fleurinput), intent(in) :: fi
28 : type(t_mpdata), intent(in) :: mpdata
29 : type(t_hybdat), intent(in) :: hybdat
30 : integer, intent(in) :: ikpt
31 : real, intent(inout) :: mat_in(:,:)
32 : real, intent(inout) :: mat_out(:,:)
33 :
34 : integer :: n_vec, i_vec, ibasm, iatom, itype, ieq, l, m, n_size, sz_mtir, sz_hlp, sz_out, sz_mt1
35 : integer :: indx0, indx1, indx2, indx3, n, iatom1, ieq1, ishift, itype1, max_lcut_plus_1
36 : integer :: ishift1, indx4, lm, iat2, it2, l2, idx1_start, idx3_start, iat, irank, ierr
37 66 : real, allocatable :: mt1_tmp(:,:,:,:), mt2_tmp(:,:,:,:), mat_in_line(:), mt3_tmp(:,:,:)
38 66 : integer, allocatable :: new_order(:)
39 : #ifdef _OPENACC
40 : real, allocatable :: mtir_tmp(:,:)
41 : #endif
42 :
43 66 : call timestart("spmm_invs")
44 27922 : mat_in_line = mat_in(hybdat%nbasp + 1, :)
45 :
46 66 : n_vec = size(mat_in, 2)
47 :
48 66 : call timestart("reorder")
49 198 : allocate(new_order(size(mat_in,1)))
50 66 : call forw_order(fi%atoms, fi%hybinp%lcutm1, mpdata%num_radbasfn, new_order)
51 : !$acc data copy(mat_in)
52 66 : call reorder(new_order, mat_in)
53 : !$acc end data
54 66 : call timestop("reorder")
55 :
56 66 : ibasm = calc_ibasm(fi, mpdata)
57 :
58 :
59 66 : call timestart("copies out of hydat")
60 36828 : mt1_tmp = hybdat%coul(ikpt)%mt1_r
61 46156 : mt2_tmp = hybdat%coul(ikpt)%mt2_r
62 66 : if(ikpt == 1 ) then
63 450 : mt3_tmp = hybdat%coul(ikpt)%mt3_r
64 : endif
65 66 : sz_mt1 = size(mt1_tmp,dim=2)
66 66 : sz_hlp = size(mat_in, 1)
67 66 : sz_out = size(mat_out, 1)
68 66 : call timestop("copies out of hydat")
69 :
70 :
71 : !$acc data copyin(mat_in) copy(mat_out)
72 : !$acc data copyin(mt2_tmp)
73 66 : call timestart("0 > ibasm: small matricies")
74 : ! compute vecout for the indices from 0:ibasm
75 : #ifndef _OPENACC
76 : !$OMP PARALLEL DO default(none) schedule(dynamic)&
77 : !$OMP private(iatom, itype, idx1_start, iat2, it2, l2, indx1, idx3_start, indx3)&
78 : !$OMP private(lm, l, m, n_size, i_vec)&
79 : !$OMP lastprivate(indx2)&
80 66 : !$OMP shared(ibasm, mat_in, hybdat, mat_out, fi, mpdata, n_vec, ikpt, mt2_tmp, sz_out, sz_hlp, sz_mt1, mt1_tmp)
81 : #endif
82 : !$acc data copyin(mt1_tmp)
83 : do iatom = 1,fi%atoms%nat
84 : itype = fi%atoms%itype(iatom)
85 :
86 : idx1_start = 0
87 : do iat2 =1,iatom-1
88 : it2 = fi%atoms%itype(iat2)
89 : do l2 = 0, fi%hybinp%lcutm1(it2)
90 : idx1_start = idx1_start + (mpdata%num_radbasfn(l2, it2)-1) * (2*l2+1)
91 : enddo
92 : enddo
93 : indx1 = idx1_start
94 :
95 : idx3_start = ibasm
96 : do iat2 = 1,iatom-1
97 : it2 = fi%atoms%itype(iat2)
98 : idx3_start = idx3_start + (fi%hybinp%lcutm1(it2)+1)**2
99 : enddo
100 : indx3 = idx3_start
101 :
102 : do lm = 1,(fi%hybinp%lcutm1(itype)+1)**2
103 : call calc_l_m_from_lm(lm, l, m)
104 : indx1 = indx1 + 1
105 : indx2 = indx1 + mpdata%num_radbasfn(l, itype) - 2
106 : indx3 = indx3 + 1
107 :
108 : n_size = mpdata%num_radbasfn(l, itype) - 1
109 : !$acc host_data use_device(mt1_tmp, mat_in, mat_out)
110 : call CPP_dgemm("N","N", n_size, n_vec, n_size, 1.0, mt1_tmp(1,1,l,itype), sz_mt1,&
111 : mat_in(indx1,1), sz_hlp, 0.0, mat_out(indx1,1), sz_out)
112 : !$acc end host_data
113 :
114 : !$acc kernels present(mat_out, mat_in, mt2_tmp)
115 : do i_vec = 1, n_vec
116 : mat_out(indx1:indx2,i_vec) = mat_out(indx1:indx2,i_vec) + mat_in(indx3, i_vec) * mt2_tmp(:n_size,m,l,iatom)
117 : enddo
118 : !$acc end kernels
119 : indx1 = indx2
120 : END DO
121 : END DO
122 : !$acc end data
123 : #ifndef _OPENACC
124 : !$OMP END PARALLEL DO
125 : #endif
126 66 : call timestop("0 > ibasm: small matricies")
127 :
128 66 : IF (indx2 /= ibasm) call judft_error('spmm: error counting basis functions')
129 :
130 66 : IF (ikpt == 1) THEN
131 : !$acc data copyin(mt3_tmp, mat_in_line)
132 18 : call timestart("gamma point 1 inv")
133 18 : iatom = 0
134 18 : indx0 = 0
135 : #ifndef _OPENACC
136 : !$OMP parallel do default(none) &
137 : !$OMP private(iatom, itype, ishift, l, indx0, indx1, indx2, indx3, indx4, iatom1, itype1, ishift1, i_vec, n_size)&
138 18 : !$OMP private(max_lcut_plus_1) shared(fi, mpdata, hybdat, mat_out, ibasm, n_vec, ikpt, mat_in, mat_in_line, mt2_tmp, mt3_tmp)
139 : #endif
140 : do iatom = 1, fi%atoms%nat
141 : itype = fi%atoms%itype(iatom)
142 : ishift = sum([((2*l + 1)*(mpdata%num_radbasfn(l, itype) - 1), l=0, fi%hybinp%lcutm1(itype))])
143 : l = 0
144 :
145 : indx0 = 0
146 : do iat = 1,iatom-1
147 : indx0 = indx0 + sum([((2*l + 1)*(mpdata%num_radbasfn(l, fi%atoms%itype(iat)) - 1), l=0, fi%hybinp%lcutm1(fi%atoms%itype(iat)))])
148 : enddo
149 :
150 : indx1 = indx0 + 1
151 : indx2 = indx1 + mpdata%num_radbasfn(l, itype) - 2
152 :
153 : indx3 = ibasm
154 : n_size = mpdata%num_radbasfn(l, itype) - 1
155 : do iatom1 = 1,fi%atoms%nat
156 : itype1 = fi%atoms%itype(iatom1)
157 : ishift1 = (fi%hybinp%lcutm1(itype1) + 1)**2
158 : indx4 = indx3 + 1
159 : IF (iatom /= iatom1) then
160 : !$acc kernels present(mat_out, mat_in, mt3_tmp)
161 : do i_vec = 1, n_vec
162 : mat_out(indx1:indx2, i_vec) = mat_out(indx1:indx2, i_vec) &
163 : + mt3_tmp(:n_size, iatom1, iatom)*mat_in(indx4, i_vec)
164 : enddo
165 : !$acc end kernels
166 : endif
167 : indx3 = indx3 + ishift1
168 : END DO
169 :
170 : IF (indx3 /= hybdat%nbasp) call judft_error('spmvec: error counting index indx3')
171 :
172 : n_size = mpdata%num_radbasfn(l, itype) - 1
173 : max_lcut_plus_1 = maxval(fi%hybinp%lcutm1) + 1
174 : !$acc kernels present(mat_out, mat_in_line, mt2_tmp)
175 : do i_vec = 1, n_vec
176 : mat_out(indx1:indx2, i_vec) = mat_out(indx1:indx2, i_vec) &
177 : + mt2_tmp(:n_size, 0, max_lcut_plus_1, iatom)*mat_in_line(i_vec)
178 : enddo
179 : !$acc end kernels
180 : END DO
181 :
182 : #ifndef _OPENACC
183 : !$OMP end parallel do
184 : #endif
185 18 : call timestop("gamma point 1 inv")
186 : !$acc end data !mt3_tmp
187 : END IF
188 :
189 : ! compute vecout for the index-range from ibasm+1:nbasm
190 :
191 : indx1 = sum([(((2*l + 1)*fi%atoms%neq(itype), l=0, fi%hybinp%lcutm1(itype)), &
192 1034 : itype=1, fi%atoms%ntype)]) + mpdata%n_g(ikpt)
193 :
194 66 : call timestart("ibasm+1 -> dgemm")
195 : #ifdef _OPENACC
196 : call timestart("copy mtir_tmp")
197 : allocate(mtir_tmp(hybdat%coul(ikpt)%mtir%matsize1, hybdat%coul(ikpt)%mtir%matsize2), stat=ierr)
198 : if(ierr /= 0) call judft_error("can't alloc mtir_tmp")
199 : call dlacpy("N", size(mtir_tmp,1), size(mtir_tmp,2), hybdat%coul(ikpt)%mtir%data_r, &
200 : size(hybdat%coul(ikpt)%mtir%data_r,1), mtir_tmp, size(mtir_tmp,1))
201 : call timestop("copy mtir_tmp")
202 : #endif
203 :
204 :
205 66 : sz_mtir = size(CPP_mtir_r, 1)
206 : !$acc data copyin(CPP_mtir_r)
207 : !$acc host_data use_device(CPP_mtir_r, mat_in, mat_out)
208 : call CPP_dgemm("N", "N", indx1, n_vec, indx1, 1.0, CPP_mtir_r, sz_mtir, &
209 66 : mat_in(ibasm + 1, 1), sz_hlp, 0.0, mat_out(ibasm + 1, 1), sz_out)
210 : !$acc end host_data
211 : !$acc end data ! CPP_mtir_r
212 : #ifdef _OPENACC
213 : deallocate(mtir_tmp)
214 : #endif
215 66 : call timestop("ibasm+1 -> dgemm")
216 :
217 66 : call timestart("dot prod")
218 66 : iatom = 0
219 66 : indx1 = ibasm; indx2 = 0; indx3 = 0
220 154 : DO itype = 1, fi%atoms%ntype
221 242 : DO ieq = 1, fi%atoms%neq(itype)
222 88 : iatom = iatom + 1
223 616 : DO l = 0, fi%hybinp%lcutm1(itype)
224 440 : n = mpdata%num_radbasfn(l, itype)
225 2728 : DO m = -l, l
226 2200 : indx1 = indx1 + 1
227 2200 : indx2 = indx2 + 1
228 2200 : indx3 = indx3 + n - 1
229 :
230 : !$acc host_data use_device(mat_in, mt2_tmp, mat_out)
231 : call CPP_dgemv("T", n-1, n_vec, 1.0, mat_in(indx2,1), sz_hlp, mt2_tmp(1, m, l, iatom), 1, &
232 2200 : 1.0, mat_out(indx1,1), sz_out)
233 : !$acc end host_data
234 :
235 2640 : indx2 = indx3
236 : END DO
237 :
238 : END DO
239 : END DO
240 : END DO
241 : !$acc end data ! mt2_tmp
242 66 : call timestop("dot prod")
243 : !$acc end data ! mat_in, mat_out
244 :
245 :
246 :
247 66 : IF (ikpt == 1) THEN
248 18 : call timestart("gamma point 2 inv")
249 18 : iatom = 0
250 18 : indx0 = 0
251 42 : DO itype = 1, fi%atoms%ntype
252 288 : ishift = sum([((2*l + 1)*(mpdata%num_radbasfn(l, itype) - 1), l=0, fi%hybinp%lcutm1(itype))])
253 66 : DO ieq = 1, fi%atoms%neq(itype)
254 24 : iatom = iatom + 1
255 24 : indx1 = indx0 + 1
256 24 : indx2 = indx1 + mpdata%num_radbasfn(0, itype) - 2
257 24 : n_size = mpdata%num_radbasfn(0, itype) - 1
258 9490 : do i_vec = 1, n_vec
259 : mat_out(hybdat%nbasp + 1, i_vec) = mat_out(hybdat%nbasp + 1, i_vec) &
260 : + dot_product(mt2_tmp(:n_size, 0, maxval(fi%hybinp%lcutm1) + 1, iatom), &
261 102278 : mat_in(indx1:indx2, i_vec))
262 : enddo
263 48 : indx0 = indx0 + ishift
264 : END DO
265 : END DO
266 :
267 : !$OMP PARALLEL DO default(none) schedule(dynamic)&
268 : !$OMP private(iatom, itype, indx1, indx2, itype1, ishift1) &
269 : !$OMP private(ieq1, iatom1, indx3, indx4, n_size, i_vec) &
270 18 : !$OMP shared(fi, n_vec, mat_out, ibasm, mpdata, mat_in, hybdat, ikpt, mt3_tmp)
271 : do iatom = 1, fi%atoms%nat
272 : itype = fi%atoms%itype(iatom)
273 : indx1 = ibasm + sum([((fi%hybinp%lcutm1(fi%atoms%itype(iat)) + 1)**2, iat=1,iatom-1)]) + 1
274 :
275 : iatom1 = 0
276 : indx2 = 0
277 : DO itype1 = 1, fi%atoms%ntype
278 : ishift1 = sum([((2*l + 1)*(mpdata%num_radbasfn(l, itype1) - 1), l=0, fi%hybinp%lcutm1(itype1))])
279 : DO ieq1 = 1, fi%atoms%neq(itype1)
280 : iatom1 = iatom1 + 1
281 : IF (iatom1 == iatom) CYCLE
282 :
283 : indx3 = indx2 + (ieq1 - 1)*ishift1 + 1
284 : indx4 = indx3 + mpdata%num_radbasfn(0, itype1) - 2
285 :
286 : n_size = mpdata%num_radbasfn(0, itype1) - 1
287 : do i_vec = 1, n_vec
288 : mat_out(indx1, i_vec) = mat_out(indx1, i_vec) &
289 : + dot_product(mt3_tmp(:n_size, iatom, iatom1), &
290 : mat_in(indx3:indx4, i_vec))
291 : enddo
292 : END DO
293 : indx2 = indx2 + fi%atoms%neq(itype1)*ishift1
294 : END DO
295 : END DO
296 : !$OMP END PARALLEL DO
297 18 : call timestop("gamma point 2 inv")
298 :
299 : END IF
300 :
301 66 : call timestart("reorder")
302 66 : call back_order(fi%atoms, fi%hybinp%lcutm1, mpdata%num_radbasfn, new_order)
303 : !$acc data copy(mat_in, mat_out)
304 66 : call reorder(new_order, mat_in)
305 66 : call reorder(new_order, mat_out)
306 : !$acc end data
307 66 : call timestop("reorder")
308 66 : call timestop("spmm_invs")
309 66 : end subroutine spmm_invs
310 : end module m_spmm_inv
|