Line data Source code
1 : !--------------------------------------------------------------------------------
2 : ! Copyright (c) 2016 Peter Grünberg Institut, Forschungszentrum Jülich, Germany
3 : ! This file is part of FLEUR and available as free software under the conditions
4 : ! of the MIT license as expressed in the LICENSE file in more detail.
5 : !--------------------------------------------------------------------------------
6 : module m_cuda_diag
7 : use m_types_mat
8 : use m_types_mpimat
9 : use m_judft
10 : #ifdef CPP_CUSOLVER
11 : use cusolverDn
12 : #endif
13 : use m_types_solver
14 : implicit none
15 : !**********************************************************
16 : ! Solve the generalized eigenvalue problem
17 : ! using the cusolver library
18 : !**********************************************************
19 : type, extends(t_solver)::t_solver_cuda
20 : contains
21 : procedure :: solve_gev => cuda_GEV
22 : end type
23 : public :: get_solver_cuda
24 :
25 : #ifdef CPP_CUSOLVER
26 : type(cusolverDnHandle) :: handle
27 : #endif
28 :
29 : contains
30 :
31 97 : function get_solver_cuda() result(solver)
32 : type(t_solver_cuda), pointer::solver
33 97 : allocate (solver)
34 97 : solver%name = "cuda"
35 : #ifdef CPP_CUSOLVER
36 : solver%available = .true.
37 : #else
38 97 : solver%available = .false.
39 : #endif
40 97 : solver%parallel = .false.
41 97 : solver%serial = .true.
42 97 : solver%generalized = .true.
43 97 : solver%standard = .false.
44 97 : solver%single_precision = .false.
45 97 : solver%transform = .false.
46 97 : solver%GPU = .true.
47 97 : end function
48 :
49 0 : subroutine cuda_gev(self, hmat, smat, ne, eig, zmat, ikpt)
50 : !!Simple driver to solve Generalized Eigenvalue Problem using CuSolverDN
51 : implicit none
52 : class(t_solver_cuda) ::self
53 : class(t_mat), intent(INOUT) :: hmat, smat
54 : integer, intent(INOUT) :: ne
55 : class(t_mat), allocatable, intent(OUT) :: zmat
56 : real, intent(OUT) :: eig(:)
57 : integer, intent(IN) :: ikpt
58 :
59 : #ifdef CPP_CUSOLVER
60 : integer :: istat, ne_found, lwork_d, devinfo(1)
61 : real, allocatable :: work_d(:), eig_tmp(:)
62 : complex, allocatable :: work_c(:)
63 :
64 : logical :: firstcall = .true.
65 : if (firstcall) then
66 : firstcall = .false.
67 : istat = cusolverDnCreate(handle)
68 : if (istat /= CUSOLVER_STATUS_SUCCESS) call judft_error('handle creation failed')
69 : end if
70 :
71 : allocate (t_mat::zmat)
72 : allocate (eig_tmp(hmat%matsize1))
73 : call zmat%alloc(hmat%l_real, hmat%matsize1, ne)
74 : !!$acc Data copyin(hmat,smat)
75 : if (hmat%l_real) then
76 : associate (h => hmat%data_r, s => smat%data_r)
77 : !$ACC DATA copyin(s)COPY(h)COPYOUT(eig_tmp)
78 : !$ACC HOST_DATA USE_DEVICE(s,h,eig_tmp)
79 : istat = cusolverDnDsygvdx_bufferSize(handle, CUSOLVER_EIG_TYPE_1, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, &
80 : CUBLAS_FILL_MODE_UPPER, hmat%matsize1, h, hmat%matsize1, &
81 : s, smat%matsize1, 0.0, 0.0, 1, ne, ne_found, eig_tmp, lwork_d)
82 : !$acc end host_data
83 : if (istat /= CUSOLVER_STATUS_SUCCESS) call judft_error('cusolverDnZhegvdx_buffersize failed')
84 : allocate (work_d(lwork_d))
85 : !$ACC DATA create(work_d,devinfo)
86 : !$ACC HOST_DATA USE_DEVICE(s,h,eig_tmp,work_d,devinfo)
87 : istat = cusolverDnDsygvdx(handle, CUSOLVER_EIG_TYPE_1, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, &
88 : CUBLAS_FILL_MODE_UPPER, hmat%matsize1, h, hmat%matsize1, &
89 : s, smat%matsize1, 0.0, 0.0, 1, ne, ne_found, eig_tmp, work_d, lwork_d, devinfo(1))
90 : !$ACC END HOST_DATA
91 : !$ACC END DATA
92 : !$ACC END DATA
93 : if (istat /= CUSOLVER_STATUS_SUCCESS) call judft_error('cusolverDnZhegvdx failed')
94 : ne = ne_found
95 : call zmat%alloc(hmat%l_real, hmat%matsize1, ne_found)
96 : zmat%data_r = h(:, :ne_found)
97 : eig = eig_tmp(:ne)
98 : end associate
99 : else
100 : associate (h => hmat%data_c, s => smat%data_c)
101 : !$ACC DATA copyin(s) COPY(h) COPYOUT(eig_tmp)
102 : !$ACC HOST_DATA USE_DEVICE(s,h,eig_tmp)
103 : istat = cusolverDnZhegvdx_bufferSize(handle, CUSOLVER_EIG_TYPE_1, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, &
104 : CUBLAS_FILL_MODE_UPPER, hmat%matsize1, h, hmat%matsize1, &
105 : s, smat%matsize1, 0.0, 0.0, 1, ne, ne_found, eig_tmp, lwork_d)
106 : !$acc end host_data
107 : if (istat /= CUSOLVER_STATUS_SUCCESS) write (*, *) 'cusolverDnZhegvdx_buffersize failed'
108 : allocate (work_c(lwork_d))
109 : !$ACC DATA create(work_c,devinfo)
110 : !$ACC HOST_DATA USE_DEVICE(s,h,eig_tmp,work_c,devinfo)
111 : istat = cusolverDnZhegvdx(handle, CUSOLVER_EIG_TYPE_1, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, &
112 : CUBLAS_FILL_MODE_UPPER, hmat%matsize1, h, hmat%matsize1, &
113 : s, smat%matsize1, 0.0, 0.0, 1, ne, ne_found, eig_tmp, work_c, lwork_d, devinfo(1))
114 : !$ACC END HOST_DATA
115 : !$acc update self(devinfo)
116 : if (istat /= CUSOLVER_STATUS_SUCCESS) then
117 : write (*, *) devinfo
118 : call judft_error('cusolverDnZhegvdx failed')
119 : end if
120 : !$ACC END DATA
121 : !$ACC END DATA
122 : ne = ne_found
123 : call zmat%alloc(hmat%l_real, hmat%matsize1, ne_found)
124 : zmat%data_c = h(:, :ne_found)
125 : eig = eig_tmp(:ne)
126 :
127 : end associate
128 : end if
129 : #endif
130 :
131 0 : end subroutine
132 :
133 97 : end module m_cuda_diag
|