LCOV - code coverage report
Current view: top level - types - types_fft.F90 (source / functions) Hit Total Coverage
Test: FLEUR test coverage Lines: 55 65 84.6 %
Date: 2024-04-28 04:28:00 Functions: 4 6 66.7 %

          Line data    Source code
       1             : module m_types_fft
       2             : #ifdef CPP_FFT_MKL
       3             :    USE mkl_dfti
       4             : #endif
       5             :    USE m_selecFFT
       6             :    use m_judft
       7             :    USE iso_c_binding
       8             : #ifdef CPP_SPFFT
       9             :    USE spfft
      10             : #endif
      11             : #ifdef CPP_FFTW 
      12             :    use fftw3
      13             : #endif
      14             : #ifdef _OPENACC
      15             :    use openacc
      16             :    use cufft
      17             : #endif
      18             :    !$ use omp_lib
      19             :    implicit none
      20             :    private
      21             : #ifdef CPP_FFT_MKL
      22             :    type ptr_container
      23             :       type(dfti_descriptor), pointer :: dfti_handle
      24             :    endtype ptr_container
      25             : #endif
      26             : 
      27             :    type,public:: t_fft 
      28             :       logical :: initialized = .False.
      29             :       integer :: backend = -1
      30             :       integer :: batch_size = 1
      31             :       integer :: length(3) = [-1,-1,-1]
      32             :       logical :: forw
      33             :       ! cfft storage
      34             :       real, allocatable :: afft(:), bfft(:)
      35             : #ifdef _OPENACC 
      36             :       integer(4) :: cufft_plan
      37             : #endif
      38             : #ifdef CPP_FFT_MKL
      39             :       ! mkl
      40             :       type(ptr_container), allocatable :: container(:)
      41             : #endif
      42             : #ifdef CPP_SPFFT
      43             :       !SpFFT
      44             :       integer, allocatable :: indices(:)
      45             :       type(c_ptr)          :: transform = c_null_ptr, realSpacePtr = c_null_ptr
      46             :       integer              :: xyPlanesize
      47             :       COMPLEX(C_DOUBLE_COMPLEX), ALLOCATABLE :: recSpaceFunction(:)
      48             :       COMPLEX(C_DOUBLE_COMPLEX), POINTER     :: externalRealSpaceMesh(:, :, :)
      49             : #endif
      50             : #ifdef CPP_FFTW
      51             :       type(c_ptr), allocatable :: plan(:)
      52             :       type(c_ptr)              :: ptr_in, ptr_out
      53             :       complex(C_DOUBLE_COMPLEX), pointer :: in(:,:), out(:,:)       
      54             : #endif
      55             :    contains 
      56             :       procedure :: init => t_fft_init
      57             :       procedure :: exec => t_fft_exec_single
      58             :       procedure :: exec_batch => t_fft_exec_batched
      59             :       procedure :: free => t_fft_free
      60             :    end type t_fft
      61             : contains
      62      191863 :    subroutine t_fft_init(fft, length, forw, indices, batch_size, l_gpu)
      63             :       implicit none       
      64             :       class(t_fft)                  :: fft
      65             :       integer, intent(in)           :: length(3) !length of data in each direction
      66             :       logical, intent(in)           :: forw          !.true. for the forward transformation, .false. for the backward one
      67             :       INTEGER, OPTIONAL, INTENT(IN) :: indices(:)    !array of indices of relevant/nonzero elements in the FFT mesh
      68             :       integer, optional, intent(in) :: batch_size
      69             :       logical, optional, intent(in) :: l_gpu
      70             : 
      71             :       INTEGER, PARAMETER :: numOMPThreads = 1
      72             :       integer :: size_dat, ierr, fftMeshIndex, maxNumLocalZColumns
      73             :       integer :: temp, x, y, z, xCoord, yCoord, zCoord, i
      74             :       INTEGER, ALLOCATABLE :: sparseCoords(:)
      75             :       LOGICAL, ALLOCATABLE :: nonzeroArea(:, :)
      76             :       type(c_ptr)          :: grid = c_null_ptr
      77             :       logical :: in_openmp = .false.
      78             :       integer :: max_threads = 1, thread_id = 0
      79             :       integer :: n_plans
      80             :       integer :: n(3), dist
      81             :       integer, parameter :: stride = 1
      82             : 
      83      191863 :       !$ thread_id   = omp_get_thread_num()
      84      191863 :       !$ max_threads = omp_get_max_threads()
      85      191863 :       !$ in_openmp   = omp_in_parallel()
      86             : 
      87      191863 :       if(present(batch_size)) then
      88         528 :          fft%batch_size = batch_size
      89             :       else
      90      191335 :          fft%batch_size = 1
      91             :       endif
      92             : 
      93      191863 :       fft%initialized = .True.
      94      191863 :       fft%backend = defaultFFT_const
      95      191863 :       fft%backend = selecFFT(PRESENT(indices), l_gpu)
      96      767452 :       fft%length  = length
      97      191863 :       fft%forw    = forw
      98             : 
      99      383726 :       select case(fft%backend)
     100             : #ifdef CPP_FFTW
     101             :       case(FFTW_const)
     102      191863 :          n_plans = min(max_threads, fft%batch_size)
     103      575589 :          allocate(fft%plan(n_plans))
     104             : 
     105      767452 :          fft%ptr_in = fftw_alloc_complex(int(n_plans * product(length), C_SIZE_T))
     106     1151178 :          call c_f_pointer(fft%ptr_in, fft%in, [product(length), n_plans])
     107             : 
     108      767452 :          fft%ptr_out = fftw_alloc_complex(int(n_plans * product(length), C_SIZE_T))
     109     1151178 :          call c_f_pointer(fft%ptr_out, fft%out, [product(length), n_plans])
     110             : 
     111      383990 :          do i = 1,n_plans
     112      383990 :             !$omp critical
     113      192127 :             if(fft%forw) then
     114             :                fft%plan(i) = fftw_plan_dft_3d(fft%length(3), fft%length(2), fft%length(1),&
     115       32714 :                                              fft%in(:,i), fft%out(:,i), FFTW_FORWARD,FFTW_MEASURE) 
     116             :             else
     117             :                fft%plan(i) = fftw_plan_dft_3d(fft%length(3), fft%length(2), fft%length(1),&
     118      159413 :                                              fft%in(:,i), fft%out(:,i), FFTW_BACKWARD,FFTW_MEASURE) 
     119             :             endif
     120             :             !$omp end critical
     121             :          enddo
     122             : #endif
     123             : #ifdef _OPENACC
     124             :       case(cuFFT_const)
     125             :          n = [fft%length(3), fft%length(2), fft%length(1)]
     126             :          dist   = product(fft%length)
     127             :          ierr = cufftPlanMany(fft%cufft_plan, 3_4, n, &
     128             :                               n, stride, dist, n, stride, dist, CUFFT_Z2Z, fft%batch_size)
     129             : 
     130             :          if(ierr /= 0) then
     131             :             call acc_present_dump()
     132             :             call handle_cufft_error(ierr)
     133             :             call juDFT_error("cuFFT Plan many failed.")
     134             :          endif
     135             : #endif
     136             :       case(mklFFT_const)
     137             : #ifdef CPP_FFT_MKL
     138             :          n_plans = min(max_threads, fft%batch_size)
     139             :          allocate(fft%container(n_plans))
     140             :          do i = 1,n_plans
     141             :             ierr = DftiCreateDescriptor(fft%container(i)%dfti_handle, dfti_double, dfti_complex, 3, length)
     142             :             if (ierr /= 0) call juDFT_error("cant create descriptor", calledby="fft_interface")
     143             :             ierr = DftiCommitDescriptor(fft%container(i)%dfti_handle)
     144             :             if (ierr /= 0) call juDFT_error("can't commit descriptor", calledby="fft_interface")
     145             :          enddo
     146             : #endif
     147             : 
     148             : #ifdef CPP_SPFFT
     149             :       case(spFFT_const)
     150             :          fft%indices = indices
     151             :          ALLOCATE(sparseCoords(3*SIZE(fft%indices)))
     152             :          if(.not. allocated(fft%recSpaceFunction)) ALLOCATE(fft%recSpaceFunction(SIZE(fft%indices)))
     153             :          ALLOCATE(nonzeroArea(0:length(1) - 1, 0:length(2) - 1))
     154             :          nonzeroArea(:, :) = .FALSE.
     155             :          fft%xyPlaneSize = fft%length(1)*fft%length(2)
     156             :          DO i = 1, SIZE(fft%indices)
     157             :             zCoord = fft%indices(i)/fft%xyPlaneSize
     158             :             temp = MOD(fft%indices(i), fft%xyPlaneSize)
     159             :             yCoord = temp/length(1)
     160             :             xCoord = MOD(temp, length(1))
     161             : 
     162             :             sparseCoords(3*(i - 1) + 3) = zCoord
     163             :             sparseCoords(3*(i - 1) + 2) = yCoord
     164             :             sparseCoords(3*(i - 1) + 1) = xCoord
     165             : 
     166             :             nonzeroArea(xCoord, yCoord) = .TRUE.
     167             :          END DO
     168             : 
     169             :          maxNumLocalZColumns = COUNT(nonzeroArea)
     170             :          IF (fft%forw) THEN
     171             :             ierr = spfft_grid_create(grid, length(1), length(2), length(3), &
     172             :                                           maxNumLocalZColumns, SPFFT_PU_HOST, numOMPThreads); 
     173             :             IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in creating spFFT grid! (1)")
     174             : 
     175             :             ierr = spfft_transform_create(fft%transform, grid, SPFFT_PU_HOST, SPFFT_TRANS_C2C, &
     176             :                                                length(1), length(2), length(3), length(3), &
     177             :                                                size(fft%recSpaceFunction), SPFFT_INDEX_TRIPLETS, sparseCoords)
     178             :             IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in creating spFFT transform! (1)")
     179             : 
     180             :             ierr = spfft_grid_destroy(grid)
     181             :             IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in destroying spFFT grid! (1)")
     182             : 
     183             :             ierr = spfft_transform_get_space_domain(fft%transform, SPFFT_PU_HOST, fft%realSpacePtr)
     184             :             IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in obtaining spFFT space domain! (1)")
     185             : 
     186             :             CALL C_F_POINTER(fft%realSpacePtr, fft%externalRealSpaceMesh, [length(1), length(2), length(3)])
     187             :          ELSE
     188             :             ierr = spfft_grid_create(grid, length(1), length(2), length(3), &
     189             :                                           maxNumLocalZColumns, SPFFT_PU_HOST, numOMPThreads); 
     190             :             IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in creating spFFT grid! (2)")
     191             : 
     192             :             ierr = spfft_transform_create(fft%transform, grid, SPFFT_PU_HOST, SPFFT_TRANS_C2C, &
     193             :                                                length(1), length(2), length(3), length(3), &
     194             :                                                size(fft%recSpaceFunction), SPFFT_INDEX_TRIPLETS, sparseCoords)
     195             :             IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in creating spFFT transform! (2)")
     196             : 
     197             :             ierr = spfft_grid_destroy(grid)
     198             :             IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in destroying spFFT grid! (2)")
     199             : 
     200             :             ierr = spfft_transform_get_space_domain(fft%transform, SPFFT_PU_HOST, fft%realSpacePtr)
     201             :             IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in obtaining spFFT space domain! (2)")
     202             :          END IF
     203             : #endif
     204             :       case default 
     205           0 :          size_dat = product(length)
     206           0 :          allocate (fft%afft(size_dat), fft%bfft(size_dat), stat=ierr)
     207      191863 :          if (ierr /= 0) call juDFT_error("can't alloc afft & bfft", calledby="fft_interface")
     208             :       end select
     209      191863 :    end subroutine
     210             : 
     211             :    subroutine handle_cufft_error(ierr)
     212             :       implicit none 
     213             :       integer, intent(in) :: ierr 
     214             : 
     215             :       ! FROM
     216             :       ! https://docs.nvidia.com/hpc-sdk/compilers/fortran-cuda-interfaces/index.html#cf-fft-runtime
     217             : 
     218             :       select case (ierr)
     219             :       case  (1)  
     220             :          write (*,*) "CUFFT_INVALID_PLAN"
     221             :       case (2)
     222             :          write (*,*) "CUFFT_ALLOC_FAILED"
     223             :       case (3)
     224             :          write (*,*) "CUFFT_INVALID_TYPE"
     225             :       case (4)
     226             :          write (*,*) "CUFFT_INVALID_VALUE"
     227             :       case (5)
     228             :          write (*,*) "CUFFT_INTERNAL_ERROR"
     229             :       case (6)
     230             :          write (*,*) "CUFFT_EXEC_FAILED"
     231             :       case (7)
     232             :          write (*,*) "CUFFT_SETUP_FAILED"
     233             :       case (8)
     234             :          write (*,*) "CUFFT_INVALID_SIZE"
     235             :       case (9)
     236             :          write (*,*) "CUFFT_UNALIGNED_DATA"
     237             :       case default
     238             :          write (*,*) "unknow cuda errror"
     239             :       end select
     240             : 
     241             :    end subroutine handle_cufft_error
     242             : 
     243      200527 :    subroutine t_fft_exec_batched(fft, dat)
     244             :       USE m_cfft
     245             :       implicit none 
     246             :       class(t_fft), intent(inout) :: fft
     247             :       complex, intent(inout)      :: dat(:,:) 
     248             :       integer      :: isn, size_dat 
     249             :       INTEGER      :: i, x, y, z, fftMeshIndex, ierr, me, direction
     250             :       logical      :: in_omp
     251             : 
     252      802108 :       size_dat = product(fft%length)
     253             : 
     254      401054 :       select case(fft%backend)
     255             : #ifdef CPP_FFTW
     256             :       case(fftw_const)
     257      200527 :          me = 1
     258      200527 :          !$omp parallel do default(none) private(me, i) shared(fft, dat)
     259             :          do i = 1,size(dat,2)
     260             :             !$ me = omp_get_thread_num() + 1
     261             :             fft%in(:,me) = dat(:,i)
     262             :             call fftw_execute_dft(fft%plan(me), fft%in(:,me), fft%out(:,me))
     263             :             dat(:,i) = fft%out(:,me)
     264             :          enddo
     265             :          !$omp end parallel do
     266             : #endif
     267             :       case(mklFFT_const)
     268             : #ifdef CPP_FFT_MKL
     269             :          me = 1
     270             :          !$omp parallel do default(none) private(me, i, ierr) shared(fft, dat)
     271             :          do i = 1,size(dat,2)
     272             :             !$ me = omp_get_thread_num() + 1
     273             :             if (fft%forw) then
     274             :                ierr = DftiComputeForward(fft%container(me)%dfti_handle, dat(:,i))
     275             :             else
     276             :                ierr = DftiComputeBackward(fft%container(me)%dfti_handle, dat(:,i))
     277             :             end if
     278             :             if(ierr /= 0) call juDFT_error("problem executing dft")
     279             :          enddo
     280             :          !$omp end parallel do
     281             : #endif
     282             : #ifdef _OPENACC 
     283             :       case(cufft_const)
     284             :          in_omp = .False. 
     285             :          !$ in_omp = omp_in_parallel() 
     286             :          if(in_omp) call juDFT_error("calling cuFFT from within OMP")
     287             : 
     288             :          !$acc host_data use_device(dat)
     289             :          ierr = cufftExecZ2z(fft%cufft_plan, dat, dat, merge(CUFFT_FORWARD, CUFFT_INVERSE, fft%forw))
     290             :          !$acc end host_data
     291             :          if(ierr /= 0) call juDFT_error("executing cufft failed.")
     292             : #endif
     293             : #ifdef CPP_SPFFT
     294             :       case(spFFT_const)
     295             :          IF (fft%forw) THEN
     296             :             DO z = 1, SIZE(fft%externalRealSpaceMesh, 3)
     297             :                DO y = 1, SIZE(fft%externalRealSpaceMesh, 2)
     298             :                   DO x = 1, SIZE(fft%externalRealSpaceMesh, 1)
     299             :                      fftMeshIndex = (x - 1) + (y - 1)*fft%length(1) + (z - 1)*fft%xyPlaneSize + 1
     300             :                      fft%externalRealSpaceMesh(x, y, z) = dat(fftMeshIndex)
     301             :                   END DO
     302             :                END DO
     303             :             END DO
     304             :             ierr = spfft_transform_forward(fft%transform, SPFFT_PU_HOST, fft%recSpaceFunction, SPFFT_NO_SCALING)!SPFFT_FULL_SCALING)
     305             :             IF (ierr /= SPFFT_SUCCESS) THEN
     306             :                CALL juDFT_error("Error in spFFT forward fft%transform! (1)", calledby="fft_interface")
     307             :             END IF
     308             :             dat(:) = CMPLX(0.0, 0.0)
     309             :             DO i = 1, SIZE(fft%indices)
     310             :                dat(fft%indices(i) + 1) = fft%recSpaceFunction(i)
     311             :             END DO
     312             : 
     313             :          ELSE
     314             :             DO i = 1, SIZE(fft%indices)
     315             :                fft%recSpaceFunction(i) = dat(fft%indices(i) + 1)
     316             :             END DO
     317             :             ierr = spfft_transform_backward(fft%transform, fft%recSpaceFunction, SPFFT_PU_HOST)
     318             :             IF (ierr /= SPFFT_SUCCESS) THEN
     319             :                CALL juDFT_error("Error in spFFT backward fft%transform! (2)", calledby="fft_interface")
     320             :             END IF
     321             : 
     322             :             CALL C_F_POINTER(fft%realSpacePtr, fft%externalRealSpaceMesh, [fft%length(1), fft%length(2), fft%length(3)])
     323             : 
     324             :             DO z = 1, SIZE(fft%externalRealSpaceMesh, 3)
     325             :                DO y = 1, SIZE(fft%externalRealSpaceMesh, 2)
     326             :                   DO x = 1, SIZE(fft%externalRealSpaceMesh, 1)
     327             :                      fftMeshIndex = (x - 1) + (y - 1)*fft%length(1) + (z - 1)*fft%xyPlaneSize + 1
     328             :                      dat(fftMeshIndex) = fft%externalRealSpaceMesh(x, y, z)
     329             :                   END DO
     330             :                END DO
     331             :             END DO
     332             :          END IF
     333             : #endif
     334             :       case default
     335      200527 :          do i = 1,size(dat,2)
     336           0 :             fft%afft = real(dat(:,i))
     337           0 :             fft%bfft = aimag(dat(:,i))
     338             :       
     339           0 :             isn = merge(-1, 1, fft%forw)
     340           0 :             CALL cfft(fft%afft, fft%bfft, size_dat, fft%length(1), fft%length(1), isn)
     341           0 :             CALL cfft(fft%afft, fft%bfft, size_dat, fft%length(2), fft%length(1)*fft%length(2), isn)
     342           0 :             CALL cfft(fft%afft, fft%bfft, size_dat, fft%length(3), size_dat, isn)
     343           0 :             dat(:,i) = cmplx(fft%afft, fft%bfft)
     344             :          enddo
     345             :       end select
     346      200527 :    end subroutine t_fft_exec_batched
     347             : 
     348      191423 :    subroutine t_fft_exec_single(fft, dat)
     349             :       implicit none
     350             :       class(t_fft), intent(inout)   :: fft
     351             :       complex, intent(inout),target :: dat(:) 
     352             :       integer :: i
     353             :       type(c_ptr)      :: ptr
     354             :       complex, pointer :: tmp_2d(:,:) 
     355             : 
     356             :       ! if an array is 1D just pretend it's 2d
     357      191423 :       ptr = c_loc(dat)
     358      574269 :       call c_f_pointer(ptr, tmp_2d, [size(dat), 1])
     359             : 
     360      191423 :       call t_fft_exec_batched(fft, tmp_2d)
     361      191423 :    end subroutine t_fft_exec_single
     362             : 
     363      191863 :    subroutine t_fft_free(fft)
     364             :       implicit none 
     365             :       integer      :: ierr
     366             :       class(t_fft) :: fft 
     367             :       logical :: in_openmp = .false.
     368             :       integer :: i
     369      191863 :       !$ in_openmp   = omp_in_parallel()
     370             : 
     371      191863 :       if(allocated(fft%afft)) deallocate(fft%afft)
     372      191863 :       if(allocated(fft%bfft)) deallocate(fft%bfft)
     373      383726 :       select case(fft%backend)
     374             : #ifdef CPP_FFTW
     375             :       case(FFTW_const)
     376      191863 :          call fftw_free(fft%ptr_in)
     377      191863 :          call fftw_free(fft%ptr_out)
     378             : 
     379      383990 :          do i=1,size(fft%plan)
     380      384254 :             !$omp critical
     381      192127 :             call fftw_destroy_plan(fft%plan(i))
     382             :             !$omp end critical
     383      383990 :             fft%plan(i)  = c_null_ptr
     384             :          enddo     
     385      383726 :          deallocate(fft%plan)
     386             : #endif
     387             : #ifdef _OPENACC 
     388             :       case(cufft_const)
     389             :          ierr = cufftDestroy(fft%cufft_plan)
     390             :          if(ierr /= 0) call juDFT_error("cufftdestroy failed")
     391             : #endif
     392             :       case(mklFFT_const)
     393             : #ifdef CPP_FFT_MKL
     394             :          do i=1,size(fft%container)
     395             :             ierr = DftiFreeDescriptor(fft%container(i)%dfti_handle)
     396             :          enddo
     397             :          deallocate(fft%container)
     398             : #endif
     399             : #ifdef CPP_SPFFT
     400             :       case(spFFT_const)
     401             :          ierr = spfft_transform_destroy(fft%transform)
     402             :          IF (ierr /= SPFFT_SUCCESS) CALL juDFT_error("Error in destroying spFFT fft%transform! (1)")
     403             :          fft%transform    = c_null_ptr
     404             :          fft%realSpacePtr = c_null_ptr
     405             : #endif
     406             :       case default
     407             :          
     408             :       end select
     409             : 
     410      191863 :       fft%initialized = .False.
     411      191863 :       fft%backend    = -1
     412      767452 :       fft%length     = [-1,-1,-1]
     413      191863 :       fft%batch_size = -1
     414      191863 :    end subroutine t_fft_free
     415           0 : end module m_types_fft

Generated by: LCOV version 1.14