Skip to content

Refactor: Use psi_initializer instead of wavefunc #5775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1dd449e
use psi_initializer
Qianruipku Dec 26, 2024
5f031f9
Merge branch 'develop' of https://github.com/deepmodeling/abacus-deve…
Qianruipku Dec 27, 2024
a40f8d0
fix compile
Qianruipku Dec 27, 2024
6fbf545
same results of random init
Qianruipku Dec 27, 2024
57148e3
make atomic initialized results right
Qianruipku Dec 30, 2024
ef6d8c0
finish refactor
Qianruipku Jan 4, 2025
cbc5a99
merge
Qianruipku Jan 4, 2025
f48d076
fix compile
Qianruipku Jan 4, 2025
b454fa1
fix compile
Qianruipku Jan 4, 2025
32c6bb8
fix UTs
Qianruipku Jan 5, 2025
4cb0318
update results
Qianruipku Jan 5, 2025
07f6759
update results
Qianruipku Jan 5, 2025
85b1951
update GPU results
Qianruipku Jan 5, 2025
6f45435
Merge branch 'develop' of https://github.com/deepmodeling/abacus-deve…
Qianruipku Jan 5, 2025
d3fb83f
update
Qianruipku Jan 5, 2025
a30d765
refactor pw
Qianruipku Jan 6, 2025
53ce201
change 108_PW_RE_PINT_RKS results
Qianruipku Jan 6, 2025
312baca
update results
Qianruipku Jan 6, 2025
71ab40a
remove openmp for random generate
Qianruipku Jan 6, 2025
3880bd3
update
Qianruipku Jan 6, 2025
15e6673
remove psi_initializer in Doc
Qianruipku Jan 6, 2025
32ee682
Merge branch 'develop' into hotfix
Qianruipku Jan 6, 2025
4785da0
remove omp2
Qianruipku Jan 6, 2025
22da03d
Merge branch 'develop' into hotfix
Qianruipku Jan 6, 2025
dcc87f4
Merge branch 'develop' into hotfix
Qianruipku Jan 6, 2025
abbc916
Merge branch 'develop' into hotfix
Qianruipku Jan 7, 2025
5e39aeb
merge
Qianruipku Jan 8, 2025
991214f
fix compile
Qianruipku Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions docs/advanced/input_files/input-main.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
- [kpar](#kpar)
- [bndpar](#bndpar)
- [latname](#latname)
- [psi\_initializer](#psi_initializer)
- [init\_wfc](#init_wfc)
- [init\_chg](#init_chg)
- [init\_vel](#init_vel)
Expand Down Expand Up @@ -93,6 +92,7 @@
- [scf\_os\_stop](#scf_os_stop)
- [scf\_os\_thr](#scf_os_thr)
- [scf\_os\_ndim](#scf_os_ndim)
- [sc\_os\_ndim](#sc_os_ndim)
- [chg\_extrap](#chg_extrap)
- [lspinorb](#lspinorb)
- [noncolin](#noncolin)
Expand Down Expand Up @@ -467,7 +467,7 @@
- [abs\_broadening](#abs_broadening)
- [ri\_hartree\_benchmark](#ri_hartree_benchmark)
- [aims\_nbasis](#aims_nbasis)
- [Reduced Density Matrix Functional Theory](#Reduced-Density-Matrix-Functional-Theory)
- [Reduced Density Matrix Functional Theory](#reduced-density-matrix-functional-theory)
- [rdmft](#rdmft)
- [rdmft\_power\_alpha](#rdmft_power_alpha)

Expand Down Expand Up @@ -580,17 +580,6 @@ These variables are used to control general system parameters.
- triclinic: triclinic (14)
- **Default**: none

### psi_initializer

- **Type**: Integer
- **Description**: enable the experimental feature psi_initializer, to support use numerical atomic orbitals initialize wavefunction (`basis_type pw` case).

NOTE: this feature is not well-implemented for `nspin 4` case (closed presently), and cannot use with `calculation nscf`/`esolver_type sdft` cases.
Available options are:
- 0: disable psi_initializer
- 1: enable psi_initializer
- **Default**: 0

### init_wfc

- **Type**: String
Expand All @@ -602,8 +591,6 @@ These variables are used to control general system parameters.
- atomic+random: add small random numbers on atomic pseudo-wavefunctions
- file: from binary files `WAVEFUNC*.dat`, which are output by setting [out_wfc_pw](#out_wfc_pw) to `2`.
- random: random numbers

with `psi_initializer 1`, two more options are supported:
- nao: from numerical atomic orbitals. If they are not enough, other wave functions are initialized with random numbers.
- nao+random: add small random numbers on numerical atomic orbitals
- **Default**: atomic
Expand Down
4 changes: 2 additions & 2 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ OBJS_PSI=psi.o\

OBJS_PSI_INITIALIZER=psi_initializer.o\
psi_initializer_random.o\
psi_initializer_file.o\
psi_initializer_atomic.o\
psi_initializer_atomic_random.o\
psi_initializer_nao.o\
Expand Down Expand Up @@ -494,6 +495,7 @@ OBJS_IO=input_conv.o\
to_wannier90_lcao.o\
fR_overlap.o\
unk_overlap_pw.o\
write_pao.o\
write_wfc_pw.o\
winput.o\
write_cube.o\
Expand Down Expand Up @@ -669,8 +671,6 @@ OBJS_SRCPW=H_Ewald_pw.o\
of_stress_pw.o\
symmetry_rho.o\
symmetry_rhog.o\
wavefunc.o\
wf_atomic.o\
psi_init.o\
elecond.o\
sto_tool.o\
Expand Down
21 changes: 11 additions & 10 deletions source/module_basis/module_pw/pw_basis_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ PW_Basis_K::~PW_Basis_K()
delete[] igl2isz_k;
delete[] igl2ig_k;
delete[] gk2;
delete[] ig2ixyz_k_;
#if defined(__CUDA) || defined(__ROCM)
if (this->device == "gpu") {
if (this->precision == "single") {
Expand Down Expand Up @@ -169,6 +168,7 @@ void PW_Basis_K::setupIndGk()
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->d_igl2isz_k, this->igl2isz_k, this->npwk_max * this->nks);
}
#endif
this->get_ig2ixyz_k();
return;
}

Expand Down Expand Up @@ -334,8 +334,12 @@ int& PW_Basis_K::getigl2ig(const int ik, const int igl) const

void PW_Basis_K::get_ig2ixyz_k()
{
delete[] this->ig2ixyz_k_;
this->ig2ixyz_k_ = new int [this->npwk_max * this->nks];
if (this->device != "gpu")
{
//only GPU need to get ig2ixyz_k
return;
}
int * ig2ixyz_k_cpu = new int [this->npwk_max * this->nks];
ModuleBase::Memory::record("PW_B_K::ig2ixyz", sizeof(int) * this->npwk_max * this->nks);
assert(gamma_only == false); //We only finish non-gamma_only fft on GPU temperarily.
for(int ik = 0; ik < this->nks; ++ik)
Expand All @@ -348,15 +352,12 @@ void PW_Basis_K::get_ig2ixyz_k()
int ixy = this->is2fftixy[is];
int iy = ixy % this->ny;
int ix = ixy / this->ny;
ig2ixyz_k_[igl + ik * npwk_max] = iz + iy * nz + ix * ny * nz;
ig2ixyz_k_cpu[igl + ik * npwk_max] = iz + iy * nz + ix * ny * nz;
}
}
#if defined(__CUDA) || defined(__ROCM)
if (this->device == "gpu") {
resmem_int_op()(gpu_ctx, ig2ixyz_k, this->npwk_max * this->nks);
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->ig2ixyz_k, this->ig2ixyz_k_, this->npwk_max * this->nks);
}
#endif
resmem_int_op()(gpu_ctx, ig2ixyz_k, this->npwk_max * this->nks);
syncmem_int_h2d_op()(gpu_ctx, cpu_ctx, this->ig2ixyz_k, ig2ixyz_k_cpu, this->npwk_max * this->nks);
delete[] ig2ixyz_k_cpu;
}

std::vector<int> PW_Basis_K::get_ig2ix(const int ik) const
Expand Down
7 changes: 3 additions & 4 deletions source/module_basis/module_pw/pw_basis_k.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ class PW_Basis_K : public PW_Basis
const bool xprime_in = true
);

void get_ig2ixyz_k();

public:
int nks=0;//number of k points in this pool
ModuleBase::Vector3<double> *kvec_d=nullptr; // Direct coordinates of k points
Expand All @@ -88,8 +86,7 @@ class PW_Basis_K : public PW_Basis

int *igl2isz_k=nullptr, * d_igl2isz_k = nullptr; //[npwk_max*nks] map (igl,ik) to (is,iz)
int *igl2ig_k=nullptr;//[npwk_max*nks] map (igl,ik) to ig
int *ig2ixyz_k=nullptr;
int *ig2ixyz_k_=nullptr;
int *ig2ixyz_k=nullptr; ///< [npw] map ig to ixyz

double *gk2=nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]

Expand All @@ -108,6 +105,8 @@ class PW_Basis_K : public PW_Basis
double * d_gk2 = nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks]
//create igl2isz_k map array for fft
void setupIndGk();
// get ig2ixyz_k
void get_ig2ixyz_k();
//calculate G+K, it is a private function
ModuleBase::Vector3<double> cal_GplusK_cartesian(const int ik, const int ig) const;

Expand Down
7 changes: 0 additions & 7 deletions source/module_basis/module_pw/test/test4-4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,6 @@ TEST_F(PWTEST,test4_4)
}
}

//check getig2ixyz_k
pwtest.get_ig2ixyz_k();
for(int igl = 0; igl < npwk ; ++igl)
{
EXPECT_GE(pwtest.ig2ixyz_k_[igl + ik * pwtest.npwk_max], 0);
}

}
delete []tmp;
delete [] rhor;
Expand Down
3 changes: 1 addition & 2 deletions source/module_cell/read_atoms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ int UnitCell::read_atom_species(std::ifstream &ifa, std::ofstream &ofs_running)
||(PARAM.inp.basis_type == "lcao_in_pw")
||(
(PARAM.inp.basis_type == "pw")
&&(PARAM.inp.psi_initializer)
&&(PARAM.inp.init_wfc.substr(0, 3) == "nao")
)
|| PARAM.inp.onsite_radius > 0.0
Expand Down Expand Up @@ -453,7 +452,7 @@ bool UnitCell::read_atom_positions(std::ifstream &ifpos, std::ofstream &ofs_runn
}
else if(PARAM.inp.basis_type == "pw")
{
if ((PARAM.inp.psi_initializer)&&(PARAM.inp.init_wfc.substr(0, 3) == "nao") || PARAM.inp.onsite_radius > 0.0)
if ((PARAM.inp.init_wfc.substr(0, 3) == "nao") || PARAM.inp.onsite_radius > 0.0)
{
std::string orbital_file = PARAM.inp.orbital_dir + orbital_fn[it];
this->read_orb_file(it, orbital_file, ofs_running, &(atoms[it]));
Expand Down
3 changes: 1 addition & 2 deletions source/module_cell/unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,7 @@ void UnitCell::cal_nwfc(std::ofstream& log) {
// Use localized basis
//=====================
if ((PARAM.inp.basis_type == "lcao") || (PARAM.inp.basis_type == "lcao_in_pw")
|| ((PARAM.inp.basis_type == "pw") && (PARAM.inp.psi_initializer)
&& (PARAM.inp.init_wfc.substr(0, 3) == "nao")
|| ((PARAM.inp.basis_type == "pw") && (PARAM.inp.init_wfc.substr(0, 3) == "nao")
&& (PARAM.inp.esolver_type == "ksdft"))) // xiaohui add 2013-09-02
{
ModuleBase::GlobalFunc::AUTO_SET("NBANDS", PARAM.inp.nbands);
Expand Down
28 changes: 15 additions & 13 deletions source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace ModuleESolver
template <typename T>
ESolver_KS_LIP<T>::~ESolver_KS_LIP()
{
delete this->psi_local;
// delete Hamilt
this->deallocate_hamilt();
}
Expand All @@ -79,11 +80,22 @@ namespace ModuleESolver
this->p_hamilt = nullptr;
}
}
template <typename T>
void ESolver_KS_LIP<T>::before_scf(UnitCell& ucell, const int istep)
{
ESolver_KS_PW<T>::before_scf(ucell, istep);
this->p_psi_init->initialize_lcao_in_pw(this->psi_local, GlobalV::ofs_running);
}

template <typename T>
void ESolver_KS_LIP<T>::before_all_runners(UnitCell& ucell, const Input_para& inp)
{
ESolver_KS_PW<T>::before_all_runners(ucell, inp);
delete this->psi_local;
this->psi_local = new psi::Psi<T>(this->psi->get_nk(),
this->p_psi_init->psi_initer->nbands_start(),
this->psi->get_nbasis(),
this->psi->get_ngk_pointer());
#ifdef __EXX
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax"
|| PARAM.inp.calculation == "cell-relax"
Expand All @@ -94,14 +106,14 @@ namespace ModuleESolver
this->exx_lip = std::unique_ptr<Exx_Lip<T>>(new Exx_Lip<T>(GlobalC::exx_info.info_lip,
ucell.symm,
&this->kv,
this->p_wf_init,
this->psi_local,
this->kspw_psi,
this->pw_wfc,
this->pw_rho,
this->sf,
&ucell,
this->pelec));
// this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_wf_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec);
// this->exx_lip.init(GlobalC::exx_info.info_lip, cell.symm, &this->kv, this->p_psi_init, this->kspw_psi, this->pw_wfc, this->pw_rho, this->sf, &cell, this->pelec);
}
}
#endif
Expand Down Expand Up @@ -136,18 +148,8 @@ namespace ModuleESolver
hsolver::DiagoIterAssist<T>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;

// It is not a good choice to overload another solve function here, this will spoil the concept of
// multiple inheritance and polymorphism. But for now, we just do it in this way.
// In the future, there will be a series of class ESolver_KS_LCAO_PW, HSolver_LCAO_PW and so on.
std::weak_ptr<psi::Psi<T>> psig = this->p_wf_init->get_psig();

if (psig.expired())
{
ModuleBase::WARNING_QUIT("ESolver_KS_PW::hamilt2density_single", "psig lifetime is expired");
}

hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge,ucell.tpiba,ucell.nat);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, *this->psi_local, skip_charge,ucell.tpiba,ucell.nat);

// add exx
#ifdef __EXX
Expand Down
4 changes: 4 additions & 0 deletions source/module_esolver/esolver_ks_lcaopw.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ namespace ModuleESolver
void before_all_runners(UnitCell& ucell, const Input_para& inp) override;
void after_all_runners(UnitCell& ucell) override;

virtual void before_scf(UnitCell& ucell, const int istep) override;

protected:
virtual void iter_init(UnitCell& ucell, const int istep, const int iter) override;
virtual void iter_finish(UnitCell& ucell, const int istep, int& iter) override;
Expand All @@ -35,6 +37,8 @@ namespace ModuleESolver

virtual void allocate_hamilt(const UnitCell& ucell) override;
virtual void deallocate_hamilt() override;

psi::Psi<T, base_device::DEVICE_CPU>* psi_local = nullptr; ///< psi for all local NAOs

#ifdef __EXX
std::unique_ptr<Exx_Lip<T>> exx_lip;
Expand Down
76 changes: 20 additions & 56 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
}

delete this->psi;
delete this->p_wf_init;
delete this->p_psi_init;
}

template <typename T, typename Device>
Expand Down Expand Up @@ -189,26 +189,6 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
&(this->pelec->f_en.vtxc));
}

//! 7) prepare some parameters for electronic wave functions initilization
this->p_wf_init = new psi::PSIInit<T, Device>(PARAM.inp.init_wfc,
PARAM.inp.ks_solver,
PARAM.inp.basis_type,
PARAM.inp.psi_initializer,
this->pw_wfc);
this->p_wf_init->prepare_init(&(this->sf),
&ucell,
1,
#ifdef __MPI
&GlobalC::Pkpoints,
GlobalV::MY_RANK,
#endif
&this->ppcell);

if (this->psi != nullptr)
{
delete this->psi;
this->psi = nullptr;
}

//! initalize local pseudopotential
this->locpp.init_vloc(ucell, this->pw_rhod);
Expand All @@ -219,17 +199,19 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
this->ppcell.init_vnl(ucell, this->pw_rhod);
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "NON-LOCAL POTENTIAL");

//! Allocate psi
this->p_wf_init->allocate_psi(this->psi,
this->kv.get_nkstot(),
this->kv.get_nks(),
this->kv.ngk.data(),
this->pw_wfc->npwk_max,
&this->sf,
&this->ppcell,
ucell);

assert(this->psi != nullptr);
//! Allocate and initialize psi
this->p_psi_init = new psi::PSIInit<T, Device>(PARAM.inp.init_wfc,
PARAM.inp.ks_solver,
PARAM.inp.basis_type,
GlobalV::MY_RANK,
ucell,
this->sf,
GlobalC::Pkpoints,
this->ppcell,
*this->pw_wfc);
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk.data(), PARAM.inp.nbands, this->pw_wfc->npwk_max);
this->p_psi_init->prepare_init(PARAM.inp.pw_seed);

this->kspw_psi = PARAM.inp.device == "gpu" || PARAM.inp.precision == "single"
? new psi::Psi<T, Device>(this->psi[0])
: reinterpret_cast<psi::Psi<T, Device>*>(this->psi);
Expand Down Expand Up @@ -267,7 +249,7 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)

this->pw_wfc->collect_local_pw(PARAM.inp.erf_ecut, PARAM.inp.erf_height, PARAM.inp.erf_sigma);

this->p_wf_init->make_table(this->kv.get_nks(), &this->sf, &this->ppcell, ucell);
this->p_psi_init->prepare_init(PARAM.inp.pw_seed);
}
if (ucell.ionic_position_updated)
{
Expand Down Expand Up @@ -407,29 +389,11 @@ void ESolver_KS_PW<T, Device>::before_scf(UnitCell& ucell, const int istep)
auto* dftu = ModuleDFTU::DFTU::get_instance();
dftu->init(ucell, nullptr, this->kv.get_nks());
}
// after init_rho (in pelec->init_scf), we have rho now.
// before hamilt2density, we update Hk and initialize psi

// before_scf function will be called everytime before scf. However, once
// atomic coordinates changed, structure factor will change, therefore all
// atomwise properties will change. So we need to reinitialize psi every
// time before scf. But for random wavefunction, we dont, because random
// wavefunction is not related to atomic coordinates. What the old strategy
// does is only to initialize for once...
if (((PARAM.inp.init_wfc == "random") && (istep == 0)) || (PARAM.inp.init_wfc != "random"))
{
this->p_wf_init->initialize_psi(this->psi,
this->kspw_psi,
this->p_hamilt,
this->ppcell,
ucell,
GlobalV::ofs_running,
this->already_initpsi);

if (this->already_initpsi == false)
{
this->already_initpsi = true;
}

if (!this->already_initpsi)
{
this->p_psi_init->initialize_psi(this->psi, this->kspw_psi, this->p_hamilt, GlobalV::ofs_running);
this->already_initpsi = true;
}

ModuleBase::timer::tick("ESolver_KS_PW", "before_scf");
Expand Down
Loading
Loading