Skip to content

Commit 0c2154f

Browse files
Refactor: refactor cg interface (#3293)
* init implementation * fix cuda implementation * address parallel problems * fix compilation errors * fix ut errors * fix sdft ut error * fix ut error * fix typo --------- Co-authored-by: Qianrui Liu <[email protected]>
1 parent f074af3 commit 0c2154f

File tree

15 files changed

+826
-511
lines changed

15 files changed

+826
-511
lines changed

source/module_base/module_container/ATen/core/tensor.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace container {
99

1010
Tensor::Tensor() : Tensor(DataType::DT_FLOAT) {}
1111

12-
Tensor::Tensor(DataType data_type) : Tensor(data_type, TensorShape({})) {}
12+
Tensor::Tensor(DataType data_type) : Tensor(data_type, TensorShape({1})) {}
1313

1414
// Constructor that creates a tensor with the given data type and shape using the default allocator.
1515
Tensor::Tensor(DataType data_type, const TensorShape& shape)
@@ -291,21 +291,31 @@ bool Tensor::AllocateFrom(const Tensor& other, const TensorShape& shape) {
291291

292292
void Tensor::sync(const Tensor& rhs) {
293293
REQUIRES_OK(this->data_type_ == rhs.data_type_
294-
&& this->device_ == rhs.device_
295-
&& this->shape_ == rhs.shape_)
294+
&& this->device_ == rhs.device_)
296295

297-
TEMPLATE_ALL_2(data_type_, device_,
298-
kernels::synchronize_memory<T_, DEVICE_, DEVICE_>()(
299-
this->data<T_>(), rhs.data<T_>(), this->NumElements()))
296+
if (this->shape_ == rhs.shape_) {
297+
TEMPLATE_ALL_2(data_type_, device_,
298+
kernels::synchronize_memory<T_, DEVICE_, DEVICE_>()(
299+
this->data<T_>(), rhs.data<T_>(), this->NumElements()))
300+
}
301+
else {
302+
TEMPLATE_ALL_2(data_type_, device_,
303+
kernels::synchronize_memory_stride<T_, DEVICE_, DEVICE_>()(
304+
this->data<T_>(), rhs.data<T_>(), this->shape().dims(), rhs.shape().dims()))
305+
}
300306
}
301307

302308
Tensor Tensor::operator[](const int& index) const {
303309
REQUIRES_OK(
304-
index > 0 && index < shape_.dim_size(0),
310+
index >= 0 && index < shape_.dim_size(0),
305311
"Tensor index is out of bounds.")
306312

307313
TensorShape output_shape = this->shape_;
308314
output_shape.remove_dim(0);
315+
if (output_shape.ndim() == 0) {
316+
// If the output shape is empty, we need to add a dimension of size 1
317+
output_shape.add_dim(1);
318+
}
309319
auto data_ = reinterpret_cast<char*>(this->data()) + index * shape_.strides()[0] * SizeOfType(this->data_type_);
310320

311321
return TensorMap(data_, this->data_type_, this->device_, output_shape);

source/module_base/module_container/ATen/core/tensor.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,12 @@ class Tensor {
499499
return this->NumElements() > 0;
500500
}
501501

502+
template<typename T>
503+
void set_value(T value) {
504+
TEMPLATE_ALL_2(this->data_type_, this->device_,
505+
kernels::set_memory<T, DEVICE_>()(this->data<T>(), value, this->NumElements()))
506+
}
507+
502508
protected:
503509

504510
/**

source/module_base/module_container/ATen/core/tensor_utils.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,16 @@ void _internal_output(
270270

271271
template <typename T>
272272
T extract(const container::Tensor& tensor) {
273-
return reinterpret_cast<T*>(tensor.data())[0];
273+
if (tensor.device_type() == DeviceType::CpuDevice) {
274+
return reinterpret_cast<T*>(tensor.data())[0];
275+
}
276+
else {
277+
T result = 0;
278+
TEMPLATE_ALL_2(tensor.data_type(), tensor.device_type(),
279+
kernels::synchronize_memory<T, DEVICE_CPU, DEVICE_>()(
280+
&result, reinterpret_cast<T*>(tensor.data()), 1))
281+
return result;
282+
}
274283
}
275284

276285
} // namespace container

source/module_base/module_container/ATen/kernels/memory.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <vector>
55
#include <complex>
66

7+
#include <base/macros/macros.h>
78
#include <ATen/core/tensor_types.h>
89

910
namespace container {
@@ -72,6 +73,26 @@ struct synchronize_memory {
7273
const size_t& size);
7374
};
7475

76+
template <typename T, typename Device_out, typename Device_in>
77+
struct synchronize_memory_stride {
78+
void operator()(
79+
T* arr_out,
80+
const T* arr_in,
81+
const std::vector<int64_t>& out_size,
82+
const std::vector<int64_t>& in_size)
83+
{
84+
REQUIRES_OK(in_size.size() == out_size.size() && in_size.size() <= 2);
85+
if (in_size.size() == 1) {
86+
synchronize_memory<T, Device_out, Device_in>()(arr_out, arr_in, in_size[0]);
87+
}
88+
else {
89+
for (int64_t ii = 0; ii < out_size[0]; ii++) {
90+
synchronize_memory<T, Device_out, Device_in>()(arr_out + ii * out_size[1], arr_in + ii * in_size[1], in_size[1]);
91+
}
92+
}
93+
}
94+
};
95+
7596
/**
7697
* @brief Casts memory between devices.
7798
*

source/module_hsolver/diagh.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ class DiagH
2828
// virtual void init()=0;
2929
std::string method = "none";
3030

31-
virtual void diag(hamilt::Hamilt<T, Device> *phm_in, psi::Psi<T, Device> &psi, Real *eigenvalue_in) = 0;
31+
virtual void diag(hamilt::Hamilt<T, Device> *phm_in, psi::Psi<T, Device> &psi, Real *eigenvalue_in) {
32+
ModuleBase::WARNING_QUIT("diagh", "diag method not implemented for the base class!");
33+
};
3234

3335
};
3436

0 commit comments

Comments
 (0)