Skip to content

Commit 8905ddf

Browse files
Refactor: refactor the constructors of Psi class (#5761)
* remove Psi(const Psi& psi_in, const int nk_in, int nband_in); * fix bug * fix bug * [pre-commit.ci lite] apply automatic fixes * remove device value in psi * update Psi(const Psi& psi_in, const int nk_in, int nband_in) * update get_ngk usage * fix bug about ngk * [pre-commit.ci lite] apply automatic fixes * fix bug * format operator * [pre-commit.ci lite] apply automatic fixes * fix bug * fix bug * fix bug * fix bug * add get_cur_effective_basis func * fix bug * update get_cur_effective_basis * check bugs * update Constructor 8-1 * fix bug * fix bug * fix bug * fix bug maybe * fix bug * check correct * check 1 * fix unit test * fix unit bug * update get_ngk func * remove get-ngk in velocity-pw * fix bug * [pre-commit.ci lite] apply automatic fixes * fix 186_PW_SKG_ALL bug * format source/module_io/unk_overlap_pw.cpp * update Constructor in psi * [pre-commit.ci lite] apply automatic fixes * debug unit test * fix ri test bug * [pre-commit.ci lite] apply automatic fixes * fix psi-ut bug * remove Psi<T, Device>::Psi(T* psi_pointer, const Psi& psi_in, const int nk_in, int nband_in) * remove useless code * update Psi(const Psi& psi_in, const int nk_in, const int nband_in); * remove Psi(const Psi& psi_in, const int nk_in, const int nband_in); * refactor psi code * fix sdft bug * change to get_current_ngk --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 4ddec65 commit 8905ddf

38 files changed

+647
-484
lines changed

source/module_elecstate/cal_dm.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
2727
//dm.fix_k(ik);
2828
dm[ik].create(ParaV->ncol, ParaV->nrow);
2929
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
30-
psi::Psi<double> wg_wfc(wfc, 1);
30+
psi::Psi<double> wg_wfc(1,
31+
wfc.get_nbands(),
32+
wfc.get_nbasis(),
33+
wfc.get_nbasis(),
34+
true);
35+
wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size());
3136

3237
int ib_global = 0;
3338
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
@@ -41,7 +46,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
4146
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
4247
}
4348
}
44-
if (ib_global >= wg.nc) continue;
49+
if (ib_global >= wg.nc) { continue;
50+
}
4551
const double wg_local = wg(ik, ib_global);
4652
double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
4753
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
@@ -99,7 +105,8 @@ inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg,
99105
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
100106
}
101107
}
102-
if (ib_global >= wg.nc) continue;
108+
if (ib_global >= wg.nc) { continue;
109+
}
103110
const double wg_local = wg(ik, ib_global);
104111
std::complex<double>* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
105112
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);

source/module_elecstate/elecstate_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
183183

184184
this->init_rho_data();
185185
int ik = psi.get_current_k();
186-
int npw = psi.get_current_nbas();
186+
int npw = psi.get_current_ngk();
187187
int current_spin = 0;
188188
if (PARAM.inp.nspin == 2)
189189
{
@@ -287,7 +287,7 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
287287
psi.fix_k(ik);
288288
const T* psi_now = psi.get_pointer();
289289
const int currect_spin = this->klist->isk[ik];
290-
const int npw = psi.get_current_nbas();
290+
const int npw = psi.get_current_ngk();
291291

292292
// get |beta>
293293
if (this->ppcell->nkb > 0)

source/module_elecstate/elecstate_pw_cal_tau.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
1515
for (int ik = 0; ik < psi.get_nk(); ++ik)
1616
{
1717
psi.fix_k(ik);
18-
int npw = psi.get_current_nbas();
18+
int npw = psi.get_current_ngk();
1919
int current_spin = 0;
2020
if (PARAM.inp.nspin == 2)
2121
{

source/module_elecstate/module_dm/cal_dm_psi.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV,
3232
// dm.fix_k(ik);
3333
// dm[ik].create(ParaV->ncol, ParaV->nrow);
3434
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
35-
psi::Psi<double> wg_wfc(wfc, 1);
35+
36+
psi::Psi<double> wg_wfc(1,
37+
wfc.get_nbands(),
38+
wfc.get_nbasis(),
39+
wfc.get_nbasis(),
40+
true);
41+
wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size());
42+
3643

3744
int ib_global = 0;
3845
for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
@@ -89,7 +96,12 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV,
8996
// dm.fix_k(ik);
9097
//dm[ik].create(ParaV->ncol, ParaV->nrow);
9198
// wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
92-
psi::Psi<std::complex<double>> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), nullptr);
99+
psi::Psi<std::complex<double>> wg_wfc(1,
100+
wfc.get_nbands(),
101+
wfc.get_nbasis(),
102+
wfc.get_nbasis(),
103+
true);
104+
93105
const std::complex<double>* pwfc = wfc.get_pointer();
94106
std::complex<double>* pwg_wfc = wg_wfc.get_pointer();
95107
#ifdef _OPENMP

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
10831083
//! initialize the gradients of Etotal with respect to occupation numbers and wfc,
10841084
//! and set all elements to 0.
10851085
ModuleBase::matrix dE_dOccNum(this->pelec->wg.nr, this->pelec->wg.nc, true);
1086-
psi::Psi<TK> dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis());
1086+
psi::Psi<TK> dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis(), this->kv.ngk, true);
10871087
dE_dWfc.zero_out();
10881088

10891089
double Etotal_RDMFT = this->rdmft_solver.run(dE_dOccNum, dE_dWfc);

source/module_esolver/esolver_of.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,11 @@ void ESolver_OF::before_opt(const int istep, UnitCell& ucell)
222222

223223
// Refresh the arrays
224224
delete this->psi_;
225-
this->psi_ = new psi::Psi<double>(1, PARAM.inp.nspin, this->pw_rho->nrxx);
225+
this->psi_ = new psi::Psi<double>(1,
226+
PARAM.inp.nspin,
227+
this->pw_rho->nrxx,
228+
this->pw_rho->nrxx,
229+
true);
226230
for (int is = 0; is < PARAM.inp.nspin; ++is)
227231
{
228232
this->pphi_[is] = this->psi_->get_pointer(is);

source/module_esolver/esolver_of_tool.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ void ESolver_OF::init_elecstate(UnitCell& ucell)
7171
void ESolver_OF::allocate_array()
7272
{
7373
// Initialize the "wavefunction", which is sqrt(rho)
74-
this->psi_ = new psi::Psi<double>(1, PARAM.inp.nspin, this->pw_rho->nrxx);
74+
this->psi_ = new psi::Psi<double>(1,
75+
PARAM.inp.nspin,
76+
this->pw_rho->nrxx,
77+
this->pw_rho->nrxx,
78+
true);
7579
ModuleBase::Memory::record("OFDFT::Psi", sizeof(double) * PARAM.inp.nspin * this->pw_rho->nrxx);
7680
this->pphi_ = new double*[PARAM.inp.nspin];
7781
for (int is = 0; is < PARAM.inp.nspin; ++is)

source/module_hamilt_general/operator.cpp

Lines changed: 81 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,31 @@
44

55
using namespace hamilt;
66

7-
8-
template<typename T, typename Device>
9-
Operator<T, Device>::Operator(){}
10-
11-
template<typename T, typename Device>
12-
Operator<T, Device>::~Operator()
7+
template <typename T, typename Device>
8+
Operator<T, Device>::Operator()
139
{
14-
if(this->hpsi != nullptr) { delete this->hpsi;
1510
}
11+
12+
template <typename T, typename Device>
13+
Operator<T, Device>::~Operator()
14+
{
15+
if (this->hpsi != nullptr)
16+
{
17+
delete this->hpsi;
18+
}
1619
Operator* last = this->next_op;
1720
Operator* last_sub = this->next_sub_op;
18-
while(last != nullptr || last_sub != nullptr)
21+
while (last != nullptr || last_sub != nullptr)
1922
{
20-
if(last_sub != nullptr)
21-
{//delete sub_chain first
23+
if (last_sub != nullptr)
24+
{ // delete sub_chain first
2225
Operator* node_delete = last_sub;
2326
last_sub = last_sub->next_sub_op;
2427
node_delete->next_sub_op = nullptr;
2528
delete node_delete;
2629
}
2730
else
28-
{//delete main chain if sub_chain is deleted
31+
{ // delete main chain if sub_chain is deleted
2932
Operator* node_delete = last;
3033
last_sub = last->next_sub_op;
3134
node_delete->next_sub_op = nullptr;
@@ -36,7 +39,7 @@ Operator<T, Device>::~Operator()
3639
}
3740
}
3841

39-
template<typename T, typename Device>
42+
template <typename T, typename Device>
4043
typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& input) const
4144
{
4245
using syncmem_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
@@ -46,37 +49,51 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
4649

4750
T* tmhpsi = this->get_hpsi(input);
4851
const T* tmpsi_in = std::get<0>(psi_info);
49-
//if range in hpsi_info is illegal, the first return of to_range() would be nullptr
52+
// if range in hpsi_info is illegal, the first return of to_range() would be nullptr
5053
if (tmpsi_in == nullptr)
5154
{
5255
ModuleBase::WARNING_QUIT("Operator", "please choose correct range of psi for hPsi()!");
5356
}
54-
//if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
57+
// if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
5558
T* hpsi_pointer = std::get<2>(input);
5659
if (this->in_place)
5760
{
5861
// ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
5962
syncmem_op()(this->ctx, this->ctx, hpsi_pointer, this->hpsi->get_pointer(), this->hpsi->size());
6063
delete this->hpsi;
61-
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
64+
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
65+
1,
66+
nbands / psi_input->npol,
67+
psi_input->get_nbasis(),
68+
psi_input->get_nbasis(),
69+
true);
6270
}
6371

6472
auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {
65-
6673
// a "psi" with the bands of needed range
67-
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis(), true);
68-
69-
74+
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in),
75+
1,
76+
nbands,
77+
psi_input->get_nbasis(),
78+
psi_input->get_nbasis(),
79+
true);
80+
7081
switch (op->get_act_type())
7182
{
7283
case 2:
7384
op->act(psi_wrapper, *this->hpsi, nbands);
7485
break;
7586
default:
76-
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node);
87+
op->act(nbands,
88+
psi_input->get_nbasis(),
89+
psi_input->npol,
90+
tmpsi_in,
91+
this->hpsi->get_pointer(),
92+
psi_input->get_current_nbas(),
93+
is_first_node);
7794
break;
7895
}
79-
};
96+
};
8097

8198
ModuleBase::timer::tick("Operator", "hPsi");
8299
call_act(this, true); // first node
@@ -91,39 +108,43 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
91108
return hpsi_info(this->hpsi, psi::Range(1, 0, 0, nbands / psi_input->npol), hpsi_pointer);
92109
}
93110

94-
95-
template<typename T, typename Device>
96-
void Operator<T, Device>::init(const int ik_in)
111+
template <typename T, typename Device>
112+
void Operator<T, Device>::init(const int ik_in)
97113
{
98114
this->ik = ik_in;
99-
if(this->next_op != nullptr) {
115+
if (this->next_op != nullptr)
116+
{
100117
this->next_op->init(ik_in);
101118
}
102119
}
103120

104-
template<typename T, typename Device>
105-
void Operator<T, Device>::add(Operator* next)
121+
template <typename T, typename Device>
122+
void Operator<T, Device>::add(Operator* next)
106123
{
107-
if(next==nullptr) { return;
108-
}
124+
if (next == nullptr)
125+
{
126+
return;
127+
}
109128
next->is_first_node = false;
110-
if(next->next_op != nullptr) { this->add(next->next_op);
111-
}
129+
if (next->next_op != nullptr)
130+
{
131+
this->add(next->next_op);
132+
}
112133
Operator* last = this;
113-
//loop to end of the chain
114-
while(last->next_op != nullptr)
134+
// loop to end of the chain
135+
while (last->next_op != nullptr)
115136
{
116-
if(next->cal_type==last->cal_type)
137+
if (next->cal_type == last->cal_type)
117138
{
118139
break;
119140
}
120141
last = last->next_op;
121142
}
122-
if(next->cal_type == last->cal_type)
143+
if (next->cal_type == last->cal_type)
123144
{
124-
//insert next to sub chain of current node
145+
// insert next to sub chain of current node
125146
Operator* sub_last = last;
126-
while(sub_last->next_sub_op != nullptr)
147+
while (sub_last->next_sub_op != nullptr)
127148
{
128149
sub_last = sub_last->next_sub_op;
129150
}
@@ -136,34 +157,45 @@ void Operator<T, Device>::add(Operator* next)
136157
}
137158
}
138159

139-
template<typename T, typename Device>
160+
template <typename T, typename Device>
140161
T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
141162
{
142163
const int nbands_range = (std::get<1>(info).range_2 - std::get<1>(info).range_1 + 1);
143-
//in_place call of hPsi, hpsi inputs as new psi,
144-
//create a new hpsi and delete old hpsi later
164+
// in_place call of hPsi, hpsi inputs as new psi,
165+
// create a new hpsi and delete old hpsi later
145166
T* hpsi_pointer = std::get<2>(info);
146167
const T* psi_pointer = std::get<0>(info)->get_pointer();
147-
if(this->hpsi != nullptr)
168+
if (this->hpsi != nullptr)
148169
{
149170
delete this->hpsi;
150171
this->hpsi = nullptr;
151172
}
152-
if(!hpsi_pointer)
173+
if (!hpsi_pointer)
153174
{
154175
ModuleBase::WARNING_QUIT("Operator::hPsi", "hpsi_pointer can not be nullptr");
155176
}
156-
else if(hpsi_pointer == psi_pointer)
177+
else if (hpsi_pointer == psi_pointer)
157178
{
158179
this->in_place = true;
159-
this->hpsi = new psi::Psi<T, Device>(std::get<0>(info)[0], 1, nbands_range);
180+
// this->hpsi = new psi::Psi<T, Device>(std::get<0>(info)[0], 1, nbands_range);
181+
this->hpsi = new psi::Psi<T, Device>(1,
182+
nbands_range,
183+
std::get<0>(info)->get_nbasis(),
184+
std::get<0>(info)->get_nbasis(),
185+
true);
160186
}
161187
else
162188
{
163189
this->in_place = false;
164-
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, std::get<0>(info)[0], 1, nbands_range);
190+
191+
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
192+
1,
193+
nbands_range,
194+
std::get<0>(info)->get_nbasis(),
195+
std::get<0>(info)->get_nbasis(),
196+
true);
165197
}
166-
198+
167199
hpsi_pointer = this->hpsi->get_pointer();
168200
size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis();
169201
// ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size);
@@ -172,7 +204,8 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
172204
return hpsi_pointer;
173205
}
174206

175-
namespace hamilt {
207+
namespace hamilt
208+
{
176209
template class Operator<float, base_device::DEVICE_CPU>;
177210
template class Operator<std::complex<float>, base_device::DEVICE_CPU>;
178211
template class Operator<double, base_device::DEVICE_CPU>;
@@ -183,4 +216,4 @@ template class Operator<std::complex<float>, base_device::DEVICE_GPU>;
183216
template class Operator<double, base_device::DEVICE_GPU>;
184217
template class Operator<std::complex<double>, base_device::DEVICE_GPU>;
185218
#endif
186-
}
219+
} // namespace hamilt

0 commit comments

Comments
 (0)