Skip to content

Refactor:remove MPI part funcs of ucell #5810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ void Driver::driver_run()
// the life of ucell should begin here, mohan 2024-05-12
UnitCell ucell;
ucell.setup(PARAM.inp.latname,
PARAM.inp.ntype,
PARAM.inp.lmaxmax,
PARAM.inp.init_vel,
PARAM.inp.fixed_axes);
PARAM.inp.ntype,
PARAM.inp.lmaxmax,
PARAM.inp.init_vel,
PARAM.inp.fixed_axes);

ucell.setup_cell(PARAM.globalv.global_in_stru, GlobalV::ofs_running);
Check_Atomic_Stru::check_atomic_stru(ucell, PARAM.inp.min_dist_coef);
Expand Down
5 changes: 2 additions & 3 deletions source/module_basis/module_ao/test/ORB_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void test_orb::TearDown()
}
ooo.clear_after_ions(OGT, ORB, 0, nproj);
delete[] nproj;
delete[] orbital_fn;
return;
}

Expand Down Expand Up @@ -75,7 +74,7 @@ void test_orb::set_orbs()
ORB.init(ofs_running,
ntype_read,
"./",
orbital_fn,
orbital_fn.data(),
descriptor_file,
lmax,
lcao_ecut,
Expand Down Expand Up @@ -114,7 +113,7 @@ void test_orb::set_files()

ModuleBase::GlobalFunc::SCAN_BEGIN(ifs, "NUMERICAL_ORBITAL");

orbital_fn = new std::string[ntype_read];
orbital_fn.resize(ntype_read);

for (int it = 0; it < ntype_read; it++)
{
Expand Down
4 changes: 2 additions & 2 deletions source/module_basis/module_ao/test/ORB_unittest.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class test_orb : public testing::Test
double randr(double Rmax);
void gen_table_center2();

bool force_flag = 0;
bool force_flag = false;
int my_rank = 0;
int ntype_read;

Expand All @@ -66,7 +66,7 @@ class test_orb : public testing::Test
int lmax = 1;
double lat0 = 1.0;
std::string case_dir = "./GaAs/";
std::string* orbital_fn;
std::vector<std::string> orbital_fn;
std::string descriptor_file;
};
#endif
108 changes: 107 additions & 1 deletion source/module_cell/bcast_cell.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include "unitcell.h"

#include "module_base/parallel_common.h"
#include "module_parameter/parameter.h"
#ifdef __EXX
#include "module_ri/serialization_cereal.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#endif
namespace unitcell
{
void bcast_atoms_tau(Atom* atoms,
Expand All @@ -12,4 +17,105 @@ namespace unitcell
}
#endif
}

void bcast_atoms_pseudo(Atom* atoms,
const int ntype)
{
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
for (int i = 0; i < ntype; i++)
{
atoms[i].bcast_atom2();
}
#endif
}

void bcast_Lattice(Lattice& lat)
{
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
// distribute lattice parameters.
ModuleBase::Matrix3& latvec = lat.latvec;
ModuleBase::Matrix3& latvec_supercell = lat.latvec_supercell;
Parallel_Common::bcast_string(lat.Coordinate);
Parallel_Common::bcast_double(lat.lat0);
Parallel_Common::bcast_double(lat.lat0_angstrom);
Parallel_Common::bcast_double(lat.tpiba);
Parallel_Common::bcast_double(lat.tpiba2);
Parallel_Common::bcast_double(lat.omega);
Parallel_Common::bcast_string(lat.latName);

// distribute lattice vectors.
Parallel_Common::bcast_double(latvec.e11);
Parallel_Common::bcast_double(latvec.e12);
Parallel_Common::bcast_double(latvec.e13);
Parallel_Common::bcast_double(latvec.e21);
Parallel_Common::bcast_double(latvec.e22);
Parallel_Common::bcast_double(latvec.e23);
Parallel_Common::bcast_double(latvec.e31);
Parallel_Common::bcast_double(latvec.e32);
Parallel_Common::bcast_double(latvec.e33);

// distribute lattice vectors.
for (int i = 0; i < 3; i++)
{
Parallel_Common::bcast_double(lat.a1[i]);
Parallel_Common::bcast_double(lat.a2[i]);
Parallel_Common::bcast_double(lat.a3[i]);
Parallel_Common::bcast_double(lat.latcenter[i]);
Parallel_Common::bcast_int(lat.lc[i]);
}

// distribute superlattice vectors.
Parallel_Common::bcast_double(latvec_supercell.e11);
Parallel_Common::bcast_double(latvec_supercell.e12);
Parallel_Common::bcast_double(latvec_supercell.e13);
Parallel_Common::bcast_double(latvec_supercell.e21);
Parallel_Common::bcast_double(latvec_supercell.e22);
Parallel_Common::bcast_double(latvec_supercell.e23);
Parallel_Common::bcast_double(latvec_supercell.e31);
Parallel_Common::bcast_double(latvec_supercell.e32);
Parallel_Common::bcast_double(latvec_supercell.e33);

// distribute Change the lattice vectors or not
#endif
}

void bcast_magnetism(Magnetism& magnet, const int ntype)
{
#ifdef __MPI
MPI_Barrier(MPI_COMM_WORLD);
Parallel_Common::bcast_double(magnet.start_magnetization, ntype);
if (PARAM.inp.nspin == 4)
{
Parallel_Common::bcast_double(magnet.ux_[0]);
Parallel_Common::bcast_double(magnet.ux_[1]);
Parallel_Common::bcast_double(magnet.ux_[2]);
}
#endif
}

void bcast_unitcell(UnitCell& ucell)
{
#ifdef __MPI
const int ntype = ucell.ntype;
Parallel_Common::bcast_int(ucell.nat);

bcast_Lattice(ucell.lat);
bcast_magnetism(ucell.magnet,ntype);
bcast_atoms_tau(ucell.atoms,ntype);

for (int i = 0; i < ntype; i++)
{
Parallel_Common::bcast_string(ucell.orbital_fn[i]);
}

#ifdef __EXX
ModuleBase::bcast_data_cereal(GlobalC::exx_info.info_ri.files_abfs,
MPI_COMM_WORLD,
0);
#endif
return;
#endif
}
}
40 changes: 40 additions & 0 deletions source/module_cell/bcast_cell.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,50 @@
#ifndef BCAST_CELL_H
#define BCAST_CELL_H

#include "module_cell/unitcell.h"
namespace unitcell
{
/**
* @brief broadcast the tau array of the atoms
*
* @param atoms: the atoms to be broadcasted [in/out]
* @param ntype: the number of types of the atoms [in]
*/
void bcast_atoms_tau(Atom* atoms,
const int ntype);

/**
* @brief broadcast the pseduo of the atoms
*
* @param atoms: the atoms to be broadcasted [in/out]
* @param ntype: the number of types of the atoms [in]
*/
void bcast_atoms_pseudo(Atom* atoms,
const int ntype);
/**
* @brief broadcast the lattice
*
* @param lat: the lattice to be broadcasted [in/out]
*/
void bcast_Lattice(Lattice& lat);

/**
* @brief broadcast the magnetism
*
* @param magnet: the magnetism to be broadcasted [in/out]
* @param nytpe: the number of types of the atoms [in]
*/
void bcast_magnetism(Magnetism& magnet,
const int ntype);

/**
* @brief broadcast the unitcell
*
* @param ucell: the unitcell to be broadcasted [in/out]
*/
void bcast_unitcell(UnitCell& ucell);


}

#endif // BCAST_CELL_H
4 changes: 2 additions & 2 deletions source/module_cell/module_neighbor/test/prepare_unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ class UcellTestPrepare
delete[] ucell->atom_mass;
delete[] ucell->pseudo_fn;
delete[] ucell->pseudo_type;
delete[] ucell->orbital_fn;

delete[] ucell->magnet.start_magnetization; //mag set here
ucell->atom_label = new std::string[ucell->ntype];
ucell->atom_mass = new double[ucell->ntype];
ucell->pseudo_fn = new std::string[ucell->ntype];
ucell->pseudo_type = new std::string[ucell->ntype];
ucell->orbital_fn = new std::string[ucell->ntype];
ucell->orbital_fn.resize(ucell->ntype);
ucell->magnet.start_magnetization = new double[ucell->ntype]; //mag set here
ucell->magnet.ux_[0] = 0.0; // ux_ set here
ucell->magnet.ux_[1] = 0.0;
Expand Down
2 changes: 0 additions & 2 deletions source/module_cell/read_atoms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ int UnitCell::read_atom_species(std::ifstream &ifa, std::ofstream &ofs_running)
delete[] atom_mass;
delete[] pseudo_fn;
delete[] pseudo_type;
delete[] orbital_fn;
this->atom_mass = new double[ntype]; //atom masses
this->atom_label = new std::string[ntype]; //atom labels
this->pseudo_fn = new std::string[ntype]; //file name of pseudopotential
this->pseudo_type = new std::string[ntype]; // type of pseudopotential
this->orbital_fn = new std::string[ntype]; // filename of orbitals

std::string word;
//==========================================
Expand Down
4 changes: 2 additions & 2 deletions source/module_cell/test/prepare_unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ class UcellTestPrepare
delete[] ucell->atom_mass;
delete[] ucell->pseudo_fn;
delete[] ucell->pseudo_type;
delete[] ucell->orbital_fn;

delete[] ucell->magnet.start_magnetization; //mag set here
ucell->atom_label = new std::string[ucell->ntype];
ucell->atom_mass = new double[ucell->ntype];
ucell->pseudo_fn = new std::string[ucell->ntype];
ucell->pseudo_type = new std::string[ucell->ntype];
ucell->orbital_fn = new std::string[ucell->ntype];
ucell->orbital_fn.resize(ucell->ntype);
ucell->magnet.start_magnetization = new double[ucell->ntype]; //mag set here
ucell->magnet.ux_[0] = 0.0; // ux_ set here
ucell->magnet.ux_[1] = 0.0;
Expand Down
5 changes: 0 additions & 5 deletions source/module_cell/test/support/mock_unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ UnitCell::~UnitCell() {
delete[] atom_mass;
delete[] pseudo_fn;
delete[] pseudo_type;
delete[] orbital_fn;
if (set_atom_flag) {
delete[] atoms;
}
Expand All @@ -37,10 +36,6 @@ bool UnitCell::read_atom_positions(std::ifstream& ifpos,
bool UnitCell::judge_big_cell() const { return true; }
void UnitCell::update_stress(ModuleBase::matrix& scs) {}
void UnitCell::update_force(ModuleBase::matrix& fcs) {}
#ifdef __MPI
void UnitCell::bcast_unitcell() {}
void UnitCell::bcast_unitcell2() {}
#endif
void UnitCell::set_iat2itia() {}
void UnitCell::setup_cell(const std::string& fn, std::ofstream& log) {}
void UnitCell::read_orb_file(int it,
Expand Down
Loading
Loading