Skip to content

Commit 4484536

Browse files
authored
Refactor: remove GlobalC::Pkpoint (#5846)
* Refactor: remove GlobalC::Pkpoint
1 parent 16714c6 commit 4484536

22 files changed

+61
-101
lines changed

source/module_cell/klist.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ void K_Vectors::set(const UnitCell& ucell,
148148
// It's very important in parallel case,
149149
// firstly do the mpi_k() and then
150150
// do set_kup_and_kdw()
151-
GlobalC::Pkpoints.kinfo(nkstot,
152-
GlobalV::KPAR,
153-
GlobalV::MY_POOL,
154-
GlobalV::RANK_IN_POOL,
155-
GlobalV::NPROC,
156-
nspin_in); // assign k points to several process pools
151+
this->para_k.kinfo(nkstot,
152+
GlobalV::KPAR,
153+
GlobalV::MY_POOL,
154+
GlobalV::RANK_IN_POOL,
155+
GlobalV::NPROC,
156+
nspin_in); // assign k points to several process pools
157157
#ifdef __MPI
158158
// distribute K point data to the corresponding process
159159
this->mpi_k(); // 2008-4-29
@@ -1163,7 +1163,7 @@ void K_Vectors::mpi_k()
11631163

11641164
Parallel_Common::bcast_double(koffset, 3);
11651165

1166-
this->nks = GlobalC::Pkpoints.nks_pool[GlobalV::MY_POOL];
1166+
this->nks = this->para_k.nks_pool[GlobalV::MY_POOL];
11671167

11681168
GlobalV::ofs_running << std::endl;
11691169
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "k-point number in this process", nks);
@@ -1217,7 +1217,7 @@ void K_Vectors::mpi_k()
12171217
for (int i = 0; i < nks; i++)
12181218
{
12191219
// 3 is because each k point has three value:kx, ky, kz
1220-
k_index = i + GlobalC::Pkpoints.startk_pool[GlobalV::MY_POOL];
1220+
k_index = i + this->para_k.startk_pool[GlobalV::MY_POOL];
12211221
kvec_c[i].x = kvec_c_aux[k_index * 3];
12221222
kvec_c[i].y = kvec_c_aux[k_index * 3 + 1];
12231223
kvec_c[i].z = kvec_c_aux[k_index * 3 + 2];

source/module_cell/klist.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "module_base/global_variable.h"
66
#include "module_base/matrix3.h"
77
#include "module_cell/unitcell.h"
8-
8+
#include "parallel_kpoints.h"
99
#include <vector>
1010

1111
class K_Vectors
@@ -31,6 +31,9 @@ class K_Vectors
3131
K_Vectors& operator=(const K_Vectors&) = default;
3232
K_Vectors& operator=(K_Vectors&& rhs) = default;
3333

34+
Parallel_Kpoints para_k; ///< parallel for kpoints
35+
36+
3437
/**
3538
* @brief Set up the k-points for the system.
3639
*

source/module_cell/parallel_kpoints.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,6 @@
33
#include "module_base/parallel_common.h"
44
#include "module_base/parallel_global.h"
55

6-
Parallel_Kpoints::Parallel_Kpoints()
7-
{
8-
}
9-
10-
Parallel_Kpoints::~Parallel_Kpoints()
11-
{
12-
}
13-
146
// the kpoints here are reduced after symmetry applied.
157
void Parallel_Kpoints::kinfo(int& nkstot_in,
168
const int& kpar_in,
@@ -227,7 +219,7 @@ void Parallel_Kpoints::pool_collection(double* value_re,
227219
return;
228220
}
229221

230-
void Parallel_Kpoints::pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik)
222+
void Parallel_Kpoints::pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik) const
231223
{
232224
const int dim2 = w.getBound2();
233225
const int dim3 = w.getBound3();
@@ -237,7 +229,7 @@ void Parallel_Kpoints::pool_collection(std::complex<double>* value, const Module
237229
}
238230

239231
template <class T, class V>
240-
void Parallel_Kpoints::pool_collection_aux(T* value, const V& w, const int& dim, const int& ik)
232+
void Parallel_Kpoints::pool_collection_aux(T* value, const V& w, const int& dim, const int& ik) const
241233
{
242234
#ifdef __MPI
243235
const int ik_now = ik - this->startk_pool[this->my_pool];

source/module_cell/parallel_kpoints.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
class Parallel_Kpoints
1010
{
1111
public:
12-
Parallel_Kpoints();
13-
~Parallel_Kpoints();
12+
Parallel_Kpoints(){};
13+
~Parallel_Kpoints(){};
1414

1515
void kinfo(int& nkstot_in,
1616
const int& kpar_in,
@@ -28,9 +28,9 @@ class Parallel_Kpoints
2828
const ModuleBase::realArray& a,
2929
const ModuleBase::realArray& b,
3030
const int& ik);
31-
void pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik);
31+
void pool_collection(std::complex<double>* value, const ModuleBase::ComplexArray& w, const int& ik) const;
3232
template <class T, class V>
33-
void pool_collection_aux(T* value, const V& w, const int& dim, const int& ik);
33+
void pool_collection_aux(T* value, const V& w, const int& dim, const int& ik) const;
3434
#ifdef __MPI
3535
/**
3636
* @brief gather kpoints from all processors
@@ -46,8 +46,8 @@ class Parallel_Kpoints
4646
// int* nproc_pool = nullptr; it is not used
4747

4848
// inforamation about kpoints, dim: KPAR
49-
std::vector<int> nks_pool; // number of k-points in each pool
50-
std::vector<int> startk_pool; // the first k-point in each pool
49+
std::vector<int> nks_pool; // number of k-points in each pool, here use k-points without spin
50+
std::vector<int> startk_pool; // the first k-point in each pool, here use k-points without spin
5151

5252
// information about which pool each k-point belongs to,
5353
std::vector<int> whichpool; // whichpool[k] : the pool which k belongs to, dim: nkstot_np

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
414414
// qianrui modify 2020-10-18
415415
if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "md" || PARAM.inp.calculation == "relax")
416416
{
417-
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv, &(GlobalC::Pkpoints));
417+
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv);
418418
}
419419

420420
const int nspin0 = (PARAM.inp.nspin == 2) ? 2 : 1;
@@ -432,8 +432,7 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
432432
0.0,
433433
PARAM.inp.out_band[1],
434434
this->pelec->ekb,
435-
this->kv,
436-
&(GlobalC::Pkpoints));
435+
this->kv);
437436
}
438437
} // out_band
439438

@@ -452,7 +451,6 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
452451
PARAM.inp.dos_scale,
453452
PARAM.inp.dos_sigma,
454453
*(this->pelec->klist),
455-
GlobalC::Pkpoints,
456454
ucell,
457455
this->pelec->eferm,
458456
PARAM.inp.nbands,

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ void ESolver_KS_PW<T, Device>::before_all_runners(UnitCell& ucell, const Input_p
209209
GlobalV::MY_RANK,
210210
ucell,
211211
this->sf,
212-
GlobalC::Pkpoints,
212+
this->kv.para_k,
213213
this->ppcell,
214214
*this->pw_wfc);
215215
allocate_psi(this->psi, this->kv.get_nks(), this->kv.ngk.data(), PARAM.inp.nbands, this->pw_wfc->npwk_max);
@@ -844,7 +844,7 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
844844
}
845845

846846
//! 2) Print occupation numbers into istate.info
847-
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv, &(GlobalC::Pkpoints));
847+
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv);
848848

849849
//! 3) Compute density of states (DOS)
850850
if (PARAM.inp.out_dos)
@@ -883,8 +883,7 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
883883
0.0,
884884
PARAM.inp.out_band[1],
885885
this->pelec->ekb,
886-
this->kv,
887-
&(GlobalC::Pkpoints));
886+
this->kv);
888887
}
889888
}
890889

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ void ESolver_SDFT_PW<T, Device>::after_all_runners(UnitCell& ucell)
266266
GlobalV::ofs_running << std::setprecision(16);
267267
GlobalV::ofs_running << " !FINAL_ETOT_IS " << this->pelec->f_en.etot * ModuleBase::Ry_to_eV << " eV" << std::endl;
268268
GlobalV::ofs_running << " --------------------------------------------\n\n" << std::endl;
269-
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv, &(GlobalC::Pkpoints));
269+
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv);
270270
}
271271

272272
template <>
@@ -277,7 +277,7 @@ void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>::after_all_r
277277
GlobalV::ofs_running << std::setprecision(16);
278278
GlobalV::ofs_running << " !FINAL_ETOT_IS " << this->pelec->f_en.etot * ModuleBase::Ry_to_eV << " eV" << std::endl;
279279
GlobalV::ofs_running << " --------------------------------------------\n\n" << std::endl;
280-
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv, &(GlobalC::Pkpoints));
280+
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, this->kv);
281281

282282
if (this->method_sto == 2)
283283
{

source/module_hamilt_pw/hamilt_pwdft/global.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ namespace GlobalC
264264
#include "module_cell/unitcell.h"
265265
namespace GlobalC
266266
{
267-
extern Parallel_Kpoints Pkpoints;
268267
extern Restart restart; // Peize Lin add 2020.04.04
269268
} // namespace GlobalC
270269

source/module_io/dos_nao.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ namespace ModuleIO
1414
/// @param[in] dos_scale
1515
/// @param[in] dos_sigma
1616
/// @param[in] kv
17-
/// @param[in] Pkpoints
1817
/// @param[in] ucell
1918
/// @param[in] eferm
2019
/// @param[in] nbands
@@ -28,7 +27,6 @@ namespace ModuleIO
2827
const double& dos_scale,
2928
const double& dos_sigma,
3029
const K_Vectors& kv,
31-
const Parallel_Kpoints& Pkpoints,
3230
const UnitCell& ucell,
3331
const elecstate::efermi& eferm,
3432
int nbands,
@@ -45,7 +43,7 @@ namespace ModuleIO
4543
{
4644
std::stringstream ss3;
4745
ss3 << PARAM.globalv.global_out_dir << "Fermi_Surface_" << i << ".bxsf";
48-
nscf_fermi_surface(ss3.str(), nbands, eferm.ef, kv, Pkpoints, ucell, ekb);
46+
nscf_fermi_surface(ss3.str(), nbands, eferm.ef, kv, ucell, ekb);
4947
}
5048
}
5149

@@ -69,7 +67,6 @@ template void out_dos_nao(
6967
const double& dos_scale,
7068
const double& dos_sigma,
7169
const K_Vectors& kv,
72-
const Parallel_Kpoints& Pkpoints,
7370
const UnitCell& ucell,
7471
const elecstate::efermi& eferm,
7572
int nbands,
@@ -84,7 +81,6 @@ template void out_dos_nao(
8481
const double& dos_scale,
8582
const double& dos_sigma,
8683
const K_Vectors& kv,
87-
const Parallel_Kpoints& Pkpoints,
8884
const UnitCell& ucell,
8985
const elecstate::efermi& eferm,
9086
int nbands,

source/module_io/dos_nao.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ namespace ModuleIO
1818
const double& dos_scale,
1919
const double& dos_sigma,
2020
const K_Vectors& kv,
21-
const Parallel_Kpoints& Pkpoints,
2221
const UnitCell& ucell,
2322
const elecstate::efermi& eferm,
2423
int nbands,

source/module_io/nscf_band.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@ void ModuleIO::nscf_band(
1212
const double &fermie,
1313
const int &precision,
1414
const ModuleBase::matrix& ekb,
15-
const K_Vectors& kv,
16-
const Parallel_Kpoints* Pkpoints)
15+
const K_Vectors& kv)
1716
{
1817
ModuleBase::TITLE("ModuleIO","nscf_band");
1918
ModuleBase::timer::tick("ModuleIO", "nscf_band");
2019
// number of k points without spin; nspin = 1,2, nkstot = nkstot_np * nspin;
2120
// nspin = 4, nkstot = nkstot_np
22-
const int nkstot_np = Pkpoints->nkstot_np;
23-
const int nks_np = Pkpoints->nks_np;
21+
const int nkstot_np = kv.para_k.nkstot_np;
22+
const int nks_np = kv.para_k.nks_np;
2423

2524
#ifdef __MPI
2625
if(GlobalV::MY_RANK==0)
@@ -33,7 +32,7 @@ void ModuleIO::nscf_band(
3332
klength.resize(nkstot_np);
3433
klength[0] = 0.0;
3534
std::vector<ModuleBase::Vector3<double>> kvec_c_global;
36-
Pkpoints->gatherkvec(kv.kvec_c, kvec_c_global);
35+
kv.para_k.gatherkvec(kv.kvec_c, kvec_c_global);
3736
for(int ik=0; ik<nkstot_np; ik++)
3837
{
3938
if (ik>0)
@@ -43,10 +42,10 @@ void ModuleIO::nscf_band(
4342
klength[ik] += (kv.kl_segids[ik] == kv.kl_segids[ik-1]) ? delta.norm() : 0.0;
4443
}
4544
/* first find if present kpoint in present pool */
46-
if ( GlobalV::MY_POOL == Pkpoints->whichpool[ik] )
45+
if ( GlobalV::MY_POOL == kv.para_k.whichpool[ik] )
4746
{
4847
/* then get the local kpoint index, which starts definitly from 0 */
49-
const int ik_now = ik - Pkpoints->startk_pool[GlobalV::MY_POOL];
48+
const int ik_now = ik - kv.para_k.startk_pool[GlobalV::MY_POOL];
5049
/* if present kpoint corresponds the spin of the present one */
5150
assert( kv.isk[ik_now+is*nks_np] == is );
5251
if ( GlobalV::RANK_IN_POOL == 0)

source/module_io/nscf_band.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@ namespace ModuleIO
1616
* @param precision precision of the output
1717
* @param ekb eigenvalues of k points and bands
1818
* @param kv klist
19-
* @param Pkpoints parallel kpoints
2019
*/
2120
void nscf_band(const int& is,
2221
const std::string& out_band_dir,
2322
const int& nband,
2423
const double& fermie,
2524
const int& precision,
2625
const ModuleBase::matrix& ekb,
27-
const K_Vectors& kv,
28-
const Parallel_Kpoints* Pkpoints);
26+
const K_Vectors& kv);
2927
}
3028

3129
#endif

source/module_io/nscf_fermi_surf.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ void ModuleIO::nscf_fermi_surface(const std::string &out_band_dir,
88
const int &nband,
99
const double &ef,
1010
const K_Vectors& kv,
11-
const Parallel_Kpoints& Pkpoints,
1211
const UnitCell& ucell,
1312
const ModuleBase::matrix &ekb)
1413
{
@@ -29,7 +28,7 @@ void ModuleIO::nscf_fermi_surface(const std::string &out_band_dir,
2928

3029
for(int ik=0; ik<kv.get_nkstot(); ik++)
3130
{
32-
if ( GlobalV::MY_POOL == Pkpoints.whichpool[ik] )
31+
if ( GlobalV::MY_POOL == kv.para_k.whichpool[ik] )
3332
{
3433
if( GlobalV::RANK_IN_POOL == 0)
3534
{
@@ -58,7 +57,7 @@ void ModuleIO::nscf_fermi_surface(const std::string &out_band_dir,
5857
ofs << " " << ucell.G.e31 << " " << ucell.G.e32 << " " << ucell.G.e33 << std::endl;
5958
}
6059

61-
const int ik_now = ik - Pkpoints.startk_pool[GlobalV::MY_POOL];
60+
const int ik_now = ik - kv.para_k.startk_pool[GlobalV::MY_POOL];
6261
ofs << "ik= " << ik << std::endl;
6362
ofs << kv.kvec_c[ik_now].x << " " << kv.kvec_c[ik_now].y << " " << kv.kvec_c[ik_now].z << std::endl;
6463

source/module_io/nscf_fermi_surf.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ void nscf_fermi_surface(const std::string& out_band_dir,
1111
const int& nband,
1212
const double& ef,
1313
const K_Vectors& kv,
14-
const Parallel_Kpoints& Pkpoints,
1514
const UnitCell& ucell,
1615
const ModuleBase::matrix& ekb);
1716
}

source/module_io/numerical_basis.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -654,9 +654,9 @@ void Numerical_Basis::output_k(std::ofstream& ofs, const K_Vectors& kv)
654654
// temprary restrict kpar=1 for NSPIN=2 case for generating_orbitals
655655
int pool = 0;
656656
if (PARAM.inp.nspin != 2) {
657-
pool = GlobalC::Pkpoints.whichpool[ik];
657+
pool = kv.para_k.whichpool[ik];
658658
}
659-
const int iknow = ik - GlobalC::Pkpoints.startk_pool[GlobalV::MY_POOL];
659+
const int iknow = ik - kv.para_k.startk_pool[GlobalV::MY_POOL];
660660
if (GlobalV::RANK_IN_POOL == 0)
661661
{
662662
if (GlobalV::MY_POOL == 0)
@@ -671,7 +671,7 @@ void Numerical_Basis::output_k(std::ofstream& ofs, const K_Vectors& kv)
671671
else
672672
{
673673

674-
int startpro_pool = GlobalC::Pkpoints.get_startpro_pool(pool);
674+
int startpro_pool = kv.para_k.get_startpro_pool(pool);
675675
MPI_Status ierror;
676676
MPI_Recv(&kx, 1, MPI_DOUBLE, startpro_pool, ik * 4, MPI_COMM_WORLD, &ierror);
677677
MPI_Recv(&ky, 1, MPI_DOUBLE, startpro_pool, ik * 4 + 1, MPI_COMM_WORLD, &ierror);
@@ -755,7 +755,7 @@ void Numerical_Basis::output_overlap_Q(std::ofstream& ofs, const std::vector<Mod
755755
{
756756
ModuleBase::ComplexArray Qtmp(overlap_Q[ik].getBound1(), overlap_Q[ik].getBound2(), overlap_Q[ik].getBound3());
757757
Qtmp.zero_out();
758-
GlobalC::Pkpoints.pool_collection(Qtmp.ptr, overlap_Q_k, ik);
758+
kv.para_k.pool_collection(Qtmp.ptr, overlap_Q_k, ik);
759759
if (GlobalV::MY_RANK == 0)
760760
{
761761
// ofs << "\n ik=" << ik;
@@ -803,12 +803,12 @@ void Numerical_Basis::output_overlap_Sq(const std::string& name, std::ofstream&
803803
{
804804
for (int ik = 0; ik < nkstot; ik++)
805805
{
806-
if (GlobalV::MY_POOL == GlobalC::Pkpoints.whichpool[ik])
806+
if (GlobalV::MY_POOL == kv.para_k.whichpool[ik])
807807
{
808808
if (GlobalV::RANK_IN_POOL == 0)
809809
{
810810
ofs.open(name.c_str(), std::ios::app);
811-
const int ik_now = ik - GlobalC::Pkpoints.startk_pool[GlobalV::MY_POOL] + is * nkstot;
811+
const int ik_now = ik - kv.para_k.startk_pool[GlobalV::MY_POOL] + is * nkstot;
812812

813813
const int size = overlap_Sq[ik_now].getSize();
814814
for (int i = 0; i < size; i++)

0 commit comments

Comments
 (0)