Skip to content

fix: memory leak when precision=single #5839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,18 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
container::kernels::destroyGpuBlasHandle();
container::kernels::destroyGpuSolverHandle();
#endif
delete reinterpret_cast<psi::Psi<T, Device>*>(this->kspw_psi);
}
#ifdef __DSP
std::cout << " ** Closing DSP Hardware..." << std::endl;
dspDestoryHandle(GlobalV::MY_RANK);
#endif
if(PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
{
delete this->kspw_psi;
}
if (PARAM.inp.precision == "single")
{
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
delete this->__kspw_psi;
}

delete this->psi;
Expand Down
5 changes: 4 additions & 1 deletion source/module_hamilt_pw/hamilt_pwdft/VNL_in_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,12 @@ void pseudopot_cell_vnl::getvnl(Device* ctx,
delmem_var_op()(ctx, ylm);
delmem_var_op()(ctx, vkb1);
delmem_complex_op()(ctx, sk);
if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
if (PARAM.inp.device == "gpu" || PARAM.inp.precision == "single")
{
delmem_var_op()(ctx, gk);
}
if (PARAM.inp.device == "gpu")
{
delmem_int_op()(ctx, atom_nh);
delmem_int_op()(ctx, atom_nb);
delmem_int_op()(ctx, atom_na);
Expand Down
16 changes: 16 additions & 0 deletions source/module_io/read_input_item_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,12 +775,28 @@ void ReadInput::item_system()
para.input.device=base_device::information::get_device_flag(
para.inp.device, para.inp.basis_type);
};
item.check_value = [](const Input_Item& item, const Parameter& para) {
std::vector<std::string> avail_list = {"cpu", "gpu"};
if (std::find(avail_list.begin(), avail_list.end(), para.input.device) == avail_list.end())
{
const std::string warningstr = nofound_str(avail_list, "device");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
}
};
this->add_item(item);
}
{
Input_Item item("precision");
item.annotation = "the computing precision for ABACUS";
read_sync_string(input.precision);
item.check_value = [](const Input_Item& item, const Parameter& para) {
std::vector<std::string> avail_list = {"single", "double"};
if (std::find(avail_list.begin(), avail_list.end(), para.input.precision) == avail_list.end())
{
const std::string warningstr = nofound_str(avail_list, "precision");
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
}
};
this->add_item(item);
}
}
Expand Down
35 changes: 26 additions & 9 deletions source/module_psi/psi_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void PSIInit<T, Device>::prepare_init(const int& random_seed)
this->psi_initer = std::unique_ptr<psi_initializer<T>>(new psi_initializer_random<T>());
}
else if (this->init_wfc == "atomic"
|| (this->init_wfc == "atomic+random" && this->ucell.natomwfc != PARAM.inp.nbands))
|| (this->init_wfc == "atomic+random" && this->ucell.natomwfc < PARAM.inp.nbands))
{
this->psi_initer = std::unique_ptr<psi_initializer<T>>(new psi_initializer_atomic<T>());
}
Expand Down Expand Up @@ -99,17 +99,30 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
const int nbands_start = this->psi_initer->nbands_start();
const int nbands = psi->get_nbands();
const int nbasis = psi->get_nbasis();
const bool another_psi_space = (nbands_start != nbands || PARAM.inp.precision == "single");
const bool not_equal = (nbands_start != nbands);

Psi<T>* psi_cpu = reinterpret_cast<psi::Psi<T>*>(psi);
Psi<T, Device>* psi_device = kspw_psi;

if (another_psi_space)
if (not_equal)
{
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nullptr);
psi_device = PARAM.inp.device == "gpu" ? new psi::Psi<T, Device>(psi_cpu[0])
: reinterpret_cast<psi::Psi<T, Device>*>(psi_cpu);
}
else if (PARAM.inp.precision == "single")
{
if (PARAM.inp.device == "cpu")
{
psi_cpu = reinterpret_cast<psi::Psi<T>*>(kspw_psi);
psi_device = kspw_psi;
}
else
{
psi_cpu = new Psi<T>(1, nbands_start, nbasis, nullptr);
psi_device = kspw_psi;
}
}

// loop over kpoints, make it possible to only allocate memory for psig at the only one kpt
// like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
Expand All @@ -126,16 +139,16 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
this->psi_initer->init_psig(psi_cpu->get_pointer(), ik);
if (psi_device->get_pointer() != psi_cpu->get_pointer())
{
castmem_h2d_op()(ctx, cpu_ctx, psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis);
syncmem_h2d_op()(ctx, cpu_ctx, psi_device->get_pointer(), psi_cpu->get_pointer(), nbands_start * nbasis);
}

std::vector<typename GetTypeReal<T>::type> etatom(nbands_start, 0.0);

if (this->ks_solver == "cg")
{
if (another_psi_space)
if (not_equal)
{
// for diagH_subspace_init, psi_cpu->get_pointer() and kspw_psi->get_pointer() should be different
// for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be different
hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init(p_hamilt,
psi_device->get_pointer(),
nbands_start,
Expand All @@ -145,7 +158,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
}
else
{
// for diagH_subspace_init, psi_cpu->get_pointer() and kspw_psi->get_pointer() can be the same
// for diagH_subspace, psi_device->get_pointer() and kspw_psi->get_pointer() can be the same
hsolver::DiagoIterAssist<T, Device>::diagH_subspace(p_hamilt,
*psi_device,
*kspw_psi,
Expand All @@ -155,21 +168,25 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
}
else // dav, bpcg
{
if (another_psi_space)
if (psi_device->get_pointer() != kspw_psi->get_pointer())
{
syncmem_complex_op()(ctx, ctx, kspw_psi->get_pointer(), psi_device->get_pointer(), nbands * nbasis);
}
}
} // end k-point loop

if (another_psi_space)
if (not_equal)
{
delete psi_cpu;
if(PARAM.inp.device == "gpu")
{
delete psi_device;
}
}
else if (PARAM.inp.precision == "single" && PARAM.inp.device == "gpu")
{
delete psi_cpu;
}

ModuleBase::timer::tick("PSIInit", "initialize_psi");
}
Expand Down
3 changes: 1 addition & 2 deletions source/module_psi/psi_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ class PSIInit

//-------------------------OP--------------------------------------------
using syncmem_complex_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
using castmem_h2d_op
= base_device::memory::cast_memory_op<T, T, Device, base_device::DEVICE_CPU>;
using syncmem_h2d_op = base_device::memory::synchronize_memory_op<T, Device, base_device::DEVICE_CPU>;
};

///@brief allocate the wavefunction
Expand Down
Loading