Skip to content

Commit 060a472

Browse files
committed
fix: memory leak when precision=single
1 parent a2ec5d1 commit 060a472

File tree

4 files changed

+49
-9
lines changed

4 files changed

+49
-9
lines changed

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,13 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
112112
std::cout << " ** Closing DSP Hardware..." << std::endl;
113113
dspDestoryHandle(GlobalV::MY_RANK);
114114
#endif
115+
if(PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
116+
{
117+
delete this->kspw_psi;
118+
}
115119
if (PARAM.inp.precision == "single")
116120
{
117-
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
121+
delete this->__kspw_psi;
118122
}
119123

120124
delete this->psi;

source/module_hamilt_pw/hamilt_pwdft/VNL_in_pw.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,12 @@ void pseudopot_cell_vnl::getvnl(Device* ctx,
532532
delmem_var_op()(ctx, ylm);
533533
delmem_var_op()(ctx, vkb1);
534534
delmem_complex_op()(ctx, sk);
535-
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
535+
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
536536
{
537537
delmem_var_op()(ctx, gk);
538+
}
539+
if (PARAM.inp.device == "gpu")
540+
{
538541
delmem_int_op()(ctx, atom_nh);
539542
delmem_int_op()(ctx, atom_nb);
540543
delmem_int_op()(ctx, atom_na);

source/module_io/read_input_item_system.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,12 +775,28 @@ void ReadInput::item_system()
775775
para.input.device=base_device::information::get_device_flag(
776776
para.inp.device, para.inp.basis_type);
777777
};
778+
item.check_value = [](const Input_Item& item, const Parameter& para) {
779+
std::vector<std::string> avail_list = {"cpu", "gpu"};
780+
if (std::find(avail_list.begin(), avail_list.end(), para.input.device) == avail_list.end())
781+
{
782+
const std::string warningstr = nofound_str(avail_list, "device");
783+
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
784+
}
785+
};
778786
this->add_item(item);
779787
}
780788
{
781789
Input_Item item("precision");
782790
item.annotation = "the computing precision for ABACUS";
783791
read_sync_string(input.precision);
792+
item.check_value = [](const Input_Item& item, const Parameter& para) {
793+
std::vector<std::string> avail_list = {"single", "double"};
794+
if (std::find(avail_list.begin(), avail_list.end(), para.input.precision) == avail_list.end())
795+
{
796+
const std::string warningstr = nofound_str(avail_list, "precision");
797+
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
798+
}
799+
};
784800
this->add_item(item);
785801
}
786802
}

source/module_psi/psi_init.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,30 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
9999
const int nbands_start = this->psi_initer->nbands_start();
100100
const int nbands = psi->get_nbands();
101101
const int nbasis = psi->get_nbasis();
102-
const bool another_psi_space = (nbands_start != nbands || PARAM.inp.precision == "single");
102+
const bool not_equal = (nbands_start != nbands);
103103

104104
Psi<T>* psi_cpu = reinterpret_cast<psi::Psi<T>*>(psi);
105105
Psi<T, Device>* psi_device = kspw_psi;
106106

107-
if (another_psi_space)
107+
if (not_equal)
108108
{
109109
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nullptr);
110110
psi_device = PARAM.inp.device == "gpu" ? new psi::Psi<T, Device>(psi_cpu[0])
111111
: reinterpret_cast<psi::Psi<T, Device>*>(psi_cpu);
112112
}
113+
else if (PARAM.inp.precision == "single")
114+
{
115+
if (PARAM.inp.device == "cpu")
116+
{
117+
psi_cpu = reinterpret_cast<psi::Psi<T>*>(kspw_psi);
118+
psi_device = kspw_psi;
119+
}
120+
else
121+
{
122+
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nullptr);
123+
psi_device = kspw_psi;
124+
}
125+
}
113126

114127
// loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
115128
// like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
@@ -133,9 +146,9 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
133146

134147
if (this->ks_solver == "cg")
135148
{
136-
if (another_psi_space)
149+
if (not_equal)
137150
{
138-
// for diagH_subspace_init, psi_cpu->get_pointer() and kspw_psi->get_pointer() should be different
151+
// for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be different
139152
hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init(p_hamilt,
140153
psi_device->get_pointer(),
141154
nbands_start,
@@ -145,7 +158,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
145158
}
146159
else
147160
{
148-
// for diagH_subspace_init, psi_cpu->get_pointer() and kspw_psi->get_pointer() can be the same
161+
// for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same
149162
hsolver::DiagoIterAssist<T, Device>::diagH_subspace(p_hamilt,
150163
*psi_device,
151164
*kspw_psi,
@@ -155,21 +168,25 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
155168
}
156169
else // dav, bpcg
157170
{
158-
if (another_psi_space)
171+
if (psi_device->get_pointer() != kspw_psi->get_pointer())
159172
{
160173
syncmem_complex_op()(ctx, ctx, kspw_psi->get_pointer(), psi_device->get_pointer(), nbands * nbasis);
161174
}
162175
}
163176
} // end k-point loop
164177

165-
if (another_psi_space)
178+
if (not_equal)
166179
{
167180
delete psi_cpu;
168181
if(PARAM.inp.device == "gpu")
169182
{
170183
delete psi_device;
171184
}
172185
}
186+
else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu")
187+
{
188+
delete psi_cpu;
189+
}
173190

174191
ModuleBase::timer::tick("PSIInit", "initialize_psi");
175192
}

0 commit comments

Comments
 (0)