Skip to content

[Feature] Add some GPU kernels to blas_connector #5799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 183 additions & 24 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,101 @@
#include "module_base/global_variable.h"
#endif

#ifdef __CUDA
#include <base/macros/macros.h>
#include <cuda_runtime.h>
#include <thrust/complex.h>
#include <thrust/execution_policy.h>
#include <thrust/inner_product.h>
#include "module_base/tool_quit.h"

#include "cublas_v2.h"

namespace BlasUtils{

static cublasHandle_t cublas_handle = nullptr;

void createGpuBlasHandle(){
if (cublas_handle == nullptr) {
cublasErrcheck(cublasCreate(&cublas_handle));
}
}

void destoryBLAShandle(){
if (cublas_handle != nullptr) {
cublasErrcheck(cublasDestroy(cublas_handle));
cublas_handle = nullptr;
}
}


cublasOperation_t judge_trans(bool is_complex, const char& trans, const char* name)
{
if (trans == 'N')
{
return CUBLAS_OP_N;
}
else if(trans == 'T')
{
return CUBLAS_OP_T;
}
else if(is_complex && trans == 'C')
{
return CUBLAS_OP_C;
}
return CUBLAS_OP_N;
}

} // namespace BlasUtils

#endif

void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
saxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasSaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
daxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasDaxpy(BlasUtils::cublas_handle, n, &alpha, X, incX, Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
caxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasCaxpy(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX, (float2*)Y, incY));
#endif
}
}

void BlasConnector::axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasErrcheck(cublasZaxpy(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX, (double2*)Y, incY));
#endif
}
}


Expand All @@ -39,28 +108,48 @@ void BlasConnector::scal( const int n, const float alpha, float *X, const int i
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasSscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const double alpha, double *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasDscal(BlasUtils::cublas_handle, n, &alpha, X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
cscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasCscal(BlasUtils::cublas_handle, n, (float2*)&alpha, (float2*)X, incX));
#endif
}
}

void BlasConnector::scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zscal_(&n, &alpha, X, &incX);
}
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice) {
#ifdef __CUDA
cublasErrcheck(cublasZscal(BlasUtils::cublas_handle, n, (double2*)&alpha, (double2*)X, incX));
#endif
}
}


Expand All @@ -70,6 +159,13 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return sdot_(&n, X, &incX, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
float result = 0.0;
cublasErrcheck(cublasSdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
return result;
#endif
}
return sdot_(&n, X, &incX, Y, &incY);
}

Expand All @@ -78,6 +174,13 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return ddot_(&n, X, &incX, Y, &incY);
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
double result = 0.0;
cublasErrcheck(cublasDdot(BlasUtils::cublas_handle, n, X, incX, Y, incY, &result));
return result;
#endif
}
return ddot_(&n, X, &incX, Y, &incY);
}

Expand All @@ -92,13 +195,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -110,13 +220,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, &alpha, b, ldb, a, lda, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -128,13 +245,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (float2*)&alpha, (float2*)b, ldb, (float2*)a, lda, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -146,13 +270,20 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, n, m, k, (double2*)&alpha, (double2*)b, ldb, (double2*)a, lda, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

// Col-Major part
Expand All @@ -165,13 +296,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasSgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -183,13 +321,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasDgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
#endif
}
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -201,13 +346,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasCgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (float2*)&alpha, (float2*)a, lda, (float2*)b, ldb, (float2*)&beta, (float2*)c, ldc));
#endif
}
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -219,13 +371,20 @@ void BlasConnector::gemm_cm(const char transa, const char transb, const int m, c
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
#endif
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
cublasOperation_t cutransA = BlasUtils::judge_trans(false, transa, "gemm_op");
cublasOperation_t cutransB = BlasUtils::judge_trans(false, transb, "gemm_op");
cublasErrcheck(cublasZgemm(BlasUtils::cublas_handle, cutransA, cutransB, m, n, k, (double2*)&alpha, (double2*)a, lda, (double2*)b, ldb, (double2*)&beta, (double2*)c, ldc));
#endif
}
}

// Symm and Hemm part. Only col-major is supported.
Expand Down
Loading
Loading