Skip to content

Commit 24abddd

Browse files
Refactor: Remove the global dependence of all remained functions in DeePKS. (#5835)
* Remove global dependence of cal_gevdm and rearrange the calling order for simplifying. * Move some checks from FORCE_STRESS to LCAO_Deepks_interface. * Remove the global dependence of cal_e_delta_band. * Move cal_gedm to deepks_basic.cpp * Remove the global dependence of functions related to pdm in DeePKS. * Revert "Remove the global dependence of functions related to pdm in DeePKS." This reverts commit 7a97a95. * Remove global dependence of pdm related functions in DeePKS. * Fix the compile bug of DeePKS UT test. * Remove the global dependence of functions related to phialpha in DeePKS. * Simplify some function for LCAO_deepks_io. * Update FORCE_STRESS.cpp * [pre-commit.ci lite] apply automatic fixes * Update esolver_ks_lcao.cpp * Update LCAO_deepks_interface.cpp --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 74b2954 commit 24abddd

31 files changed

+1043
-1286
lines changed

source/Makefile.Objects

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,21 +193,20 @@ OBJS_CELL=atom_pseudo.o\
193193
read_atom_species.o\
194194

195195
OBJS_DEEPKS=LCAO_deepks.o\
196+
deepks_basic.o\
197+
deepks_descriptor.o\
196198
deepks_force.o\
197199
deepks_fpre.o\
198200
deepks_spre.o\
199-
deepks_descriptor.o\
200201
deepks_orbital.o\
201202
deepks_orbpre.o\
203+
deepks_vdelta.o\
202204
deepks_vdpre.o\
203205
deepks_hmat.o\
206+
deepks_pdm.o\
207+
deepks_phialpha.o\
204208
LCAO_deepks_io.o\
205-
LCAO_deepks_pdm.o\
206-
LCAO_deepks_phialpha.o\
207-
LCAO_deepks_torch.o\
208-
LCAO_deepks_vdelta.o\
209209
LCAO_deepks_interface.o\
210-
cal_gedm.o\
211210

212211

213212
OBJS_ELECSTAT=elecstate.o\

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,16 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
225225
if (PARAM.inp.deepks_scf)
226226
{
227227
// load the DeePKS model from deep neural network
228-
GlobalC::ld.load_model(PARAM.inp.deepks_model);
228+
DeePKS_domain::load_model(PARAM.inp.deepks_model, GlobalC::ld.model_deepks);
229229
// read pdm from file for NSCF or SCF-restart, do it only once in whole calculation
230-
GlobalC::ld.read_projected_DM((PARAM.inp.init_chg == "file"), PARAM.inp.deepks_equiv, *orb_.Alpha);
230+
DeePKS_domain::read_pdm((PARAM.inp.init_chg == "file"),
231+
PARAM.inp.deepks_equiv,
232+
GlobalC::ld.init_pdm,
233+
GlobalC::ld.inlmax,
234+
GlobalC::ld.lmaxd,
235+
GlobalC::ld.inl_l,
236+
*orb_.Alpha,
237+
GlobalC::ld.pdm);
231238
}
232239
#endif
233240

@@ -928,9 +935,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
928935
// 1) calculate the kinetic energy density tau, sunliang 2024-09-18
929936
if (PARAM.inp.out_elf[0] > 0)
930937
{
931-
elecstate::lcao_cal_tau<TK>(&(this->GG),
932-
&(this->GK),
933-
this->pelec->charge);
938+
elecstate::lcao_cal_tau<TK>(&(this->GG), &(this->GK), this->pelec->charge);
934939
}
935940

936941
//! 2) call after_scf() of ESolver_KS
@@ -1047,7 +1052,6 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
10471052
std::shared_ptr<LCAO_Deepks> ld_shared_ptr(&GlobalC::ld, [](LCAO_Deepks*) {});
10481053
LCAO_Deepks_Interface<TK, TR> LDI(ld_shared_ptr);
10491054

1050-
ModuleBase::timer::tick("ESolver_KS_LCAO", "out_deepks_labels");
10511055
LDI.out_deepks_labels(this->pelec->f_en.etot,
10521056
this->pelec->klist->get_nks(),
10531057
ucell.nat,
@@ -1061,8 +1065,6 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep)
10611065
*(this->psi),
10621066
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
10631067
p_ham_deepks);
1064-
1065-
ModuleBase::timer::tick("ESolver_KS_LCAO", "out_deepks_labels");
10661068
}
10671069
#endif
10681070

source/module_esolver/lcao_before_scf.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,19 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
211211
{
212212
const Parallel_Orbitals* pv = &this->pv;
213213
// allocate <phi(0)|alpha(R)>, phialpha is different every ion step, so it is allocated here
214-
GlobalC::ld.allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
214+
DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
215215
// build and save <phi(0)|alpha(R)> at beginning
216-
GlobalC::ld.build_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, *(two_center_bundle_.overlap_orb_alpha));
216+
DeePKS_domain::build_phialpha(PARAM.inp.cal_force,
217+
ucell,
218+
orb_,
219+
this->gd,
220+
pv,
221+
*(two_center_bundle_.overlap_orb_alpha),
222+
GlobalC::ld.phialpha);
217223

218224
if (PARAM.inp.deepks_out_unittest)
219225
{
220-
GlobalC::ld.check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
226+
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
221227
}
222228
}
223229
#endif

source/module_esolver/lcao_others.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,19 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
217217
{
218218
const Parallel_Orbitals* pv = &this->pv;
219219
// allocate <phi(0)|alpha(R)>, phialpha is different every ion step, so it is allocated here
220-
GlobalC::ld.allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
220+
DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
221221
// build and save <phi(0)|alpha(R)> at beginning
222-
GlobalC::ld.build_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, *(two_center_bundle_.overlap_orb_alpha));
222+
DeePKS_domain::build_phialpha(PARAM.inp.cal_force,
223+
ucell,
224+
orb_,
225+
this->gd,
226+
pv,
227+
*(two_center_bundle_.overlap_orb_alpha),
228+
GlobalC::ld.phialpha);
223229

224230
if (PARAM.inp.deepks_out_unittest)
225231
{
226-
GlobalC::ld.check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
232+
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
227233
}
228234
}
229235
#endif

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 10 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -500,87 +500,16 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
500500
if (PARAM.inp.deepks_out_labels) // not parallelized yet
501501
{
502502
const std::string file_ftot = PARAM.globalv.global_out_dir + "deepks_ftot.npy";
503-
LCAO_deepks_io::save_npy_f(fcs, file_ftot, ucell.nat,
504-
GlobalV::MY_RANK); // Ty/Bohr, F_tot
503+
LCAO_deepks_io::save_npy_f(fcs, file_ftot, GlobalV::MY_RANK); // Ry/Bohr, F_tot
505504

505+
const std::string file_fbase = PARAM.globalv.global_out_dir + "deepks_fbase.npy";
506506
if (PARAM.inp.deepks_scf)
507507
{
508-
const std::string file_fbase = PARAM.globalv.global_out_dir + "deepks_fbase.npy";
509-
LCAO_deepks_io::save_npy_f(fcs - fvnl_dalpha,
510-
file_fbase,
511-
ucell.nat,
512-
GlobalV::MY_RANK); // Ry/Bohr, F_base
513-
514-
if (!PARAM.inp.deepks_equiv) // training with force label not supported by equivariant version now
515-
{
516-
torch::Tensor gdmx;
517-
if (PARAM.globalv.gamma_only_local)
518-
{
519-
const std::vector<std::vector<double>>& dm_gamma
520-
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();
521-
522-
DeePKS_domain::cal_gdmx(GlobalC::ld.lmaxd,
523-
GlobalC::ld.inlmax,
524-
kv.get_nks(),
525-
kv.kvec_d,
526-
GlobalC::ld.phialpha,
527-
GlobalC::ld.inl_index,
528-
dm_gamma,
529-
ucell,
530-
orb,
531-
pv,
532-
gd,
533-
gdmx);
534-
}
535-
else
536-
{
537-
const std::vector<std::vector<std::complex<double>>>& dm_k
538-
= dynamic_cast<const elecstate::ElecStateLCAO<std::complex<double>>*>(pelec)
539-
->get_DM()
540-
->get_DMK_vector();
541-
542-
DeePKS_domain::cal_gdmx(GlobalC::ld.lmaxd,
543-
GlobalC::ld.inlmax,
544-
kv.get_nks(),
545-
kv.kvec_d,
546-
GlobalC::ld.phialpha,
547-
GlobalC::ld.inl_index,
548-
dm_k,
549-
ucell,
550-
orb,
551-
pv,
552-
gd,
553-
gdmx);
554-
}
555-
std::vector<torch::Tensor> gevdm;
556-
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
557-
torch::Tensor gvx;
558-
DeePKS_domain::cal_gvx(ucell.nat,
559-
GlobalC::ld.inlmax,
560-
GlobalC::ld.des_per_atom,
561-
GlobalC::ld.inl_l,
562-
gevdm,
563-
gdmx,
564-
gvx);
565-
566-
if (PARAM.inp.deepks_out_unittest)
567-
{
568-
DeePKS_domain::check_gdmx(gdmx);
569-
DeePKS_domain::check_gvx(gvx);
570-
}
571-
572-
LCAO_deepks_io::save_npy_gvx(ucell.nat,
573-
GlobalC::ld.des_per_atom,
574-
gvx,
575-
PARAM.globalv.global_out_dir,
576-
GlobalV::MY_RANK);
577-
}
508+
LCAO_deepks_io::save_npy_f(fcs - fvnl_dalpha, file_fbase, GlobalV::MY_RANK); // Ry/Bohr, F_base
578509
}
579510
else
580511
{
581-
const std::string file_fbase = PARAM.globalv.global_out_dir + "deepks_fbase.npy";
582-
LCAO_deepks_io::save_npy_f(fcs, file_fbase, ucell.nat,
583-
GlobalV::MY_RANK); // no scf, F_base=F_tot
512+
LCAO_deepks_io::save_npy_f(fcs, file_fbase, GlobalV::MY_RANK); // no scf, F_base=F_tot
584513
}
585514
}
586515
#endif
@@ -758,80 +687,18 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
758687
ucell.omega,
759688
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_tot, w/ model
760689

761-
// wenfei add 2021/11/2
690+
const std::string file_sbase = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
762691
if (PARAM.inp.deepks_scf)
763692
{
764-
const std::string file_sbase = PARAM.globalv.global_out_dir + "deepks_sbase.npy";
765693
LCAO_deepks_io::save_npy_s(scs - svnl_dalpha,
766694
file_sbase,
767695
ucell.omega,
768696
GlobalV::MY_RANK); // change to energy unit Ry when printing, S_base;
769-
770-
if (!PARAM.inp.deepks_equiv) // training with stress label not supported by equivariant version now
771-
{
772-
torch::Tensor gdmepsl;
773-
if (PARAM.globalv.gamma_only_local)
774-
{
775-
const std::vector<std::vector<double>>& dm_gamma
776-
= dynamic_cast<const elecstate::ElecStateLCAO<double>*>(pelec)->get_DM()->get_DMK_vector();
777-
778-
DeePKS_domain::cal_gdmepsl(GlobalC::ld.lmaxd,
779-
GlobalC::ld.inlmax,
780-
kv.get_nks(),
781-
kv.kvec_d,
782-
GlobalC::ld.phialpha,
783-
GlobalC::ld.inl_index,
784-
dm_gamma,
785-
ucell,
786-
orb,
787-
pv,
788-
gd,
789-
gdmepsl);
790-
}
791-
else
792-
{
793-
const std::vector<std::vector<std::complex<double>>>& dm_k
794-
= dynamic_cast<const elecstate::ElecStateLCAO<std::complex<double>>*>(pelec)
795-
->get_DM()
796-
->get_DMK_vector();
797-
798-
DeePKS_domain::cal_gdmepsl(GlobalC::ld.lmaxd,
799-
GlobalC::ld.inlmax,
800-
kv.get_nks(),
801-
kv.kvec_d,
802-
GlobalC::ld.phialpha,
803-
GlobalC::ld.inl_index,
804-
dm_k,
805-
ucell,
806-
orb,
807-
pv,
808-
gd,
809-
gdmepsl);
810-
}
811-
812-
std::vector<torch::Tensor> gevdm;
813-
GlobalC::ld.cal_gevdm(ucell.nat, gevdm);
814-
torch::Tensor gvepsl;
815-
DeePKS_domain::cal_gvepsl(ucell.nat,
816-
GlobalC::ld.inlmax,
817-
GlobalC::ld.des_per_atom,
818-
GlobalC::ld.inl_l,
819-
gevdm,
820-
gdmepsl,
821-
gvepsl);
822-
823-
if (PARAM.inp.deepks_out_unittest)
824-
{
825-
DeePKS_domain::check_gdmepsl(gdmepsl);
826-
DeePKS_domain::check_gvepsl(gvepsl);
827-
}
828-
829-
LCAO_deepks_io::save_npy_gvepsl(ucell.nat,
830-
GlobalC::ld.des_per_atom,
831-
gvepsl,
832-
PARAM.globalv.global_out_dir,
833-
GlobalV::MY_RANK); // unitless, grad_vepsl
834-
}
697+
}
698+
else
699+
{
700+
LCAO_deepks_io::save_npy_s(scs, file_sbase, ucell.omega,
701+
GlobalV::MY_RANK); // sbase = stot
835702
}
836703
}
837704
#endif

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,23 @@ void Force_LCAO<double>::ftable(const bool isforce,
252252
if (PARAM.inp.deepks_scf)
253253
{
254254
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
255-
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);
256-
257255
DeePKS_domain::cal_descriptor(ucell.nat,
258256
GlobalC::ld.inlmax,
259257
GlobalC::ld.inl_l,
260258
GlobalC::ld.pdm,
261259
descriptor,
262260
GlobalC::ld.des_per_atom);
263-
GlobalC::ld.cal_gedm(ucell.nat, descriptor);
261+
DeePKS_domain::cal_gedm(ucell.nat,
262+
GlobalC::ld.lmaxd,
263+
GlobalC::ld.nmaxd,
264+
GlobalC::ld.inlmax,
265+
GlobalC::ld.des_per_atom,
266+
GlobalC::ld.inl_l,
267+
descriptor,
268+
GlobalC::ld.pdm,
269+
GlobalC::ld.model_deepks,
270+
GlobalC::ld.gedm,
271+
GlobalC::ld.E_delta);
264272

265273
const int nks = 1;
266274
DeePKS_domain::cal_f_delta<double>(dm_gamma,
@@ -302,32 +310,8 @@ void Force_LCAO<double>::ftable(const bool isforce,
302310
}
303311

304312
#ifdef __DEEPKS
305-
// It seems these test should not all be here, should be moved in the future
306-
// Also, these test are not in multi-k case now
307313
if (PARAM.inp.deepks_scf && PARAM.inp.deepks_out_unittest)
308314
{
309-
const int nks = 1; // 1 for gamma-only
310-
LCAO_deepks_io::print_dm(nks, PARAM.globalv.nlocal, this->ParaV->nrow, dm_gamma);
311-
312-
GlobalC::ld.check_projected_dm();
313-
314-
DeePKS_domain::check_descriptor(GlobalC::ld.inlmax,
315-
GlobalC::ld.des_per_atom,
316-
GlobalC::ld.inl_l,
317-
ucell,
318-
PARAM.globalv.global_out_dir,
319-
descriptor);
320-
321-
GlobalC::ld.check_gedm();
322-
323-
GlobalC::ld.cal_e_delta_band(dm_gamma, nks);
324-
325-
std::ofstream ofs("E_delta_bands.dat");
326-
ofs << std::setprecision(10) << GlobalC::ld.e_delta_band;
327-
328-
std::ofstream ofs1("E_delta.dat");
329-
ofs1 << std::setprecision(10) << GlobalC::ld.E_delta;
330-
331315
DeePKS_domain::check_f_delta(ucell.nat, fvnl_dalpha, svnl_dalpha);
332316
}
333317
#endif

0 commit comments

Comments
 (0)