Skip to content

Commit 4ddec65

Browse files
Refactor:Remove Ucell::update_pos_taud (#5794)
* update the update_pos_taud * change func with vector3 update_pos_taud * modify the input format * add unittest for the update_pos_tau * update test for relax_new * add update_vel for ucell * [pre-commit.ci lite] apply automatic fixes --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 9ab9150 commit 4ddec65

File tree

14 files changed

+220
-435
lines changed

14 files changed

+220
-435
lines changed

source/module_cell/atom_spec.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Atom
2222
std::vector<bool> iw2_new;
2323
int nw = 0; // number of local orbitals (l,n,m) of this type
2424

25-
void set_index(void);
25+
void set_index();
2626

2727
int type = 0; // Index of atom type
2828
int na = 0; // Number of atoms in this type.
@@ -34,8 +34,7 @@ class Atom
3434

3535
std::string label = "\0"; // atomic symbol
3636
std::vector<ModuleBase::Vector3<double>> tau; // Cartesian coordinates of each atom in this type.
37-
std::vector<ModuleBase::Vector3<double>>
38-
dis; // direct displacements of each atom in this type in current step liuyu modift 2023-03-22
37+
std::vector<ModuleBase::Vector3<double>> dis; // direct displacements of each atom in this type in current step liuyu modift 2023-03-22
3938
std::vector<ModuleBase::Vector3<double>> taud; // Direct coordinates of each atom in this type.
4039
std::vector<ModuleBase::Vector3<double>> vel; // velocities of each atom in this type.
4140
std::vector<ModuleBase::Vector3<double>> force; // force acting on each atom in this type.
@@ -54,8 +53,8 @@ class Atom
5453
void print_Atom(std::ofstream& ofs);
5554
void update_force(ModuleBase::matrix& fcs);
5655
#ifdef __MPI
57-
void bcast_atom(void);
58-
void bcast_atom2(void);
56+
void bcast_atom();
57+
void bcast_atom2();
5958
#endif
6059
};
6160

source/module_cell/test/support/mock_unitcell.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ bool UnitCell::read_atom_positions(std::ifstream& ifpos,
3333
std::ofstream& ofs_warning) {
3434
return true;
3535
}
36-
void UnitCell::update_pos_taud(double* posd_in) {}
37-
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {}
38-
void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {}
39-
void UnitCell::bcast_atoms_tau() {}
36+
4037
bool UnitCell::judge_big_cell() const { return true; }
4138
void UnitCell::update_stress(ModuleBase::matrix& scs) {}
4239
void UnitCell::update_force(ModuleBase::matrix& fcs) {}

source/module_cell/test/unitcell_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,7 @@ TEST_F(UcellTest, UpdateVel)
10211021
{
10221022
vel_in[iat].set(iat * 0.1, iat * 0.1, iat * 0.1);
10231023
}
1024-
ucell->update_vel(vel_in);
1024+
unitcell::update_vel(vel_in,ucell->ntype,ucell->nat,ucell->atoms);
10251025
for (int iat = 0; iat < ucell->nat; ++iat)
10261026
{
10271027
EXPECT_DOUBLE_EQ(vel_in[iat].x, 0.1 * iat);

source/module_cell/test/unitcell_test_para.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ TEST_F(UcellTest, UpdatePosTau)
153153
}
154154
delete[] pos_in;
155155
}
156-
TEST_F(UcellTest, UpdatePosTaud)
156+
TEST_F(UcellTest, UpdatePosTaud_pointer)
157157
{
158158
double* pos_in = new double[ucell->nat * 3];
159159
ModuleBase::Vector3<double>* tmp = new ModuleBase::Vector3<double>[ucell->nat];
@@ -167,7 +167,8 @@ TEST_F(UcellTest, UpdatePosTaud)
167167
ucell->iat2iait(iat, &ia, &it);
168168
tmp[iat] = ucell->atoms[it].taud[ia];
169169
}
170-
ucell->update_pos_taud(pos_in);
170+
unitcell::update_pos_taud(ucell->lat,pos_in,ucell->ntype,
171+
ucell->nat,ucell->atoms);
171172
for (int iat = 0; iat < ucell->nat; ++iat)
172173
{
173174
int it, ia;
@@ -180,6 +181,37 @@ TEST_F(UcellTest, UpdatePosTaud)
180181
delete[] pos_in;
181182
}
182183

184+
//test update_pos_taud with ModuleBase::Vector3<double> version
185+
TEST_F(UcellTest, UpdatePosTaud_Vector3)
186+
{
187+
ModuleBase::Vector3<double>* pos_in = new ModuleBase::Vector3<double>[ucell->nat];
188+
ModuleBase::Vector3<double>* tmp = new ModuleBase::Vector3<double>[ucell->nat];
189+
ucell->set_iat2itia();
190+
for (int iat = 0; iat < ucell->nat; ++iat)
191+
{
192+
for (int ik = 0; ik < 3; ++ik)
193+
{
194+
pos_in[iat][ik] = 0.01;
195+
}
196+
int it=0;
197+
int ia=0;
198+
ucell->iat2iait(iat, &ia, &it);
199+
tmp[iat] = ucell->atoms[it].taud[ia];
200+
}
201+
unitcell::update_pos_taud(ucell->lat,pos_in,ucell->ntype,
202+
ucell->nat,ucell->atoms);
203+
for (int iat = 0; iat < ucell->nat; ++iat)
204+
{
205+
int it, ia;
206+
ucell->iat2iait(iat, &ia, &it);
207+
for (int ik = 0; ik < 3; ++ik)
208+
{
209+
EXPECT_DOUBLE_EQ(ucell->atoms[it].taud[ia][ik], tmp[iat][ik] + 0.01);
210+
}
211+
}
212+
delete[] tmp;
213+
delete[] pos_in;
214+
}
183215
TEST_F(UcellTest, ReadPseudo)
184216
{
185217
PARAM.input.pseudo_dir = pp_dir;

source/module_cell/unitcell.cpp

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -314,65 +314,6 @@ std::vector<ModuleBase::Vector3<int>> UnitCell::get_constrain() const
314314
return constrain;
315315
}
316316

317-
318-
319-
void UnitCell::update_pos_taud(double* posd_in) {
320-
int iat = 0;
321-
for (int it = 0; it < this->ntype; it++) {
322-
Atom* atom = &this->atoms[it];
323-
for (int ia = 0; ia < atom->na; ia++) {
324-
for (int ik = 0; ik < 3; ++ik) {
325-
atom->taud[ia][ik] += posd_in[3 * iat + ik];
326-
atom->dis[ia][ik] = posd_in[3 * iat + ik];
327-
}
328-
iat++;
329-
}
330-
}
331-
assert(iat == this->nat);
332-
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
333-
this->bcast_atoms_tau();
334-
}
335-
336-
// posd_in is atomic displacements here liuyu 2023-03-22
337-
void UnitCell::update_pos_taud(const ModuleBase::Vector3<double>* posd_in) {
338-
int iat = 0;
339-
for (int it = 0; it < this->ntype; it++) {
340-
Atom* atom = &this->atoms[it];
341-
for (int ia = 0; ia < atom->na; ia++) {
342-
for (int ik = 0; ik < 3; ++ik) {
343-
atom->taud[ia][ik] += posd_in[iat][ik];
344-
atom->dis[ia][ik] = posd_in[iat][ik];
345-
}
346-
iat++;
347-
}
348-
}
349-
assert(iat == this->nat);
350-
unitcell::periodic_boundary_adjustment(this->atoms,this->latvec, this->ntype);
351-
this->bcast_atoms_tau();
352-
}
353-
354-
void UnitCell::update_vel(const ModuleBase::Vector3<double>* vel_in) {
355-
int iat = 0;
356-
for (int it = 0; it < this->ntype; ++it) {
357-
Atom* atom = &this->atoms[it];
358-
for (int ia = 0; ia < atom->na; ++ia) {
359-
this->atoms[it].vel[ia] = vel_in[iat];
360-
++iat;
361-
}
362-
}
363-
assert(iat == this->nat);
364-
}
365-
366-
367-
void UnitCell::bcast_atoms_tau() {
368-
#ifdef __MPI
369-
MPI_Barrier(MPI_COMM_WORLD);
370-
for (int i = 0; i < ntype; i++) {
371-
atoms[i].bcast_atom(); // bcast tau array
372-
}
373-
#endif
374-
}
375-
376317
//==============================================================
377318
// Calculate various lattice related quantities for given latvec
378319
//==============================================================

source/module_cell/unitcell.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,6 @@ class UnitCell {
200200
void print_cell(std::ofstream& ofs) const;
201201
void print_cell_xyz(const std::string& fn) const;
202202

203-
void update_pos_taud(const ModuleBase::Vector3<double>* posd_in);
204-
void update_pos_taud(double* posd_in);
205-
void update_vel(const ModuleBase::Vector3<double>* vel_in);
206-
void bcast_atoms_tau();
207203
bool judge_big_cell() const;
208204

209205
void update_stress(ModuleBase::matrix& scs); // updates stress

source/module_cell/update_cell.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,75 @@ void update_pos_tau(const Lattice& lat,
374374
bcast_atoms_tau(atoms, ntype);
375375
}
376376

377+
void update_pos_taud(const Lattice& lat,
378+
const double* posd_in,
379+
const int ntype,
380+
const int nat,
381+
Atom* atoms)
382+
{
383+
int iat = 0;
384+
for (int it = 0; it < ntype; it++)
385+
{
386+
Atom* atom = &atoms[it];
387+
for (int ia = 0; ia < atom->na; ia++)
388+
{
389+
for (int ik = 0; ik < 3; ++ik)
390+
{
391+
atom->taud[ia][ik] += posd_in[3 * iat + ik];
392+
atom->dis[ia][ik] = posd_in[3 * iat + ik];
393+
}
394+
iat++;
395+
}
396+
}
397+
assert(iat == nat);
398+
periodic_boundary_adjustment(atoms,lat.latvec,ntype);
399+
bcast_atoms_tau(atoms, ntype);
400+
}
401+
402+
// posd_in is atomic displacements here liuyu 2023-03-22
403+
void update_pos_taud(const Lattice& lat,
404+
const ModuleBase::Vector3<double>* posd_in,
405+
const int ntype,
406+
const int nat,
407+
Atom* atoms)
408+
{
409+
int iat = 0;
410+
for (int it = 0; it < ntype; it++)
411+
{
412+
Atom* atom = &atoms[it];
413+
for (int ia = 0; ia < atom->na; ia++)
414+
{
415+
for (int ik = 0; ik < 3; ++ik)
416+
{
417+
atom->taud[ia][ik] += posd_in[iat][ik];
418+
atom->dis[ia][ik] = posd_in[iat][ik];
419+
}
420+
iat++;
421+
}
422+
}
423+
assert(iat == nat);
424+
periodic_boundary_adjustment(atoms,lat.latvec,ntype);
425+
bcast_atoms_tau(atoms, ntype);
426+
}
427+
428+
void update_vel(const ModuleBase::Vector3<double>* vel_in,
429+
const int ntype,
430+
const int nat,
431+
Atom* atoms)
432+
{
433+
int iat = 0;
434+
for (int it = 0; it < ntype; ++it)
435+
{
436+
Atom* atom = &atoms[it];
437+
for (int ia = 0; ia < atom->na; ++ia)
438+
{
439+
atoms[it].vel[ia] = vel_in[iat];
440+
++iat;
441+
}
442+
}
443+
assert(iat == nat);
444+
}
445+
377446
void periodic_boundary_adjustment(Atom* atoms,
378447
const ModuleBase::Matrix3& latvec,
379448
const int ntype)

source/module_cell/update_cell.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,48 @@ namespace unitcell
4848
const int ntype,
4949
const int nat,
5050
Atom* atoms);
51+
52+
/**
53+
* @brief update the position and tau of the atoms
54+
*
55+
* @param lat: the lattice of the atoms [in]
56+
* @param pos_in: the position of the atoms in direct coordinate system [in]
57+
* @param ntype: the number of types of the atoms [in]
58+
* @param nat: the number of atoms [in]
59+
* @param atoms: the atoms to be updated [out]
60+
*/
61+
void update_pos_taud(const Lattice& lat,
62+
const double* posd_in,
63+
const int ntype,
64+
const int nat,
65+
Atom* atoms);
66+
/**
67+
* @brief update the velocity of the atoms
68+
*
69+
* @param lat: the lattice of the atoms [in]
70+
* @param pos_in: the position of the atoms in direct coordinate system
71+
* in ModuleBase::Vector3 version [in]
72+
* @param ntype: the number of types of the atoms [in]
73+
* @param nat: the number of atoms [in]
74+
* @param atoms: the atoms to be updated [out]
75+
*/
76+
void update_pos_taud(const Lattice& lat,
77+
const ModuleBase::Vector3<double>* posd_in,
78+
const int ntype,
79+
const int nat,
80+
Atom* atoms);
81+
/**
82+
* @brief update the velocity of the atoms
83+
*
84+
* @param vel_in: the velocity of the atoms [in]
85+
* @param ntype: the number of types of the atoms [in]
86+
* @param nat: the number of atoms [in]
87+
* @param atoms: the atoms to be updated [out]
88+
*/
89+
void update_vel(const ModuleBase::Vector3<double>* vel_in,
90+
const int ntype,
91+
const int nat,
92+
Atom* atoms);
5193
}
5294
//
5395
#endif // UPDATE_CELL_H

source/module_md/md_base.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
#include "md_base.h"
2-
32
#include "md_func.h"
43
#ifdef __MPI
54
#include "mpi.h"
65
#endif
76
#include "module_io/print_info.h"
8-
7+
#include "module_cell/update_cell.h"
98
MD_base::MD_base(const Parameter& param_in, UnitCell& unit_in) : mdp(param_in.mdp), ucell(unit_in)
109
{
1110
my_rank = param_in.globalv.myrank;
@@ -112,7 +111,7 @@ void MD_base::update_pos()
112111
MPI_Bcast(pos, ucell.nat * 3, MPI_DOUBLE, 0, MPI_COMM_WORLD);
113112
#endif
114113

115-
ucell.update_pos_taud(pos);
114+
unitcell::update_pos_taud(ucell.lat,pos,ucell.ntype,ucell.nat,ucell.atoms);
116115

117116
return;
118117
}

source/module_md/run_md.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "msst.h"
1111
#include "nhchain.h"
1212
#include "verlet.h"
13-
13+
#include "module_cell/update_cell.h"
1414
namespace Run_MD
1515
{
1616

@@ -97,7 +97,7 @@ void md_line(UnitCell& unit_in, ModuleESolver::ESolver* p_esolver, const Paramet
9797

9898
if ((mdrun->step_ + mdrun->step_rst_) % param_in.mdp.md_restartfreq == 0)
9999
{
100-
unit_in.update_vel(mdrun->vel);
100+
unitcell::update_vel(mdrun->vel,unit_in.ntype,unit_in.nat,unit_in.atoms);
101101
std::stringstream file;
102102
file << PARAM.globalv.global_stru_dir << "STRU_MD_" << mdrun->step_ + mdrun->step_rst_;
103103
// changelog 20240509

source/module_relax/relax_new/relax.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ void Relax::move_cell_ions(UnitCell& ucell, const bool is_new_dir)
633633
ucell.symm.symmetrize_vec3_nat(move_ion);
634634
}
635635

636-
ucell.update_pos_taud(move_ion);
636+
unitcell::update_pos_taud(ucell.lat,move_ion,ucell.ntype,ucell.nat,ucell.atoms);
637637

638638
// Print the structure file.
639639
ucell.print_tau();

0 commit comments

Comments
 (0)