Line data Source code
1 : module m_wavefproducts_aux
2 : use m_types_fftGrid
3 : use m_types
4 : CONTAINS
5 88 : subroutine wavefproducts_IS_FFT(fi, ik, iq, g_t, jsp, bandoi, bandof, mpdata, hybdat, lapw, stars, nococonv, &
6 88 : ikqpt, z_k, z_kqpt_p, c_phase_kqpt, cprod)
7 : !$ use omp_lib
8 : use m_constants
9 : use m_judft
10 : use m_fft_interface
11 : use m_io_hybrid
12 : use m_juDFT
13 : #ifdef CPP_MPI
14 : use mpi
15 : #endif
16 : implicit NONE
17 : type(t_fleurinput), intent(in) :: fi
18 : TYPE(t_nococonv), INTENT(IN) :: nococonv
19 : TYPE(t_lapw), INTENT(IN) :: lapw
20 : TYPE(t_mpdata), intent(in) :: mpdata
21 : TYPE(t_hybdat), INTENT(INOUT) :: hybdat
22 : type(t_stars), intent(in) :: stars
23 : type(t_mat), intent(in) :: z_k
24 : type(t_mat), intent(inout) :: z_kqpt_p, cprod
25 : ! - scalars -
26 : INTEGER, INTENT(IN) :: ik, iq, jsp, g_t(3), bandoi, bandof
27 : INTEGER, INTENT(IN) :: ikqpt
28 : ! - arrays -
29 : complex, intent(inout) :: c_phase_kqpt(hybdat%nbands(ikqpt,jsp))
30 :
31 88 : complex, allocatable :: prod(:,:), psi_k(:, :), psi_kqpt(:,:)
32 :
33 88 : type(t_mat) :: z_kqpt
34 88 : type(t_lapw) :: lapw_ikqpt
35 704 : type(t_fft) :: fft, wavef2rs_fft
36 1144 : type(t_fftgrid) :: stepf, grid
37 :
38 :
39 : integer, parameter :: blocksize = 512
40 : integer :: g(3), igptm, iob, n_omp, j, jstart, loop_length
41 : integer :: ok, nbasfcn, psize, iband, ierr, i, max_igptm
42 176 : integer, allocatable :: band_list(:), g_ptr(:)
43 : real :: inv_vol, gcutoff, max_imag
44 :
45 : logical :: real_warned
46 :
47 88 : real_warned = .False.
48 :
49 88 : call timestart("wavef_IS_FFT")
50 88 : max_igptm = mpdata%n_g(iq)
51 :
52 88 : gcutoff = (2*fi%input%rkmax + fi%mpinp%g_cutoff) * fi%hybinp%fftcut
53 88 : inv_vol = 1/sqrt(fi%cell%omtil)
54 88 : psize = bandof - bandoi + 1
55 : !this is for the exact result. Christoph recommend 2*gmax+gcutm for later
56 88 : if (2*fi%input%rkmax + fi%mpinp%g_cutoff > fi%input%gmax) then
57 44 : write (*, *) "WARNING: not accurate enough: 2*kmax+gcutm >= fi%input%gmax"
58 : !call juDFT_error("not accurate enough: 2*kmax+gcutm >= fi%input%gmax")
59 : endif
60 :
61 88 : call stepf%init(fi%cell, fi%sym, gcutoff)
62 : block
63 : type(t_cell) :: cell !unused
64 88 : call stepf%putfieldOnGrid(stars, stars%ustep)
65 : end block
66 88 : call fft%init(stepf%dimensions, .false., batch_size=1, l_gpu=.True.)
67 : !$acc data copyin(stepf, stepf%grid, stepf%gridlength)
68 : ! after we transform psi_k*stepf*psi_kqpt back to
69 : ! G-space we have to divide by stepf%gridLength. We do this now
70 :
71 : !$acc kernels default(none) present(stepf, stepf%grid, stepf%gridLength)
72 904574 : stepf%grid = stepf%grid * inv_vol / stepf%gridLength
73 : !$acc end kernels
74 :
75 88 : call fft%exec(stepf%grid)
76 88 : call fft%free()
77 :
78 88 : call setup_g_ptr(mpdata, stepf, g_t, iq, g_ptr)
79 :
80 88 : CALL lapw_ikqpt%init(fi, nococonv, ikqpt)
81 :
82 88 : nbasfcn = lapw_ikqpt%hyb_num_bas_fun(fi)
83 88 : call z_kqpt%alloc(z_k%l_real, nbasfcn, psize)
84 88 : call z_kqpt_p%init(z_kqpt)
85 :
86 2702 : band_list = [(i, i=bandoi, bandof)]
87 : call read_z(fi%atoms, fi%cell, hybdat, fi%kpts, fi%sym, fi%noco, nococonv, fi%input, ikqpt, jsp, z_kqpt, &
88 88 : c_phase=c_phase_kqpt, parent_z=z_kqpt_p, list=band_list)
89 : #ifdef CPP_MPI
90 88 : call timestart("read_z barrier")
91 88 : call MPI_Barrier(MPI_COMM_WORLD, ierr)
92 88 : hybdat%max_q = hybdat%max_q - 1
93 88 : call timestop("read_z barrier")
94 : #endif
95 :
96 352 : allocate(psi_kqpt(0:stepf%gridLength-1, psize), stat=ierr)
97 88 : if(ierr /= 0) call juDFT_error("can't alloc psi_kqpt")
98 :
99 : !$acc data create(psi_kqpt)
100 88 : call grid%init(fi%cell, fi%sym, gcutoff)
101 88 : call wavef2rs_fft%init(grid%dimensions, .false., batch_size=psize, l_gpu=.True.)
102 : !$acc data copyin(z_kqpt, z_kqpt%l_real, z_kqpt%data_r, z_kqpt%data_c, lapw_ikqpt, lapw_ikqpt%nv, lapw_ikqpt%gvec,&
103 : !$acc jsp, bandoi, bandof, psize, grid, grid%dimensions)
104 88 : call timestart("1st wavef2rs")
105 88 : call wavef2rs(fi, lapw_ikqpt, z_kqpt, gcutoff, 1, psize, jsp, grid, wavef2rs_fft, psi_kqpt)
106 88 : call timestop("1st wavef2rs")
107 :
108 : !$acc kernels default(none) present(psi_kqpt, stepf, stepf%grid)
109 930 : do iob = 1, psize
110 9096578 : psi_kqpt(:,iob) = psi_kqpt(:,iob) * stepf%grid
111 : enddo
112 : !$acc end kernels
113 : !$acc end data
114 88 : call wavef2rs_fft%free()
115 : !call grid%free()
116 :
117 88 : call timestart("Big OMP loop")
118 : #ifndef _OPENACC
119 : !$OMP PARALLEL default(none) &
120 : !$OMP private(iband, iob, g, igptm, prod, psi_k, ok, fft, wavef2rs_fft, max_imag, grid) &
121 : !$OMP shared(hybdat, psi_kqpt, cprod, mpdata, iq, g_t, psize, gcutoff, max_igptm)&
122 88 : !$OMP shared(jsp, z_k, stars, lapw, fi, inv_vol, ik, real_warned, n_omp, bandoi, stepf, g_ptr)
123 : #endif
124 :
125 : ! call timestart("alloc&init")
126 : allocate (prod(0:stepf%gridLength - 1, psize), stat=ok)
127 : if (ok /= 0) call juDFT_error("can't alloc prod")
128 : allocate (psi_k(0:stepf%gridLength - 1, 1), stat=ok)
129 : if (ok /= 0) call juDFT_error("can't alloc psi_k")
130 :
131 : call fft%init(stepf%dimensions, .true., batch_size=psize, l_gpu=.True.)
132 : call grid%init(fi%cell, fi%sym, gcutoff)
133 : call wavef2rs_fft%init(grid%dimensions, .false., batch_size=1, l_gpu=.True.)
134 : ! call timestop("alloc&init")
135 :
136 : !$acc data copyin(z_k, z_k%l_real, z_k%data_r, z_k%data_c, lapw, lapw%nv, lapw%gvec)&
137 : !$acc copyin(hybdat, hybdat%nbasp, g_ptr, grid, grid%dimensions, jsp)&
138 : !$acc create(psi_k, prod)
139 : #ifndef _OPENACC
140 : !$OMP DO
141 : #endif
142 : do iband = 1, hybdat%nbands(ik,jsp)
143 : call wavef2rs(fi, lapw, z_k, gcutoff, iband, iband, jsp, grid, wavef2rs_fft, psi_k)
144 :
145 : !$acc kernels default(none) present(prod, psi_k, psi_kqpt, stepf, stepf%gridlength)
146 : do iob = 1, psize
147 : do j = 0, stepf%gridlength-1
148 : prod(j,iob) = conjg(psi_k(j, 1)) * psi_kqpt(j, iob)
149 : enddo
150 : enddo
151 : !$acc end kernels
152 :
153 : call fft%exec_batch(prod)
154 :
155 : if (cprod%l_real) then
156 : if (.not. real_warned) then
157 : !$acc kernels present(prod) copyout(max_imag)
158 : max_imag = maxval(abs(aimag(prod)))
159 : !$acc end kernels
160 : if(max_imag > 1e-8) then
161 : write (*, *) "Imag part non-zero in too large"
162 : real_warned = .True.
163 : endif
164 : endif
165 :
166 : !$acc kernels default(none) present(cprod, cprod%data_r, prod, g_ptr)
167 : !$acc loop independent
168 : do iob = 1, psize
169 : !$acc loop independent
170 : DO igptm = 1, max_igptm
171 : cprod%data_r(hybdat%nbasp + igptm, iob + (iband - 1)*psize) = real(prod(g_ptr(igptm), iob))
172 : enddo
173 : enddo
174 : !$acc end kernels
175 : else
176 : !$acc kernels default(none) present(cprod, cprod%data_c, prod, g_ptr)
177 : !$acc loop independent
178 : do iob = 1, psize
179 : !$acc loop independent
180 : DO igptm = 1, max_igptm
181 : cprod%data_c(hybdat%nbasp + igptm, iob + (iband - 1)*psize) = prod(g_ptr(igptm), iob)
182 : enddo
183 : enddo
184 : !$acc end kernels
185 : endif
186 : enddo
187 : #ifndef _OPENACC
188 : !$OMP END DO
189 : #endif
190 : !$acc end data
191 : call fft%free()
192 : !call grid%free()
193 : call wavef2rs_fft%free()
194 : !$acc end data ! psi_kqpt
195 : deallocate (prod, psi_k)
196 : !$acc end data ! stepf, stepf%grid
197 :
198 : #ifndef _OPENACC
199 : !$OMP END PARALLEL
200 : #endif
201 : !call stepf%free()
202 :
203 88 : call timestop("Big OMP loop")
204 88 : deallocate(psi_kqpt)
205 88 : call timestop("wavef_IS_FFT")
206 88 : end subroutine wavefproducts_IS_FFT
207 :
208 88 : subroutine setup_g_ptr(mpdata, stepf, g_t, iq, g_out)
209 : implicit none
210 : type(t_mpdata), intent(in) :: mpdata
211 : type(t_fftgrid), intent(in) :: stepf
212 : integer, intent(in) :: g_t(:), iq
213 : integer, allocatable, intent(inout) :: g_out(:)
214 :
215 : integer :: igptm, g(3)
216 :
217 88 : if(allocated(g_out)) deallocate(g_out)
218 264 : allocate(g_out(mpdata%n_g(iq)))
219 :
220 10168 : DO igptm = 1, mpdata%n_g(iq)
221 40320 : g = mpdata%g(:, mpdata%gptm_ptr(igptm, iq)) - g_t
222 10168 : g_out(igptm) = stepf%g2fft(g)
223 : enddo
224 88 : end subroutine setup_g_ptr
225 :
226 4596 : subroutine wavef2rs(fi, lapw, zmat, gcutoff, bandoi, bandof, jspin, grid, fft, psi)
227 : ! put block of wave functions through FFT
228 : !$ use omp_lib
229 : use m_types
230 : use m_fft_interface
231 : implicit none
232 : type(t_fleurinput), intent(in) :: fi
233 : type(t_lapw), intent(in) :: lapw
234 : type(t_mat), intent(in) :: zmat
235 : integer, intent(in) :: jspin, bandoi, bandof
236 : real, intent(in) :: gcutoff
237 : type(t_fftgrid), intent(inout) :: grid
238 : type(t_fft), intent(inout) :: fft
239 : complex, intent(inout) :: psi(0:, bandoi:) ! (nv,ne)
240 :
241 : integer :: iv, nu, psize, dims(3)
242 :
243 : #ifndef _OPENACC
244 4596 : !$omp parallel do default(none) private(nu) shared(grid, bandoi, bandof, lapw, jspin, zMat, psi)
245 : #endif
246 : do nu = bandoi, bandof
247 : call grid%put_state_on_external_grid(lapw, jspin, zMat, nu, psi(:,nu), l_gpu=.True.)
248 : enddo
249 : #ifndef _OPENACC
250 : !$omp end parallel do
251 : #endif
252 :
253 4596 : call fft%exec_batch(psi)
254 4596 : end subroutine wavef2rs
255 :
256 0 : subroutine prep_list_of_gvec(lapw, mpdata, g_bounds, g_t, iq, jsp, pointer, gpt0, ngpt0)
257 : use m_types
258 : use m_juDFT
259 : implicit none
260 : type(t_lapw), intent(in) :: lapw
261 : TYPE(t_mpdata), intent(in) :: mpdata
262 : integer, intent(in) :: g_bounds(:), g_t(:), iq, jsp
263 : integer, allocatable, intent(inout) :: pointer(:, :, :), gpt0(:, :)
264 : integer, intent(inout) :: ngpt0
265 :
266 : integer :: ic, ig1, igptm, iigptm, ok, g(3)
267 :
268 : allocate (pointer(-g_bounds(1):g_bounds(1), &
269 : -g_bounds(2):g_bounds(2), &
270 0 : -g_bounds(3):g_bounds(3)), stat=ok)
271 0 : IF (ok /= 0) call juDFT_error('wavefproducts_noinv2: error allocation pointer')
272 0 : allocate (gpt0(3, size(pointer)), stat=ok)
273 0 : IF (ok /= 0) call juDFT_error('wavefproducts_noinv2: error allocation gpt0')
274 :
275 0 : call timestart("prep list of Gvec")
276 0 : pointer = 0
277 0 : ic = 0
278 0 : DO ig1 = 1, lapw%nv(jsp)
279 0 : DO igptm = 1, mpdata%n_g(iq)
280 0 : iigptm = mpdata%gptm_ptr(igptm, iq)
281 0 : g = lapw%gvec(:, ig1, jsp) + mpdata%g(:, iigptm) - g_t
282 0 : IF (pointer(g(1), g(2), g(3)) == 0) THEN
283 0 : ic = ic + 1
284 0 : gpt0(:, ic) = g
285 0 : pointer(g(1), g(2), g(3)) = ic
286 : END IF
287 : END DO
288 : END DO
289 0 : ngpt0 = ic
290 0 : call timestop("prep list of Gvec")
291 0 : end subroutine prep_list_of_gvec
292 :
293 0 : function calc_number_of_basis_functions(lapw, atoms, noco) result(nbasfcn)
294 : use m_types
295 : implicit NONE
296 : type(t_lapw), intent(in) :: lapw
297 : type(t_atoms), intent(in) :: atoms
298 : type(t_noco), intent(in) :: noco
299 : integer :: nbasfcn
300 :
301 0 : if (noco%l_noco) then
302 0 : nbasfcn = lapw%nv(1) + lapw%nv(2) + 2*atoms%nlotot
303 : else
304 0 : nbasfcn = lapw%nv(1) + atoms%nlotot
305 : endif
306 0 : end function calc_number_of_basis_functions
307 :
308 0 : function outer_prod(x, y) result(outer)
309 : implicit NONE
310 : complex, intent(in) :: x(:), y(:)
311 : complex :: outer(size(x), size(y))
312 : integer :: i, j
313 :
314 0 : do j = 1, size(y)
315 0 : do i = 1, size(x)
316 0 : outer(i, j) = x(i)*y(j)
317 : enddo
318 : enddo
319 0 : end function outer_prod
320 : end module m_wavefproducts_aux
|