Line data Source code
1 : module m_types_fft
2 : #ifdef CPP_FFT_MKL
3 : USE mkl_dfti
4 : #endif
5 : USE m_selecFFT
6 : use m_judft
7 : USE iso_c_binding
8 : #ifdef CPP_SPFFT
9 : USE spfft
10 : #endif
11 : #ifdef CPP_FFTW
12 : use fftw3
13 : #endif
14 : #ifdef _OPENACC
15 : use openacc
16 : use cufft
17 : #endif
18 : !$ use omp_lib
19 : implicit none
20 : private
21 : #ifdef CPP_FFT_MKL
22 : type ptr_container
23 : type(dfti_descriptor), pointer :: dfti_handle
24 : endtype ptr_container
25 : #endif
26 :
27 : type,public:: t_fft
28 : logical :: initialized = .False.
29 : integer :: backend = -1
30 : integer :: batch_size = 1
31 : integer :: length(3) = [-1,-1,-1]
32 : logical :: forw
33 : ! cfft storage
34 : real, allocatable :: afft(:), bfft(:)
35 : #ifdef _OPENACC
36 : integer(4) :: cufft_plan
37 : #endif
38 : #ifdef CPP_FFT_MKL
39 : ! mkl
40 : type(ptr_container), allocatable :: container(:)
41 : #endif
42 : #ifdef CPP_SPFFT
43 : !SpFFT
44 : integer, allocatable :: indices(:)
45 : type(c_ptr) :: transform = c_null_ptr, realSpacePtr = c_null_ptr
46 : integer :: xyPlanesize
47 : COMPLEX(C_DOUBLE_COMPLEX), ALLOCATABLE :: recSpaceFunction(:)
48 : COMPLEX(C_DOUBLE_COMPLEX), POINTER :: externalRealSpaceMesh(:, :, :)
49 : #endif
50 : #ifdef CPP_FFTW
51 : type(c_ptr), allocatable :: plan(:)
52 : type(c_ptr) :: ptr_in, ptr_out
53 : complex(C_DOUBLE_COMPLEX), pointer :: in(:,:), out(:,:)
54 : #endif
55 : contains
56 : procedure :: init => t_fft_init
57 : procedure :: exec => t_fft_exec_single
58 : procedure :: exec_batch => t_fft_exec_batched
59 : procedure :: free => t_fft_free
60 : end type t_fft
61 : contains
62 191863 : subroutine t_fft_init(fft, length, forw, indices, batch_size, l_gpu)
63 : implicit none
64 : class(t_fft) :: fft
65 : integer, intent(in) :: length(3) !length of data in each direction
66 : logical, intent(in) :: forw !.true. for the forward transformation, .false. for the backward one
67 : INTEGER, OPTIONAL, INTENT(IN) :: indices(:) !array of indices of relevant/nonzero elements in the FFT mesh
68 : integer, optional, intent(in) :: batch_size
69 : logical, optional, intent(in) :: l_gpu
70 :
71 : INTEGER, PARAMETER :: numOMPThreads = 1
72 : integer :: size_dat, ierr, fftMeshIndex, maxNumLocalZColumns
73 : integer :: temp, x, y, z, xCoord, yCoord, zCoord, i
74 : INTEGER, ALLOCATABLE :: sparseCoords(:)
75 : LOGICAL, ALLOCATABLE :: nonzeroArea(:, :)
76 : type(c_ptr) :: grid = c_null_ptr
77 : logical :: in_openmp = .false.
78 : integer :: max_threads = 1, thread_id = 0
79 : integer :: n_plans
80 : integer :: n(3), dist
81 : integer, parameter :: stride = 1
82 :
83 191863 : !$ thread_id = omp_get_thread_num()
84 191863 : !$ max_threads = omp_get_max_threads()
85 191863 : !$ in_openmp = omp_in_parallel()
86 :
87 191863 : if(present(batch_size)) then
88 528 : fft%batch_size = batch_size
89 : else
90 191335 : fft%batch_size = 1
91 : endif
92 :
93 191863 : fft%initialized = .True.
94 191863 : fft%backend = defaultFFT_const
95 191863 : fft%backend = selecFFT(PRESENT(indices), l_gpu)
96 767452 : fft%length = length
97 191863 : fft%forw = forw
98 :
99 383726 : select case(fft%backend)
100 : #ifdef CPP_FFTW
101 : case(FFTW_const)
102 191863 : n_plans = min(max_threads, fft%batch_size)
103 575589 : allocate(fft%plan(n_plans))
104 :
105 767452 : fft%ptr_in = fftw_alloc_complex(int(n_plans * product(length), C_SIZE_T))
106 1151178 : call c_f_pointer(fft%ptr_in, fft%in, [product(length), n_plans])
107 :
108 767452 : fft%ptr_out = fftw_alloc_complex(int(n_plans * product(length), C_SIZE_T))
109 1151178 : call c_f_pointer(fft%ptr_out, fft%out, [product(length), n_plans])
110 :
111 383990 : do i = 1,n_plans
112 383990 : !$omp critical
113 192127 : if(fft%forw) then
114 : fft%plan(i) = fftw_plan_dft_3d(fft%length(3), fft%length(2), fft%length(1),&
115 32714 : fft%in(:,i), fft%out(:,i), FFTW_FORWARD,FFTW_MEASURE)
116 : else
117 : fft%plan(i) = fftw_plan_dft_3d(fft%length(3), fft%length(2), fft%length(1),&
118 159413 : fft%in(:,i), fft%out(:,i), FFTW_BACKWARD,FFTW_MEASURE)
119 : endif
120 : !$omp end critical
121 : enddo
122 : #endif
123 : #ifdef _OPENACC
124 : case(cuFFT_const)
125 : n = [fft%length(3), fft%length(2), fft%length(1)]
126 : dist = product(fft%length)
127 : ierr = cufftPlanMany(fft%cufft_plan, 3_4, n, &
128 : n, stride, dist, n, stride, dist, CUFFT_Z2Z, fft%batch_size)
129 :
130 : if(ierr /= 0) then
131 : call acc_present_dump()
132 : call handle_cufft_error(ierr)
133 : call juDFT_error("cuFFT Plan many failed.")
134 : endif
135 : #endif
136 : case(mklFFT_const)
137 : #ifdef CPP_FFT_MKL
138 : n_plans = min(max_threads, fft%batch_size)
139 : allocate(fft%container(n_plans))
140 : do i = 1,n_plans
141 : ierr = DftiCreateDescriptor(fft%container(i)%dfti_handle, dfti_double, dfti_complex, 3, length)
142 : if (ierr /= 0) call juDFT_error("cant create descriptor", calledby="fft_interface")
143 : ierr = DftiCommitDescriptor(fft%container(i)%dfti_handle)
144 : if (ierr /= 0) call juDFT_error("can't commit descriptor", calledby="fft_interface")
145 : enddo
146 : #endif
147 :
148 : #ifdef CPP_SPFFT
149 : case(spFFT_const)
150 : fft%indices = indices
151 : ALLOCATE(sparseCoords(3*SIZE(fft%indices)))
152 : if(.not. allocated(fft%recSpaceFunction)) ALLOCATE(fft%recSpaceFunction(SIZE(fft%indices)))
153 : ALLOCATE(nonzeroArea(0:length(1) - 1, 0:length(2) - 1))
154 : nonzeroArea(:, :) = .FALSE.
155 : fft%xyPlaneSize = fft%length(1)*fft%length(2)
156 : DO i = 1, SIZE(fft%indices)
157 : zCoord = fft%indices(i)/fft%xyPlaneSize
158 : temp = MOD(fft%indices(i), fft%xyPlaneSize)
159 : yCoord = temp/length(1)
160 : xCoord = MOD(temp, length(1))
161 :
162 : sparseCoords(3*(i - 1) + 3) = zCoord
163 : sparseCoords(3*(i - 1) + 2) = yCoord
164 : sparseCoords(3*(i - 1) + 1) = xCoord
165 :
166 : nonzeroArea(xCoord, yCoord) = .TRUE.
167 : END DO
168 :
169 : maxNumLocalZColumns = COUNT(nonzeroArea)
170 : IF (fft%forw) THEN
171 : ierr = spfft_grid_create(grid, length(1), length(2), length(3), &
172 : maxNumLocalZColumns, SPFFT_PU_HOST, numOMPThreads);
173 : IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in creating spFFT grid! (1)")
174 :
175 : ierr = spfft_transform_create(fft%transform, grid, SPFFT_PU_HOST, SPFFT_TRANS_C2C, &
176 : length(1), length(2), length(3), length(3), &
177 : size(fft%recSpaceFunction), SPFFT_INDEX_TRIPLETS, sparseCoords)
178 : IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in creating spFFT transform! (1)")
179 :
180 : ierr = spfft_grid_destroy(grid)
181 : IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in destroying spFFT grid! (1)")
182 :
183 : ierr = spfft_transform_get_space_domain(fft%transform, SPFFT_PU_HOST, fft%realSpacePtr)
184 : IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in obtaining spFFT space domain! (1)")
185 :
186 : CALL C_F_POINTER(fft%realSpacePtr, fft%externalRealSpaceMesh, [length(1), length(2), length(3)])
187 : ELSE
188 : ierr = spfft_grid_create(grid, length(1), length(2), length(3), &
189 : maxNumLocalZColumns, SPFFT_PU_HOST, numOMPThreads);
190 : IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in creating spFFT grid! (2)")
191 :
192 : ierr = spfft_transform_create(fft%transform, grid, SPFFT_PU_HOST, SPFFT_TRANS_C2C, &
193 : length(1), length(2), length(3), length(3), &
194 : size(fft%recSpaceFunction), SPFFT_INDEX_TRIPLETS, sparseCoords)
195 : IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in creating spFFT transform! (2)")
196 :
197 : ierr = spfft_grid_destroy(grid)
198 : IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in destroying spFFT grid! (2)")
199 :
200 : ierr = spfft_transform_get_space_domain(fft%transform, SPFFT_PU_HOST, fft%realSpacePtr)
201 : IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in obtaining spFFT space domain! (2)")
202 : END IF
203 : #endif
204 : case default
205 0 : size_dat = product(length)
206 0 : allocate (fft%afft(size_dat), fft%bfft(size_dat), stat=ierr)
207 191863 : if (ierr /= 0) call juDFT_error("can't alloc afft & bfft", calledby="fft_interface")
208 : end select
209 191863 : end subroutine
210 :
211 : subroutine handle_cufft_error(ierr)
212 : implicit none
213 : integer, intent(in) :: ierr
214 :
215 : ! FROM
216 : ! https://docs.nvidia.com/hpc-sdk/compilers/fortran-cuda-interfaces/index.html#cf-fft-runtime
217 :
218 : select case (ierr)
219 : case (1)
220 : write (*,*) "CUFFT_INVALID_PLAN"
221 : case (2)
222 : write (*,*) "CUFFT_ALLOC_FAILED"
223 : case (3)
224 : write (*,*) "CUFFT_INVALID_TYPE"
225 : case (4)
226 : write (*,*) "CUFFT_INVALID_VALUE"
227 : case (5)
228 : write (*,*) "CUFFT_INTERNAL_ERROR"
229 : case (6)
230 : write (*,*) "CUFFT_EXEC_FAILED"
231 : case (7)
232 : write (*,*) "CUFFT_SETUP_FAILED"
233 : case (8)
234 : write (*,*) "CUFFT_INVALID_SIZE"
235 : case (9)
236 : write (*,*) "CUFFT_UNALIGNED_DATA"
237 : case default
238 : write (*,*) "unknow cuda errror"
239 : end select
240 :
241 : end subroutine handle_cufft_error
242 :
243 200527 : subroutine t_fft_exec_batched(fft, dat)
244 : USE m_cfft
245 : implicit none
246 : class(t_fft), intent(inout) :: fft
247 : complex, intent(inout) :: dat(:,:)
248 : integer :: isn, size_dat
249 : INTEGER :: i, x, y, z, fftMeshIndex, ierr, me, direction
250 : logical :: in_omp
251 :
252 802108 : size_dat = product(fft%length)
253 :
254 401054 : select case(fft%backend)
255 : #ifdef CPP_FFTW
256 : case(fftw_const)
257 200527 : me = 1
258 200527 : !$omp parallel do default(none) private(me, i) shared(fft, dat)
259 : do i = 1,size(dat,2)
260 : !$ me = omp_get_thread_num() + 1
261 : fft%in(:,me) = dat(:,i)
262 : call fftw_execute_dft(fft%plan(me), fft%in(:,me), fft%out(:,me))
263 : dat(:,i) = fft%out(:,me)
264 : enddo
265 : !$omp end parallel do
266 : #endif
267 : case(mklFFT_const)
268 : #ifdef CPP_FFT_MKL
269 : me = 1
270 : !$omp parallel do default(none) private(me, i, ierr) shared(fft, dat)
271 : do i = 1,size(dat,2)
272 : !$ me = omp_get_thread_num() + 1
273 : if (fft%forw) then
274 : ierr = DftiComputeForward(fft%container(me)%dfti_handle, dat(:,i))
275 : else
276 : ierr = DftiComputeBackward(fft%container(me)%dfti_handle, dat(:,i))
277 : end if
278 : if(ierr /= 0) call juDFT_error("problem executing dft")
279 : enddo
280 : !$omp end parallel do
281 : #endif
282 : #ifdef _OPENACC
283 : case(cufft_const)
284 : in_omp = .False.
285 : !$ in_omp = omp_in_parallel()
286 : if(in_omp) call juDFT_error("calling cuFFT from within OMP")
287 :
288 : !$acc host_data use_device(dat)
289 : ierr = cufftExecZ2z(fft%cufft_plan, dat, dat, merge(CUFFT_FORWARD, CUFFT_INVERSE, fft%forw))
290 : !$acc end host_data
291 : if(ierr /= 0) call juDFT_error("executing cufft failed.")
292 : #endif
293 : #ifdef CPP_SPFFT
294 : case(spFFT_const)
295 : IF (fft%forw) THEN
296 : DO z = 1, SIZE(fft%externalRealSpaceMesh, 3)
297 : DO y = 1, SIZE(fft%externalRealSpaceMesh, 2)
298 : DO x = 1, SIZE(fft%externalRealSpaceMesh, 1)
299 : fftMeshIndex = (x - 1) + (y - 1)*fft%length(1) + (z - 1)*fft%xyPlaneSize + 1
300 : fft%externalRealSpaceMesh(x, y, z) = dat(fftMeshIndex)
301 : END DO
302 : END DO
303 : END DO
304 : ierr = spfft_transform_forward(fft%transform, SPFFT_PU_HOST, fft%recSpaceFunction, SPFFT_NO_SCALING)!SPFFT_FULL_SCALING)
305 : IF (ierr /= SPFFT_SUCCESS) THEN
306 : CALL juDFT_error("Error in spFFT forward fft%transform! (1)", calledby="fft_interface")
307 : END IF
308 : dat(:) = CMPLX(0.0, 0.0)
309 : DO i = 1, SIZE(fft%indices)
310 : dat(fft%indices(i) + 1) = fft%recSpaceFunction(i)
311 : END DO
312 :
313 : ELSE
314 : DO i = 1, SIZE(fft%indices)
315 : fft%recSpaceFunction(i) = dat(fft%indices(i) + 1)
316 : END DO
317 : ierr = spfft_transform_backward(fft%transform, fft%recSpaceFunction, SPFFT_PU_HOST)
318 : IF (ierr /= SPFFT_SUCCESS) THEN
319 : CALL juDFT_error("Error in spFFT backward fft%transform! (2)", calledby="fft_interface")
320 : END IF
321 :
322 : CALL C_F_POINTER(fft%realSpacePtr, fft%externalRealSpaceMesh, [fft%length(1), fft%length(2), fft%length(3)])
323 :
324 : DO z = 1, SIZE(fft%externalRealSpaceMesh, 3)
325 : DO y = 1, SIZE(fft%externalRealSpaceMesh, 2)
326 : DO x = 1, SIZE(fft%externalRealSpaceMesh, 1)
327 : fftMeshIndex = (x - 1) + (y - 1)*fft%length(1) + (z - 1)*fft%xyPlaneSize + 1
328 : dat(fftMeshIndex) = fft%externalRealSpaceMesh(x, y, z)
329 : END DO
330 : END DO
331 : END DO
332 : END IF
333 : #endif
334 : case default
335 200527 : do i = 1,size(dat,2)
336 0 : fft%afft = real(dat(:,i))
337 0 : fft%bfft = aimag(dat(:,i))
338 :
339 0 : isn = merge(-1, 1, fft%forw)
340 0 : CALL cfft(fft%afft, fft%bfft, size_dat, fft%length(1), fft%length(1), isn)
341 0 : CALL cfft(fft%afft, fft%bfft, size_dat, fft%length(2), fft%length(1)*fft%length(2), isn)
342 0 : CALL cfft(fft%afft, fft%bfft, size_dat, fft%length(3), size_dat, isn)
343 0 : dat(:,i) = cmplx(fft%afft, fft%bfft)
344 : enddo
345 : end select
346 200527 : end subroutine t_fft_exec_batched
347 :
348 191423 : subroutine t_fft_exec_single(fft, dat)
349 : implicit none
350 : class(t_fft), intent(inout) :: fft
351 : complex, intent(inout),target :: dat(:)
352 : integer :: i
353 : type(c_ptr) :: ptr
354 : complex, pointer :: tmp_2d(:,:)
355 :
356 : ! if an array is 1D just pretend it's 2d
357 191423 : ptr = c_loc(dat)
358 574269 : call c_f_pointer(ptr, tmp_2d, [size(dat), 1])
359 :
360 191423 : call t_fft_exec_batched(fft, tmp_2d)
361 191423 : end subroutine t_fft_exec_single
362 :
363 191863 : subroutine t_fft_free(fft)
364 : implicit none
365 : integer :: ierr
366 : class(t_fft) :: fft
367 : logical :: in_openmp = .false.
368 : integer :: i
369 191863 : !$ in_openmp = omp_in_parallel()
370 :
371 191863 : if(allocated(fft%afft)) deallocate(fft%afft)
372 191863 : if(allocated(fft%bfft)) deallocate(fft%bfft)
373 383726 : select case(fft%backend)
374 : #ifdef CPP_FFTW
375 : case(FFTW_const)
376 191863 : call fftw_free(fft%ptr_in)
377 191863 : call fftw_free(fft%ptr_out)
378 :
379 383990 : do i=1,size(fft%plan)
380 384254 : !$omp critical
381 192127 : call fftw_destroy_plan(fft%plan(i))
382 : !$omp end critical
383 383990 : fft%plan(i) = c_null_ptr
384 : enddo
385 383726 : deallocate(fft%plan)
386 : #endif
387 : #ifdef _OPENACC
388 : case(cufft_const)
389 : ierr = cufftDestroy(fft%cufft_plan)
390 : if(ierr /= 0) call juDFT_error("cufftdestroy failed")
391 : #endif
392 : case(mklFFT_const)
393 : #ifdef CPP_FFT_MKL
394 : do i=1,size(fft%container)
395 : ierr = DftiFreeDescriptor(fft%container(i)%dfti_handle)
396 : enddo
397 : deallocate(fft%container)
398 : #endif
399 : #ifdef CPP_SPFFT
400 : case(spFFT_const)
401 : ierr = spfft_transform_destroy(fft%transform)
402 : IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in destroying spFFT fft%transform! (1)")
403 : fft%transform = c_null_ptr
404 : fft%realSpacePtr = c_null_ptr
405 : #endif
406 : case default
407 :
408 : end select
409 :
410 191863 : fft%initialized = .False.
411 191863 : fft%backend = -1
412 767452 : fft%length = [-1,-1,-1]
413 191863 : fft%batch_size = -1
414 191863 : end subroutine t_fft_free
415 0 : end module m_types_fft
|