5
5
#include " module_base/global_variable.h"
6
6
#endif
7
7
8
+ #ifdef __CUDA
9
+ #include < base/macros/macros.h>
10
+ #include < cuda_runtime.h>
11
+ #include < thrust/complex.h>
12
+ #include < thrust/execution_policy.h>
13
+ #include < thrust/inner_product.h>
14
+ #include " module_base/tool_quit.h"
15
+
16
+ #include " cublas_v2.h"
17
+
18
+ namespace BlasUtils {
19
+
20
+ static cublasHandle_t cublas_handle = nullptr ;
21
+
22
+ void createGpuBlasHandle (){
23
+ if (cublas_handle == nullptr ) {
24
+ cublasErrcheck (cublasCreate (&cublas_handle));
25
+ }
26
+ }
27
+
28
+ void destoryBLAShandle (){
29
+ if (cublas_handle != nullptr ) {
30
+ cublasErrcheck (cublasDestroy (cublas_handle));
31
+ cublas_handle = nullptr ;
32
+ }
33
+ }
34
+
35
+
36
+ cublasOperation_t judge_trans (bool is_complex, const char & trans, const char * name)
37
+ {
38
+ if (trans == ' N' )
39
+ {
40
+ return CUBLAS_OP_N;
41
+ }
42
+ else if (trans == ' T' )
43
+ {
44
+ return CUBLAS_OP_T;
45
+ }
46
+ else if (is_complex && trans == ' C' )
47
+ {
48
+ return CUBLAS_OP_C;
49
+ }
50
+ return CUBLAS_OP_N;
51
+ }
52
+
53
+ } // namespace BlasUtils
54
+
55
+ #endif
56
+
8
57
void BlasConnector::axpy ( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
9
58
{
10
59
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
11
60
saxpy_ (&n, &alpha, X, &incX, Y, &incY);
12
- }
61
+ }
62
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
63
+ #ifdef __CUDA
64
+ cublasErrcheck (cublasSaxpy (BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
65
+ #endif
66
+ }
13
67
}
14
68
15
69
void BlasConnector::axpy ( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
16
70
{
17
71
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
18
72
daxpy_ (&n, &alpha, X, &incX, Y, &incY);
19
- }
73
+ }
74
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
75
+ #ifdef __CUDA
76
+ cublasErrcheck (cublasDaxpy (BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
77
+ #endif
78
+ }
20
79
}
21
80
22
81
void BlasConnector::axpy ( const int n, const std::complex<float > alpha, const std::complex<float > *X, const int incX, std::complex<float > *Y, const int incY, base_device::AbacusDevice_t device_type)
23
82
{
24
83
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
25
84
caxpy_ (&n, &alpha, X, &incX, Y, &incY);
26
- }
85
+ }
86
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
87
+ #ifdef __CUDA
88
+ cublasErrcheck (cublasCaxpy (BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
89
+ #endif
90
+ }
27
91
}
28
92
29
93
void BlasConnector::axpy ( const int n, const std::complex<double > alpha, const std::complex<double > *X, const int incX, std::complex<double > *Y, const int incY, base_device::AbacusDevice_t device_type)
30
94
{
31
95
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
32
96
zaxpy_ (&n, &alpha, X, &incX, Y, &incY);
33
- }
97
+ }
98
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
99
+ #ifdef __CUDA
100
+ cublasErrcheck (cublasZaxpy (BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
101
+ #endif
102
+ }
34
103
}
35
104
36
105
@@ -39,28 +108,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
39
108
{
40
109
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
41
110
sscal_ (&n, &alpha, X, &incX);
42
- }
111
+ }
112
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
113
+ #ifdef __CUDA
114
+ cublasErrcheck (cublasSscal (BlasUtils::cublas_handle, n, &alpha, X, incX));
115
+ #endif
116
+ }
43
117
}
44
118
45
119
void BlasConnector::scal ( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
46
120
{
47
121
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
48
122
dscal_ (&n, &alpha, X, &incX);
49
- }
123
+ }
124
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
125
+ #ifdef __CUDA
126
+ cublasErrcheck (cublasDscal (BlasUtils::cublas_handle, n, &alpha, X, incX));
127
+ #endif
128
+ }
50
129
}
51
130
52
131
void BlasConnector::scal ( const int n, const std::complex<float > alpha, std::complex<float > *X, const int incX, base_device::AbacusDevice_t device_type)
53
132
{
54
133
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
55
134
cscal_ (&n, &alpha, X, &incX);
56
- }
135
+ }
136
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
137
+ #ifdef __CUDA
138
+ cublasErrcheck (cublasCscal (BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
139
+ #endif
140
+ }
57
141
}
58
142
59
143
void BlasConnector::scal ( const int n, const std::complex<double > alpha, std::complex<double > *X, const int incX, base_device::AbacusDevice_t device_type)
60
144
{
61
145
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
62
146
zscal_ (&n, &alpha, X, &incX);
63
- }
147
+ }
148
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
149
+ #ifdef __CUDA
150
+ cublasErrcheck (cublasZscal (BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
151
+ #endif
152
+ }
64
153
}
65
154
66
155
@@ -70,6 +159,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
70
159
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
71
160
return sdot_ (&n, X, &incX, Y, &incY);
72
161
}
162
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
163
+ #ifdef __CUDA
164
+ float result = 0.0 ;
165
+ cublasErrcheck (cublasSdot (BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
166
+ return result;
167
+ #endif
168
+ }
73
169
return sdot_ (&n, X, &incX, Y, &incY);
74
170
}
75
171
@@ -78,6 +174,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
78
174
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
79
175
return ddot_ (&n, X, &incX, Y, &incY);
80
176
}
177
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
178
+ #ifdef __CUDA
179
+ double result = 0.0 ;
180
+ cublasErrcheck (cublasDdot (BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
181
+ return result;
182
+ #endif
183
+ }
81
184
return ddot_ (&n, X, &incX, Y, &incY);
82
185
}
83
186
@@ -92,13 +195,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
92
195
&alpha, b, &ldb, a, &lda,
93
196
&beta, c, &ldc);
94
197
}
95
- #ifdef __DSP
198
+ #ifdef __DSP
96
199
else if (device_type == base_device::AbacusDevice_t::DspDevice){
97
200
sgemm_mth_ (&transb, &transa, &n, &m, &k,
98
201
&alpha, b, &ldb, a, &lda,
99
202
&beta, c, &ldc, GlobalV::MY_RANK);
100
203
}
101
- #endif
204
+ #endif
205
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
206
+ #ifdef __CUDA
207
+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
208
+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
209
+ cublasErrcheck (cublasSgemm (BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
210
+ #endif
211
+ }
102
212
}
103
213
104
214
void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -110,13 +220,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
110
220
&alpha, b, &ldb, a, &lda,
111
221
&beta, c, &ldc);
112
222
}
113
- #ifdef __DSP
223
+ #ifdef __DSP
114
224
else if (device_type == base_device::AbacusDevice_t::DspDevice){
115
225
dgemm_mth_ (&transb, &transa, &n, &m, &k,
116
226
&alpha, b, &ldb, a, &lda,
117
227
&beta, c, &ldc, GlobalV::MY_RANK);
118
228
}
119
- #endif
229
+ #endif
230
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
231
+ #ifdef __CUDA
232
+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
233
+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
234
+ cublasErrcheck (cublasDgemm (BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
235
+ #endif
236
+ }
120
237
}
121
238
122
239
void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -128,13 +245,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
128
245
&alpha, b, &ldb, a, &lda,
129
246
&beta, c, &ldc);
130
247
}
131
- #ifdef __DSP
248
+ #ifdef __DSP
132
249
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
133
250
cgemm_mth_ (&transb, &transa, &n, &m, &k,
134
251
&alpha, b, &ldb, a, &lda,
135
252
&beta, c, &ldc, GlobalV::MY_RANK);
136
253
}
137
- #endif
254
+ #endif
255
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
256
+ #ifdef __CUDA
257
+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
258
+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
259
+ cublasErrcheck (cublasCgemm (BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (float2*)&alpha, (float2*)b, ldb, (float2*)a, lda, (float2*)&beta, (float2*)c, ldc));
260
+ #endif
261
+ }
138
262
}
139
263
140
264
void BlasConnector::gemm (const char transa, const char transb, const int m, const int n, const int k,
@@ -146,13 +270,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
146
270
&alpha, b, &ldb, a, &lda,
147
271
&beta, c, &ldc);
148
272
}
149
- #ifdef __DSP
273
+ #ifdef __DSP
150
274
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
151
275
zgemm_mth_ (&transb, &transa, &n, &m, &k,
152
276
&alpha, b, &ldb, a, &lda,
153
277
&beta, c, &ldc, GlobalV::MY_RANK);
154
278
}
155
- #endif
279
+ #endif
280
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
281
+ #ifdef __CUDA
282
+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
283
+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
284
+ cublasErrcheck (cublasZgemm (BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (double2*)&alpha, (double2*)b, ldb, (double2*)a, lda, (double2*)&beta, (double2*)c, ldc));
285
+ #endif
286
+ }
156
287
}
157
288
158
289
// Col-Major part
@@ -165,13 +296,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
165
296
&alpha, a, &lda, b, &ldb,
166
297
&beta, c, &ldc);
167
298
}
168
- #ifdef __DSP
299
+ #ifdef __DSP
169
300
else if (device_type == base_device::AbacusDevice_t::DspDevice){
170
301
sgemm_mth_ (&transb, &transa, &m, &n, &k,
171
302
&alpha, a, &lda, b, &ldb,
172
303
&beta, c, &ldc, GlobalV::MY_RANK);
173
304
}
174
- #endif
305
+ #endif
306
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
307
+ #ifdef __CUDA
308
+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
309
+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
310
+ cublasErrcheck (cublasSgemm (BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
311
+ #endif
312
+ }
175
313
}
176
314
177
315
void BlasConnector::gemm_cm (const char transa, const char transb, const int m, const int n, const int k,
@@ -183,13 +321,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
183
321
&alpha, a, &lda, b, &ldb,
184
322
&beta, c, &ldc);
185
323
}
186
- #ifdef __DSP
324
+ #ifdef __DSP
187
325
else if (device_type == base_device::AbacusDevice_t::DspDevice){
188
326
dgemm_mth_ (&transa, &transb, &m, &n, &k,
189
327
&alpha, a, &lda, b, &ldb,
190
328
&beta, c, &ldc, GlobalV::MY_RANK);
191
329
}
192
- #endif
330
+ #endif
331
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
332
+ #ifdef __CUDA
333
+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
334
+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
335
+ cublasErrcheck (cublasDgemm (BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
336
+ #endif
337
+ }
193
338
}
194
339
195
340
void BlasConnector::gemm_cm (const char transa, const char transb, const int m, const int n, const int k,
@@ -201,13 +346,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
201
346
&alpha, a, &lda, b, &ldb,
202
347
&beta, c, &ldc);
203
348
}
204
- #ifdef __DSP
349
+ #ifdef __DSP
205
350
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
206
351
cgemm_mth_ (&transa, &transb, &m, &n, &k,
207
352
&alpha, a, &lda, b, &ldb,
208
353
&beta, c, &ldc, GlobalV::MY_RANK);
209
354
}
210
- #endif
355
+ #endif
356
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
357
+ #ifdef __CUDA
358
+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
359
+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
360
+ cublasErrcheck (cublasCgemm (BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
361
+ #endif
362
+ }
211
363
}
212
364
213
365
void BlasConnector::gemm_cm (const char transa, const char transb, const int m, const int n, const int k,
@@ -219,13 +371,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
219
371
&alpha, a, &lda, b, &ldb,
220
372
&beta, c, &ldc);
221
373
}
222
- #ifdef __DSP
374
+ #ifdef __DSP
223
375
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
224
376
zgemm_mth_ (&transa, &transb, &m, &n, &k,
225
377
&alpha, a, &lda, b, &ldb,
226
378
&beta, c, &ldc, GlobalV::MY_RANK);
227
379
}
228
- #endif
380
+ #endif
381
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
382
+ #ifdef __CUDA
383
+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , transa, " gemm_op" );
384
+ cublasOperation_t cutransB = BlasUtils::judge_trans (false , transb, " gemm_op" );
385
+ cublasErrcheck (cublasZgemm (BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
386
+ #endif
387
+ }
229
388
}
230
389
231
390
// Symm and Hemm part. Only col-major is supported.
0 commit comments