@@ -50,6 +50,36 @@ namespace BlasUtils{
50
50
return CUBLAS_OP_N;
51
51
}
52
52
53
+ cublasSideMode_t judge_side (const char & trans)
54
+ {
55
+ if (trans == ' L' )
56
+ {
57
+ return CUBLAS_SIDE_LEFT;
58
+ }
59
+ else if (trans == ' R' )
60
+ {
61
+ return CUBLAS_SIDE_RIGHT;
62
+ }
63
+ return CUBLAS_SIDE_LEFT;
64
+ }
65
+
66
+ cublasFillMode_t judge_fill (const char & trans)
67
+ {
68
+ if (trans == ' F' )
69
+ {
70
+ return CUBLAS_FILL_MODE_FULL;
71
+ }
72
+ else if (trans == ' U' )
73
+ {
74
+ return CUBLAS_FILL_MODE_UPPER;
75
+ }
76
+ else if (trans == ' D' )
77
+ {
78
+ return CUBLAS_FILL_MODE_LOWER;
79
+ }
80
+ return CUBLAS_FILL_MODE_FULL;
81
+ }
82
+
53
83
} // namespace BlasUtils
54
84
55
85
#endif
@@ -398,6 +428,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
398
428
&alpha, a, &lda, b, &ldb,
399
429
&beta, c, &ldc);
400
430
}
431
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
432
+ #ifdef __CUDA
433
+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
434
+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
435
+ cublasErrcheck (cublasSsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
436
+ #endif
437
+ }
401
438
}
402
439
403
440
void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
@@ -409,6 +446,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
409
446
&alpha, a, &lda, b, &ldb,
410
447
&beta, c, &ldc);
411
448
}
449
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
450
+ #ifdef __CUDA
451
+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
452
+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
453
+ cublasErrcheck (cublasDsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc));
454
+ #endif
455
+ }
412
456
}
413
457
414
458
void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
@@ -420,6 +464,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
420
464
&alpha, a, &lda, b, &ldb,
421
465
&beta, c, &ldc);
422
466
}
467
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
468
+ #ifdef __CUDA
469
+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
470
+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
471
+ cublasErrcheck (cublasCsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
472
+ #endif
473
+ }
423
474
}
424
475
425
476
void BlasConnector::symm_cm (const char side, const char uplo, const int m, const int n,
@@ -431,6 +482,13 @@ void BlasConnector::symm_cm(const char side, const char uplo, const int m, const
431
482
&alpha, a, &lda, b, &ldb,
432
483
&beta, c, &ldc);
433
484
}
485
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
486
+ #ifdef __CUDA
487
+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
488
+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
489
+ cublasErrcheck (cublasZsymm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
490
+ #endif
491
+ }
434
492
}
435
493
436
494
void BlasConnector::hemm_cm (char side, char uplo, int m, int n,
@@ -442,6 +500,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
442
500
&alpha, a, &lda, b, &ldb,
443
501
&beta, c, &ldc);
444
502
}
503
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
504
+ #ifdef __CUDA
505
+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
506
+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
507
+ cublasErrcheck (cublasChemm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
508
+ #endif
509
+ }
445
510
}
446
511
447
512
void BlasConnector::hemm_cm (char side, char uplo, int m, int n,
@@ -453,6 +518,13 @@ void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
453
518
&alpha, a, &lda, b, &ldb,
454
519
&beta, c, &ldc);
455
520
}
521
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
522
+ #ifdef __CUDA
523
+ cublasSideMode_t sideMode = BlasUtils::judge_side (side);
524
+ cublasFillMode_t fillMode = BlasUtils::judge_fill (uplo);
525
+ cublasErrcheck (cublasZhemm (BlasUtils::cublas_handle, sideMode, fillMode, m, n, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
526
+ #endif
527
+ }
456
528
}
457
529
458
530
void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -461,7 +533,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
461
533
{
462
534
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
463
535
sgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
464
- }
536
+ }
537
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
538
+ #ifdef __CUDA
539
+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , trans, " gemv_op" );
540
+ cublasErrcheck (cublasSgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
541
+ #endif
542
+ }
465
543
}
466
544
467
545
void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -470,7 +548,13 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
470
548
{
471
549
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
472
550
dgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
473
- }
551
+ }
552
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
553
+ #ifdef __CUDA
554
+ cublasOperation_t cutransA = BlasUtils::judge_trans (false , trans, " gemv_op" );
555
+ cublasErrcheck (cublasDgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha, A, lda, X, incx, &beta, Y, incy));
556
+ #endif
557
+ }
474
558
}
475
559
476
560
void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -479,7 +563,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
479
563
{
480
564
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
481
565
cgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
482
- }
566
+ }
567
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
568
+ #ifdef __CUDA
569
+ cuFloatComplex alpha_cu = make_cuFloatComplex (alpha.real (), alpha.imag ());
570
+ cuFloatComplex beta_cu = make_cuFloatComplex (beta.real (), beta.imag ());
571
+ cublasOperation_t cutransA = BlasUtils::judge_trans (true , trans, " gemv_op" );
572
+ cublasErrcheck (cublasCgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuFloatComplex*)A, lda, (cuFloatComplex*)X, incx, &beta_cu, (cuFloatComplex*)Y, incy));
573
+ #endif
574
+ }
483
575
}
484
576
485
577
void BlasConnector::gemv (const char trans, const int m, const int n,
@@ -488,7 +580,15 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
488
580
{
489
581
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
490
582
zgemv_ (&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
491
- }
583
+ }
584
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
585
+ #ifdef __CUDA
586
+ cuDoubleComplex alpha_cu = make_cuDoubleComplex (alpha.real (), alpha.imag ());
587
+ cuDoubleComplex beta_cu = make_cuDoubleComplex (beta.real (), beta.imag ());
588
+ cublasOperation_t cutransA = BlasUtils::judge_trans (true , trans, " gemv_op" );
589
+ cublasErrcheck (cublasZgemv (BlasUtils::cublas_handle, cutransA, m, n, &alpha_cu, (cuDoubleComplex*)A, lda, (cuDoubleComplex*)X, incx, &beta_cu, (cuDoubleComplex*)Y, incy));
590
+ #endif
591
+ }
492
592
}
493
593
494
594
// out = ||x||_2
@@ -497,6 +597,13 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
497
597
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
498
598
return snrm2_ ( &n, X, &incX );
499
599
}
600
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
601
+ #ifdef __CUDA
602
+ float result = 0.0 ;
603
+ cublasErrcheck (cublasSnrm2 (BlasUtils::cublas_handle, n, X, incX, &result));
604
+ return result;
605
+ #endif
606
+ }
500
607
return snrm2_ ( &n, X, &incX );
501
608
}
502
609
@@ -506,6 +613,13 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
506
613
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
507
614
return dnrm2_ ( &n, X, &incX );
508
615
}
616
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
617
+ #ifdef __CUDA
618
+ double result = 0.0 ;
619
+ cublasErrcheck (cublasDnrm2 (BlasUtils::cublas_handle, n, X, incX, &result));
620
+ return result;
621
+ #endif
622
+ }
509
623
return dnrm2_ ( &n, X, &incX );
510
624
}
511
625
@@ -515,6 +629,13 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
515
629
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
516
630
return dznrm2_ ( &n, X, &incX );
517
631
}
632
+ else if (device_type == base_device::AbacusDevice_t::GpuDevice){
633
+ #ifdef __CUDA
634
+ double result = 0.0 ;
635
+ cublasErrcheck (cublasDznrm2 (BlasUtils::cublas_handle, n, (double2*)X, incX, &result));
636
+ return result;
637
+ #endif
638
+ }
518
639
return dznrm2_ ( &n, X, &incX );
519
640
}
520
641
0 commit comments