Skip to content

Commit 9ab9150

Browse files
[Feature] Add some GPU kernels to blas_connector (#5799)
* Initial commit * Modify CMakeLists * Complete CMakeLists in module_base * Add blas_connector.cpp definition * Fix module_base tests * Fix tests failure * fix opt_test * OPTFIX2 * Return all changes * Fix global_func_text * Fix MPI Bug * return base_math_chebyshev * Fix MPI bug * Finish * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 414446a commit 9ab9150

11 files changed

+291
-115
lines changed

source/module_base/blas_connector.cpp

Lines changed: 183 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,101 @@
55
#include "module_base/global_variable.h"
66
#endif
77

8+
#ifdef __CUDA
9+
#include <base/macros/macros.h>
10+
#include <cuda_runtime.h>
11+
#include <thrust/complex.h>
12+
#include <thrust/execution_policy.h>
13+
#include <thrust/inner_product.h>
14+
#include "module_base/tool_quit.h"
15+
16+
#include "cublas_v2.h"
17+
18+
namespace BlasUtils{
19+
20+
static cublasHandle_t cublas_handle = nullptr;
21+
22+
void createGpuBlasHandle(){
23+
if (cublas_handle == nullptr) {
24+
cublasErrcheck(cublasCreate(&cublas_handle));
25+
}
26+
}
27+
28+
void destoryBLAShandle(){
29+
if (cublas_handle != nullptr) {
30+
cublasErrcheck(cublasDestroy(cublas_handle));
31+
cublas_handle = nullptr;
32+
}
33+
}
34+
35+
36+
cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name)
37+
{
38+
if (trans == 'N')
39+
{
40+
return CUBLAS_OP_N;
41+
}
42+
else if(trans == 'T')
43+
{
44+
return CUBLAS_OP_T;
45+
}
46+
else if(is_complex && trans == 'C')
47+
{
48+
return CUBLAS_OP_C;
49+
}
50+
return CUBLAS_OP_N;
51+
}
52+
53+
} // namespace BlasUtils
54+
55+
#endif
56+
857
void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
958
{
1059
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
1160
saxpy_(&n, &alpha, X, &incX, Y, &incY);
12-
}
61+
}
62+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
63+
#ifdef __CUDA
64+
cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
65+
#endif
66+
}
1367
}
1468

1569
void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
1670
{
1771
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
1872
daxpy_(&n, &alpha, X, &incX, Y, &incY);
19-
}
73+
}
74+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
75+
#ifdef __CUDA
76+
cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
77+
#endif
78+
}
2079
}
2180

2281
void BlasConnector::axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY, base_device::AbacusDevice_t device_type)
2382
{
2483
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
2584
caxpy_(&n, &alpha, X, &incX, Y, &incY);
26-
}
85+
}
86+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
87+
#ifdef __CUDA
88+
cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
89+
#endif
90+
}
2791
}
2892

2993
void BlasConnector::axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY, base_device::AbacusDevice_t device_type)
3094
{
3195
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
3296
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
33-
}
97+
}
98+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
99+
#ifdef __CUDA
100+
cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
101+
#endif
102+
}
34103
}
35104

36105

@@ -39,28 +108,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
39108
{
40109
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
41110
sscal_(&n, &alpha, X, &incX);
42-
}
111+
}
112+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
113+
#ifdef __CUDA
114+
cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
115+
#endif
116+
}
43117
}
44118

45119
void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
46120
{
47121
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
48122
dscal_(&n, &alpha, X, &incX);
49-
}
123+
}
124+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
125+
#ifdef __CUDA
126+
cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
127+
#endif
128+
}
50129
}
51130

52131
void BlasConnector::scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX, base_device::AbacusDevice_t device_type)
53132
{
54133
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
55134
cscal_(&n, &alpha, X, &incX);
56-
}
135+
}
136+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
137+
#ifdef __CUDA
138+
cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
139+
#endif
140+
}
57141
}
58142

59143
void BlasConnector::scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX, base_device::AbacusDevice_t device_type)
60144
{
61145
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
62146
zscal_(&n, &alpha, X, &incX);
63-
}
147+
}
148+
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
149+
#ifdef __CUDA
150+
cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
151+
#endif
152+
}
64153
}
65154

66155

@@ -70,6 +159,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
70159
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
71160
return sdot_(&n, X, &incX, Y, &incY);
72161
}
162+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
163+
#ifdef __CUDA
164+
float result = 0.0;
165+
cublasErrcheck(cublasSdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
166+
return result;
167+
#endif
168+
}
73169
return sdot_(&n, X, &incX, Y, &incY);
74170
}
75171

@@ -78,6 +174,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
78174
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
79175
return ddot_(&n, X, &incX, Y, &incY);
80176
}
177+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
178+
#ifdef __CUDA
179+
double result = 0.0;
180+
cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
181+
return result;
182+
#endif
183+
}
81184
return ddot_(&n, X, &incX, Y, &incY);
82185
}
83186

@@ -92,13 +195,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
92195
&alpha, b, &ldb, a, &lda,
93196
&beta, c, &ldc);
94197
}
95-
#ifdef __DSP
198+
#ifdef __DSP
96199
else if (device_type == base_device::AbacusDevice_t::DspDevice){
97200
sgemm_mth_(&transb, &transa, &n, &m, &k,
98201
&alpha, b, &ldb, a, &lda,
99202
&beta, c, &ldc, GlobalV::MY_RANK);
100203
}
101-
#endif
204+
#endif
205+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
206+
#ifdef __CUDA
207+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
208+
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
209+
cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
210+
#endif
211+
}
102212
}
103213

104214
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -110,13 +220,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
110220
&alpha, b, &ldb, a, &lda,
111221
&beta, c, &ldc);
112222
}
113-
#ifdef __DSP
223+
#ifdef __DSP
114224
else if (device_type == base_device::AbacusDevice_t::DspDevice){
115225
dgemm_mth_(&transb, &transa, &n, &m, &k,
116226
&alpha, b, &ldb, a, &lda,
117227
&beta, c, &ldc, GlobalV::MY_RANK);
118228
}
119-
#endif
229+
#endif
230+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
231+
#ifdef __CUDA
232+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
233+
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
234+
cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
235+
#endif
236+
}
120237
}
121238

122239
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -128,13 +245,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
128245
&alpha, b, &ldb, a, &lda,
129246
&beta, c, &ldc);
130247
}
131-
#ifdef __DSP
248+
#ifdef __DSP
132249
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
133250
cgemm_mth_(&transb, &transa, &n, &m, &k,
134251
&alpha, b, &ldb, a, &lda,
135252
&beta, c, &ldc, GlobalV::MY_RANK);
136253
}
137-
#endif
254+
#endif
255+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
256+
#ifdef __CUDA
257+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
258+
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
259+
cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (float2*)&alpha, (float2*)b, ldb, (float2*)a, lda, (float2*)&beta, (float2*)c, ldc));
260+
#endif
261+
}
138262
}
139263

140264
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
@@ -146,13 +270,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
146270
&alpha, b, &ldb, a, &lda,
147271
&beta, c, &ldc);
148272
}
149-
#ifdef __DSP
273+
#ifdef __DSP
150274
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
151275
zgemm_mth_(&transb, &transa, &n, &m, &k,
152276
&alpha, b, &ldb, a, &lda,
153277
&beta, c, &ldc, GlobalV::MY_RANK);
154278
}
155-
#endif
279+
#endif
280+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
281+
#ifdef __CUDA
282+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
283+
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
284+
cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (double2*)&alpha, (double2*)b, ldb, (double2*)a, lda, (double2*)&beta, (double2*)c, ldc));
285+
#endif
286+
}
156287
}
157288

158289
// Col-Major part
@@ -165,13 +296,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
165296
&alpha, a, &lda, b, &ldb,
166297
&beta, c, &ldc);
167298
}
168-
#ifdef __DSP
299+
#ifdef __DSP
169300
else if (device_type == base_device::AbacusDevice_t::DspDevice){
170301
sgemm_mth_(&transb, &transa, &m, &n, &k,
171302
&alpha, a, &lda, b, &ldb,
172303
&beta, c, &ldc, GlobalV::MY_RANK);
173304
}
174-
#endif
305+
#endif
306+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
307+
#ifdef __CUDA
308+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
309+
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
310+
cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
311+
#endif
312+
}
175313
}
176314

177315
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
@@ -183,13 +321,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
183321
&alpha, a, &lda, b, &ldb,
184322
&beta, c, &ldc);
185323
}
186-
#ifdef __DSP
324+
#ifdef __DSP
187325
else if (device_type == base_device::AbacusDevice_t::DspDevice){
188326
dgemm_mth_(&transa, &transb, &m, &n, &k,
189327
&alpha, a, &lda, b, &ldb,
190328
&beta, c, &ldc, GlobalV::MY_RANK);
191329
}
192-
#endif
330+
#endif
331+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
332+
#ifdef __CUDA
333+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
334+
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
335+
cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
336+
#endif
337+
}
193338
}
194339

195340
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
@@ -201,13 +346,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
201346
&alpha, a, &lda, b, &ldb,
202347
&beta, c, &ldc);
203348
}
204-
#ifdef __DSP
349+
#ifdef __DSP
205350
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
206351
cgemm_mth_(&transa, &transb, &m, &n, &k,
207352
&alpha, a, &lda, b, &ldb,
208353
&beta, c, &ldc, GlobalV::MY_RANK);
209354
}
210-
#endif
355+
#endif
356+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
357+
#ifdef __CUDA
358+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
359+
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
360+
cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
361+
#endif
362+
}
211363
}
212364

213365
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
@@ -219,13 +371,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
219371
&alpha, a, &lda, b, &ldb,
220372
&beta, c, &ldc);
221373
}
222-
#ifdef __DSP
374+
#ifdef __DSP
223375
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
224376
zgemm_mth_(&transa, &transb, &m, &n, &k,
225377
&alpha, a, &lda, b, &ldb,
226378
&beta, c, &ldc, GlobalV::MY_RANK);
227379
}
228-
#endif
380+
#endif
381+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
382+
#ifdef __CUDA
383+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
384+
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
385+
cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
386+
#endif
387+
}
229388
}
230389

231390
// Symm and Hemm part. Only col-major is supported.

0 commit comments

Comments
 (0)