Line data Source code
1 : !--------------------------------------------------------------------------------
2 : ! Copyright (c) 2023 Peter Grünberg Institut, Forschungszentrum Jülich, Germany
3 : ! This file is part of FLEUR and available as free software under the conditions
4 : ! of the MIT license as expressed in the LICENSE file in more detail.
5 : !--------------------------------------------------------------------------------
6 :
7 : module m_matmul_dgemm
8 : #ifdef _OPENACC
9 : use openacc
10 : #endif
11 : #ifdef CPP_MAGMA
12 : use magma
13 : #endif
14 : #ifdef _CUDA
15 : use cublas
16 : #endif
17 : implicit none
18 : PRIVATE
19 : integer,PARAMETER:: blas_select=1,cublas_select=2,magma_select=3,magmablas_select=4
20 :
21 : interface blas_matmul
22 : module procedure:: blas_matmul_r,blas_matmul_c
23 : end interface
24 :
25 : public :: blas_matmul
26 :
27 : contains
28 49609 : subroutine blas_matmul_r(m,n,k,a,b,c,alpha,beta,op_a,op_b)
29 : INTEGER,INTENT(IN):: n,m,k
30 : REAL,INTENT(IN) :: a(:,:),b(:,:)
31 : REAL,INTENT(INOUT):: c(:,:)
32 : REAL,INTENT(IN),OPTIONAL:: alpha,beta
33 : CHARACTER,INTENT(IN),OPTIONAL :: op_a,op_b
34 :
35 : REAL :: alphaa,betaa
36 : CHARACTER :: op_aa,op_bb
37 : INTEGER :: lda,ldb,ldc
38 :
39 49609 : alphaa=1.0; betaa=0.0
40 49609 : if (present(alpha)) alphaa=alpha
41 49609 : if (present(beta)) betaa=beta
42 :
43 99218 : call priv_set_defaults(op_aa,op_bb,lda,ldb,ldc,m,n,k,a,b,c,op_a,op_b)
44 :
45 49609 : select case (priv_select_multiply_r(a,b,c))
46 : case (blas_select)
47 49609 : call dgemm(op_aa,op_bb,m,n,k, alphaa, a, lda, b,ldb,betaa, c, ldc)
48 : #ifdef _CUDA
49 : case (cublas_select)
50 : !$acc host_data use_device(a,b,c)
51 : call cublasDgemm(op_aa, op_bb, m, n, k, alphaa,a, lda, b, ldb, betaa, c, ldc)
52 : !$acc end host_data
53 : #endif
54 : #ifdef CPP_MAGMA
55 : case (magma_select)
56 : call magmaf_dgemm(op_aa, op_bb, m, n, k, alphaa,a, lda, b, ldb, betaa, c, ldc)
57 : case (magmablas_select)
58 : call magmablasf_dgemm(op_aa, op_bb, m, n, k, alphaa,a, lda, b, ldb, betaa, c, ldc)
59 : #endif
60 : end select
61 :
62 49609 : end subroutine
63 :
64 :
65 54930 : subroutine blas_matmul_c(m,n,k,a,b,c,alpha,beta,op_a,op_b)
66 : INTEGER,INTENT(IN):: n,m,k
67 : COMPLEX,INTENT(IN) :: a(:,:),b(:,:)
68 : COMPLEX,INTENT(INOUT):: c(:,:)
69 : COMPLEX,INTENT(IN),OPTIONAL:: alpha,beta
70 : CHARACTER,INTENT(IN),OPTIONAL :: op_a,op_b
71 :
72 : COMPLEX :: alphaa,betaa
73 : CHARACTER :: op_aa,op_bb
74 : INTEGER :: lda,ldb,ldc
75 :
76 54930 : alphaa=1.0; betaa=0.0
77 54930 : if (present(alpha)) alphaa=alpha
78 54930 : if (present(beta)) betaa=beta
79 :
80 109860 : call priv_set_defaults(op_aa,op_bb,lda,ldb,ldc,m,n,k,a,b,c,op_a,op_b)
81 :
82 54930 : select case (priv_select_multiply_c(a,b,c))
83 : case (blas_select)
84 274326 : call zgemm(op_aa,op_bb,m,n,k, alphaa, a, lda, b,ldb,betaa, c, ldc)
85 : #ifdef _CUDA
86 : case (cublas_select)
87 : !$acc host_data use_device(a,b,c)
88 : call cublaszgemm(op_aa, op_bb, m, n, k, alphaa,a, lda, b, ldb, betaa, c, ldc)
89 : !$acc end host_data
90 : #endif
91 : #ifdef CPP_MAGMA
92 : case (magma_select)
93 : call magmaf_zgemm(op_aa, op_bb, m, n, k, alphaa,a, lda, b, ldb, betaa, c, ldc)
94 : case (magmablas_select)
95 : call magmablasf_zgemm(op_aa, op_bb, m, n, k, alphaa,a, lda, b, ldb, betaa, c, ldc)
96 : #endif
97 : end select
98 :
99 54930 : end subroutine
100 :
101 :
102 104539 : subroutine priv_set_defaults(op_aa,op_bb,lda,ldb,ldc,m,n,k,a,b,c,op_a,op_b)
103 : INTEGER,INTENT(IN):: n,m,k
104 : CLASS(*),INTENT(IN) :: a(:,:),b(:,:),c(:,:)
105 : CHARACTER,INTENT(IN),OPTIONAL :: op_a,op_b
106 :
107 : CHARACTER,INTENT(OUT) :: op_aa,op_bb
108 : INTEGER,INTENT(OUT) :: lda,ldb,ldc
109 :
110 104539 : op_aa='N'; op_bb='N'
111 104539 : if (present(op_a)) op_aa=op_a
112 104539 : if (present(op_b)) op_bb=op_b
113 104539 : lda=size(a,1)
114 104539 : ldb=size(b,1)
115 104539 : ldc=size(c,1)
116 :
117 : END subroutine
118 :
119 : integer function priv_select_multiply_r(a,b,c)result(sel)
120 : REAL,INTENT(IN):: a(:,:),b(:,:),c(:,:)
121 :
122 : #ifdef _OPENACC
123 : if (acc_is_present(a).and.acc_is_present(b).and.acc_is_present(c)) THEN
124 : !All data on GPU
125 : #ifdef _CUDA
126 :
127 : sel=cublas_select;return
128 : #endif
129 : #ifdef CPP_MAGMA
130 : sel=magmablas_select; return
131 : #endif
132 : ENDIF
133 : #endif
134 49609 : sel=blas_select
135 : return
136 : end function
137 :
138 : integer function priv_select_multiply_c(a,b,c)result(sel)
139 : COMPLEX,INTENT(IN):: a(:,:),b(:,:),c(:,:)
140 54930 : sel=blas_select
141 : #ifdef _OPENACC
142 : if (acc_is_present(a).and.acc_is_present(b).and.acc_is_present(c)) THEN
143 : !All data on GPU
144 : #ifdef _CUDA
145 : sel=cublas_select
146 : #endif
147 : #ifdef CPP_MAGMA
148 : sel=magmablas_select
149 : #endif
150 : ENDIF
151 : #endif
152 :
153 : end function
154 :
155 0 : end module
|