LCOV - code coverage report
Current view: top level - diagonalization - cuda.F90 (source / functions) Coverage Total Hit
Test: FLEUR test coverage Lines: 86.7 % 15 13
Test Date: 2026-03-18 04:40:43 Functions: 66.7 % 3 2

            Line data    Source code
       1              : !--------------------------------------------------------------------------------
       2              : ! Copyright (c) 2016 Peter Grünberg Institut, Forschungszentrum Jülich, Germany
       3              : ! This file is part of FLEUR and available as free software under the conditions
       4              : ! of the MIT license as expressed in the LICENSE file in more detail.
       5              : !--------------------------------------------------------------------------------
       6              : module m_cuda_diag
       7              :    use m_types_mat
       8              :    use m_types_mpimat
       9              :    use m_judft
      10              : #ifdef CPP_CUSOLVER
      11              :    use cusolverDn
      12              : #endif
      13              :    use m_types_solver
      14              :    implicit none
      15              : !**********************************************************
      16              : !     Solve the generalized eigenvalue problem
      17              : !     using the cusolver library
      18              : !**********************************************************
      19              :    type, extends(t_solver)::t_solver_cuda
      20              :    contains
      21              :       procedure        :: solve_gev => cuda_GEV
      22              :    end type
      23              :    public :: get_solver_cuda
      24              : 
      25              : #ifdef CPP_CUSOLVER
      26              :    type(cusolverDnHandle)  :: handle
      27              : #endif
      28              : 
      29              : contains
      30              : 
      31           97 :    function get_solver_cuda() result(solver)
      32              :       type(t_solver_cuda), pointer::solver
      33           97 :       allocate (solver)
      34           97 :       solver%name = "cuda"
      35              : #ifdef CPP_CUSOLVER
      36              :       solver%available = .true.
      37              : #else
      38           97 :       solver%available = .false.
      39              : #endif
      40           97 :       solver%parallel = .false.
      41           97 :       solver%serial = .true.
      42           97 :       solver%generalized = .true.
      43           97 :       solver%standard = .false.
      44           97 :       solver%single_precision = .false.
      45           97 :       solver%transform = .false.
      46           97 :       solver%GPU = .true.
      47           97 :    end function
      48              : 
      49            0 :    subroutine cuda_gev(self, hmat, smat, ne, eig, zmat, ikpt)
      50              :     !!Simple driver to solve Generalized Eigenvalue Problem using CuSolverDN
      51              :       implicit none
      52              :       class(t_solver_cuda) ::self
      53              :       class(t_mat), intent(INOUT) :: hmat, smat
      54              :       integer, intent(INOUT)      :: ne
      55              :       class(t_mat), allocatable, intent(OUT)    :: zmat
      56              :       real, intent(OUT)           :: eig(:)
      57              :       integer, intent(IN)         :: ikpt
      58              : 
      59              : #ifdef CPP_CUSOLVER
      60              :       integer                 :: istat, ne_found, lwork_d, devinfo(1)
      61              :       real, allocatable        :: work_d(:), eig_tmp(:)
      62              :       complex, allocatable     :: work_c(:)
      63              : 
      64              :       logical :: firstcall = .true.
      65              :       if (firstcall) then
      66              :          firstcall = .false.
      67              :          istat = cusolverDnCreate(handle)
      68              :          if (istat /= CUSOLVER_STATUS_SUCCESS) call judft_error('handle creation failed')
      69              :       end if
      70              : 
      71              :       allocate (t_mat::zmat)
      72              :       allocate (eig_tmp(hmat%matsize1))
      73              :       call zmat%alloc(hmat%l_real, hmat%matsize1, ne)
      74              :     !!$acc Data copyin(hmat,smat)
      75              :       if (hmat%l_real) then
      76              :          associate (h => hmat%data_r, s => smat%data_r)
      77              :             !$ACC DATA copyin(s)COPY(h)COPYOUT(eig_tmp)
      78              :             !$ACC HOST_DATA USE_DEVICE(s,h,eig_tmp)
      79              :             istat = cusolverDnDsygvdx_bufferSize(handle, CUSOLVER_EIG_TYPE_1, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, &
      80              :                                                  CUBLAS_FILL_MODE_UPPER, hmat%matsize1, h, hmat%matsize1, &
      81              :                                                  s, smat%matsize1, 0.0, 0.0, 1, ne, ne_found, eig_tmp, lwork_d)
      82              :             !$acc end host_data
      83              :             if (istat /= CUSOLVER_STATUS_SUCCESS) call judft_error('cusolverDnZhegvdx_buffersize failed')
      84              :             allocate (work_d(lwork_d))
      85              :             !$ACC DATA create(work_d,devinfo)
      86              :             !$ACC HOST_DATA USE_DEVICE(s,h,eig_tmp,work_d,devinfo)
      87              :             istat = cusolverDnDsygvdx(handle, CUSOLVER_EIG_TYPE_1, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, &
      88              :                                       CUBLAS_FILL_MODE_UPPER, hmat%matsize1, h, hmat%matsize1, &
      89              :                                       s, smat%matsize1, 0.0, 0.0, 1, ne, ne_found, eig_tmp, work_d, lwork_d, devinfo(1))
      90              :             !$ACC END HOST_DATA
      91              :             !$ACC END DATA
      92              :             !$ACC END DATA
      93              :             if (istat /= CUSOLVER_STATUS_SUCCESS) call judft_error('cusolverDnZhegvdx failed')
      94              :             ne = ne_found
      95              :             call zmat%alloc(hmat%l_real, hmat%matsize1, ne_found)
      96              :             zmat%data_r = h(:, :ne_found)
      97              :             eig = eig_tmp(:ne)
      98              :          end associate
      99              :       else
     100              :          associate (h => hmat%data_c, s => smat%data_c)
     101              :             !$ACC DATA copyin(s) COPY(h) COPYOUT(eig_tmp)
     102              :             !$ACC HOST_DATA USE_DEVICE(s,h,eig_tmp)
     103              :             istat = cusolverDnZhegvdx_bufferSize(handle, CUSOLVER_EIG_TYPE_1, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, &
     104              :                                                  CUBLAS_FILL_MODE_UPPER, hmat%matsize1, h, hmat%matsize1, &
     105              :                                                  s, smat%matsize1, 0.0, 0.0, 1, ne, ne_found, eig_tmp, lwork_d)
     106              :             !$acc end host_data
     107              :             if (istat /= CUSOLVER_STATUS_SUCCESS) write (*, *) 'cusolverDnZhegvdx_buffersize failed'
     108              :             allocate (work_c(lwork_d))
     109              :             !$ACC DATA create(work_c,devinfo)
     110              :             !$ACC HOST_DATA USE_DEVICE(s,h,eig_tmp,work_c,devinfo)
     111              :             istat = cusolverDnZhegvdx(handle, CUSOLVER_EIG_TYPE_1, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, &
     112              :                                       CUBLAS_FILL_MODE_UPPER, hmat%matsize1, h, hmat%matsize1, &
     113              :                                       s, smat%matsize1, 0.0, 0.0, 1, ne, ne_found, eig_tmp, work_c, lwork_d, devinfo(1))
     114              :             !$ACC END HOST_DATA
     115              :             !$acc update self(devinfo)
     116              :             if (istat /= CUSOLVER_STATUS_SUCCESS) then
     117              :                write (*, *) devinfo
     118              :                call judft_error('cusolverDnZhegvdx failed')
     119              :             end if
     120              :             !$ACC END DATA
     121              :             !$ACC END DATA
     122              :             ne = ne_found
     123              :             call zmat%alloc(hmat%l_real, hmat%matsize1, ne_found)
     124              :             zmat%data_c = h(:, :ne_found)
     125              :             eig = eig_tmp(:ne)
     126              : 
     127              :          end associate
     128              :       end if
     129              : #endif
     130              : 
     131            0 :    end subroutine
     132              : 
     133           97 : end module m_cuda_diag
        

Generated by: LCOV version 2.0-1