Skip to content

Commit f4c792c

Browse files
authored
[Feature] Complete all kernels' GPU implement in blas_connector.cpp (#5833)
* initial commit * Fix compiling error * Fix trans comparison bug
1 parent beeb256 commit f4c792c

File tree

1 file changed

+125
-4
lines changed

1 file changed

+125
-4
lines changed

source/module_base/blas_connector.cpp

Lines changed: 125 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,36 @@ namespace BlasUtils{
5050
return CUBLAS_OP_N;
5151
}
5252

53+
cublasSideMode_t judge_side(const char& trans)
54+
{
55+
if (trans == 'L')
56+
{
57+
return CUBLAS_SIDE_LEFT;
58+
}
59+
else if (trans == 'R')
60+
{
61+
return CUBLAS_SIDE_RIGHT;
62+
}
63+
return CUBLAS_SIDE_LEFT;
64+
}
65+
66+
cublasFillMode_t judge_fill(const char& trans)
67+
{
68+
if (trans == 'F')
69+
{
70+
return CUBLAS_FILL_MODE_FULL;
71+
}
72+
else if (trans == 'U')
73+
{
74+
return CUBLAS_FILL_MODE_UPPER;
75+
}
76+
else if (trans == 'D')
77+
{
78+
return CUBLAS_FILL_MODE_LOWER;
79+
}
80+
return CUBLAS_FILL_MODE_FULL;
81+
}
82+
5383
} // namespace BlasUtils
5484

5585
#endif
@@ -398,6 +428,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
398428
&alpha, a, &lda, b, &ldb,
399429
&beta, c, &ldc);
400430
}
431+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
432+
#ifdef __CUDA
433+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
434+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
435+
cublasErrcheck(cublasSsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
436+
#endif
437+
}
401438
}
402439

403440
void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
@@ -409,6 +446,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
409446
&alpha, a, &lda, b, &ldb,
410447
&beta, c, &ldc);
411448
}
449+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
450+
#ifdef __CUDA
451+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
452+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
453+
cublasErrcheck(cublasDsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
454+
#endif
455+
}
412456
}
413457

414458
void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
@@ -420,6 +464,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
420464
&alpha, a, &lda, b, &ldb,
421465
&beta, c, &ldc);
422466
}
467+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
468+
#ifdef __CUDA
469+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
470+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
471+
cublasErrcheck(cublasCsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
472+
#endif
473+
}
423474
}
424475

425476
void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
@@ -431,6 +482,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
431482
&alpha, a, &lda, b, &ldb,
432483
&beta, c, &ldc);
433484
}
485+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
486+
#ifdef __CUDA
487+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
488+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
489+
cublasErrcheck(cublasZsymm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
490+
#endif
491+
}
434492
}
435493

436494
void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
@@ -442,6 +500,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
442500
&alpha, a, &lda, b, &ldb,
443501
&beta, c, &ldc);
444502
}
503+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
504+
#ifdef __CUDA
505+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
506+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
507+
cublasErrcheck(cublasChemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
508+
#endif
509+
}
445510
}
446511

447512
void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
@@ -453,6 +518,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
453518
&alpha, a, &lda, b, &ldb,
454519
&beta, c, &ldc);
455520
}
521+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
522+
#ifdef __CUDA
523+
cublasSideMode_t sideMode = BlasUtils::judge_side(side);
524+
cublasFillMode_t fillMode = BlasUtils::judge_fill(uplo);
525+
cublasErrcheck(cublasZhemm(BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
526+
#endif
527+
}
456528
}
457529

458530
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -461,7 +533,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
461533
{
462534
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
463535
sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
464-
}
536+
}
537+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
538+
#ifdef __CUDA
539+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op");
540+
cublasErrcheck(cublasSgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
541+
#endif
542+
}
465543
}
466544

467545
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -470,7 +548,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
470548
{
471549
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
472550
dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
473-
}
551+
}
552+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
553+
#ifdef __CUDA
554+
cublasOperation_t cutransA = BlasUtils::judge_trans(false, trans, "gemv_op");
555+
cublasErrcheck(cublasDgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
556+
#endif
557+
}
474558
}
475559

476560
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -479,7 +563,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
479563
{
480564
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
481565
cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
482-
}
566+
}
567+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
568+
#ifdef __CUDA
569+
cuFloatComplex alpha_cu = make_cuFloatComplex(alpha.real(), alpha.imag());
570+
cuFloatComplex beta_cu = make_cuFloatComplex(beta.real(), beta.imag());
571+
cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op");
572+
cublasErrcheck(cublasCgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy));
573+
#endif
574+
}
483575
}
484576

485577
void BlasConnector::gemv(const char trans, const int m, const int n,
@@ -488,7 +580,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
488580
{
489581
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
490582
zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
491-
}
583+
}
584+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
585+
#ifdef __CUDA
586+
cuDoubleComplex alpha_cu = make_cuDoubleComplex(alpha.real(), alpha.imag());
587+
cuDoubleComplex beta_cu = make_cuDoubleComplex(beta.real(), beta.imag());
588+
cublasOperation_t cutransA = BlasUtils::judge_trans(true, trans, "gemv_op");
589+
cublasErrcheck(cublasZgemv(BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy));
590+
#endif
591+
}
492592
}
493593

494594
// out = ||x||_2
@@ -497,6 +597,13 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
497597
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
498598
return snrm2_( &n, X, &incX );
499599
}
600+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
601+
#ifdef __CUDA
602+
float result = 0.0;
603+
cublasErrcheck(cublasSnrm2(BlasUtils::cublas_handle, n, X, incX, &result));
604+
return result;
605+
#endif
606+
}
500607
return snrm2_( &n, X, &incX );
501608
}
502609

@@ -506,6 +613,13 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
506613
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
507614
return dnrm2_( &n, X, &incX );
508615
}
616+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
617+
#ifdef __CUDA
618+
double result = 0.0;
619+
cublasErrcheck(cublasDnrm2(BlasUtils::cublas_handle, n, X, incX, &result));
620+
return result;
621+
#endif
622+
}
509623
return dnrm2_( &n, X, &incX );
510624
}
511625

@@ -515,6 +629,13 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
515629
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
516630
return dznrm2_( &n, X, &incX );
517631
}
632+
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
633+
#ifdef __CUDA
634+
double result = 0.0;
635+
cublasErrcheck(cublasDznrm2(BlasUtils::cublas_handle, n, (double2*)X, incX, &result));
636+
return result;
637+
#endif
638+
}
518639
return dznrm2_( &n, X, &incX );
519640
}
520641

0 commit comments

Comments
 (0)