4
4
5
5
using namespace hamilt ;
6
6
7
-
8
- template <typename T, typename Device>
9
- Operator<T, Device>::Operator(){}
10
-
11
- template <typename T, typename Device>
12
- Operator<T, Device>::~Operator ()
7
+ template <typename T, typename Device>
8
+ Operator<T, Device>::Operator()
13
9
{
14
- if (this ->hpsi != nullptr ) { delete this ->hpsi ;
15
10
}
11
+
12
+ template <typename T, typename Device>
13
+ Operator<T, Device>::~Operator ()
14
+ {
15
+ if (this ->hpsi != nullptr )
16
+ {
17
+ delete this ->hpsi ;
18
+ }
16
19
Operator* last = this ->next_op ;
17
20
Operator* last_sub = this ->next_sub_op ;
18
- while (last != nullptr || last_sub != nullptr )
21
+ while (last != nullptr || last_sub != nullptr )
19
22
{
20
- if (last_sub != nullptr )
21
- {// delete sub_chain first
23
+ if (last_sub != nullptr )
24
+ { // delete sub_chain first
22
25
Operator* node_delete = last_sub;
23
26
last_sub = last_sub->next_sub_op ;
24
27
node_delete->next_sub_op = nullptr ;
25
28
delete node_delete;
26
29
}
27
30
else
28
- {// delete main chain if sub_chain is deleted
31
+ { // delete main chain if sub_chain is deleted
29
32
Operator* node_delete = last;
30
33
last_sub = last->next_sub_op ;
31
34
node_delete->next_sub_op = nullptr ;
@@ -36,7 +39,7 @@ Operator<T, Device>::~Operator()
36
39
}
37
40
}
38
41
39
- template <typename T, typename Device>
42
+ template <typename T, typename Device>
40
43
typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& input) const
41
44
{
42
45
using syncmem_op = base_device::memory::synchronize_memory_op<T, Device, Device>;
@@ -46,37 +49,51 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
46
49
47
50
T* tmhpsi = this ->get_hpsi (input);
48
51
const T* tmpsi_in = std::get<0 >(psi_info);
49
- // if range in hpsi_info is illegal, the first return of to_range() would be nullptr
52
+ // if range in hpsi_info is illegal, the first return of to_range() would be nullptr
50
53
if (tmpsi_in == nullptr )
51
54
{
52
55
ModuleBase::WARNING_QUIT (" Operator" , " please choose correct range of psi for hPsi()!" );
53
56
}
54
- // if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
57
+ // if in_place, copy temporary hpsi to target hpsi_pointer, then delete hpsi and new a wrapper for return
55
58
T* hpsi_pointer = std::get<2 >(input);
56
59
if (this ->in_place )
57
60
{
58
61
// ModuleBase::GlobalFunc::COPYARRAY(this->hpsi->get_pointer(), hpsi_pointer, this->hpsi->size());
59
62
syncmem_op ()(this ->ctx , this ->ctx , hpsi_pointer, this ->hpsi ->get_pointer (), this ->hpsi ->size ());
60
63
delete this ->hpsi ;
61
- this ->hpsi = new psi::Psi<T, Device>(hpsi_pointer, *psi_input, 1 , nbands / psi_input->npol );
64
+ this ->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
65
+ 1 ,
66
+ nbands / psi_input->npol ,
67
+ psi_input->get_nbasis (),
68
+ psi_input->get_nbasis (),
69
+ true );
62
70
}
63
71
64
72
auto call_act = [&, this ](const Operator* op, const bool & is_first_node) -> void {
65
-
66
73
// a "psi" with the bands of needed range
67
- psi::Psi<T, Device> psi_wrapper (const_cast <T*>(tmpsi_in), 1 , nbands, psi_input->get_nbasis (), true );
68
-
69
-
74
+ psi::Psi<T, Device> psi_wrapper (const_cast <T*>(tmpsi_in),
75
+ 1 ,
76
+ nbands,
77
+ psi_input->get_nbasis (),
78
+ psi_input->get_nbasis (),
79
+ true );
80
+
70
81
switch (op->get_act_type ())
71
82
{
72
83
case 2 :
73
84
op->act (psi_wrapper, *this ->hpsi , nbands);
74
85
break ;
75
86
default :
76
- op->act (nbands, psi_input->get_nbasis (), psi_input->npol , tmpsi_in, this ->hpsi ->get_pointer (), psi_input->get_ngk (op->ik ), is_first_node);
87
+ op->act (nbands,
88
+ psi_input->get_nbasis (),
89
+ psi_input->npol ,
90
+ tmpsi_in,
91
+ this ->hpsi ->get_pointer (),
92
+ psi_input->get_current_nbas (),
93
+ is_first_node);
77
94
break ;
78
95
}
79
- };
96
+ };
80
97
81
98
ModuleBase::timer::tick (" Operator" , " hPsi" );
82
99
call_act (this , true ); // first node
@@ -91,39 +108,43 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
91
108
return hpsi_info (this ->hpsi , psi::Range (1 , 0 , 0 , nbands / psi_input->npol ), hpsi_pointer);
92
109
}
93
110
94
-
95
- template <typename T, typename Device>
96
- void Operator<T, Device>::init(const int ik_in)
111
+ template <typename T, typename Device>
112
+ void Operator<T, Device>::init(const int ik_in)
97
113
{
98
114
this ->ik = ik_in;
99
- if (this ->next_op != nullptr ) {
115
+ if (this ->next_op != nullptr )
116
+ {
100
117
this ->next_op ->init (ik_in);
101
118
}
102
119
}
103
120
104
- template <typename T, typename Device>
105
- void Operator<T, Device>::add(Operator* next)
121
+ template <typename T, typename Device>
122
+ void Operator<T, Device>::add(Operator* next)
106
123
{
107
- if (next==nullptr ) { return ;
108
- }
124
+ if (next == nullptr )
125
+ {
126
+ return ;
127
+ }
109
128
next->is_first_node = false ;
110
- if (next->next_op != nullptr ) { this ->add (next->next_op );
111
- }
129
+ if (next->next_op != nullptr )
130
+ {
131
+ this ->add (next->next_op );
132
+ }
112
133
Operator* last = this ;
113
- // loop to end of the chain
114
- while (last->next_op != nullptr )
134
+ // loop to end of the chain
135
+ while (last->next_op != nullptr )
115
136
{
116
- if (next->cal_type == last->cal_type )
137
+ if (next->cal_type == last->cal_type )
117
138
{
118
139
break ;
119
140
}
120
141
last = last->next_op ;
121
142
}
122
- if (next->cal_type == last->cal_type )
143
+ if (next->cal_type == last->cal_type )
123
144
{
124
- // insert next to sub chain of current node
145
+ // insert next to sub chain of current node
125
146
Operator* sub_last = last;
126
- while (sub_last->next_sub_op != nullptr )
147
+ while (sub_last->next_sub_op != nullptr )
127
148
{
128
149
sub_last = sub_last->next_sub_op ;
129
150
}
@@ -136,34 +157,45 @@ void Operator<T, Device>::add(Operator* next)
136
157
}
137
158
}
138
159
139
- template <typename T, typename Device>
160
+ template <typename T, typename Device>
140
161
T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
141
162
{
142
163
const int nbands_range = (std::get<1 >(info).range_2 - std::get<1 >(info).range_1 + 1 );
143
- // in_place call of hPsi, hpsi inputs as new psi,
144
- // create a new hpsi and delete old hpsi later
164
+ // in_place call of hPsi, hpsi inputs as new psi,
165
+ // create a new hpsi and delete old hpsi later
145
166
T* hpsi_pointer = std::get<2 >(info);
146
167
const T* psi_pointer = std::get<0 >(info)->get_pointer ();
147
- if (this ->hpsi != nullptr )
168
+ if (this ->hpsi != nullptr )
148
169
{
149
170
delete this ->hpsi ;
150
171
this ->hpsi = nullptr ;
151
172
}
152
- if (!hpsi_pointer)
173
+ if (!hpsi_pointer)
153
174
{
154
175
ModuleBase::WARNING_QUIT (" Operator::hPsi" , " hpsi_pointer can not be nullptr" );
155
176
}
156
- else if (hpsi_pointer == psi_pointer)
177
+ else if (hpsi_pointer == psi_pointer)
157
178
{
158
179
this ->in_place = true ;
159
- this ->hpsi = new psi::Psi<T, Device>(std::get<0 >(info)[0 ], 1 , nbands_range);
180
+ // this->hpsi = new psi::Psi<T, Device>(std::get<0>(info)[0], 1, nbands_range);
181
+ this ->hpsi = new psi::Psi<T, Device>(1 ,
182
+ nbands_range,
183
+ std::get<0 >(info)->get_nbasis (),
184
+ std::get<0 >(info)->get_nbasis (),
185
+ true );
160
186
}
161
187
else
162
188
{
163
189
this ->in_place = false ;
164
- this ->hpsi = new psi::Psi<T, Device>(hpsi_pointer, std::get<0 >(info)[0 ], 1 , nbands_range);
190
+
191
+ this ->hpsi = new psi::Psi<T, Device>(hpsi_pointer,
192
+ 1 ,
193
+ nbands_range,
194
+ std::get<0 >(info)->get_nbasis (),
195
+ std::get<0 >(info)->get_nbasis (),
196
+ true );
165
197
}
166
-
198
+
167
199
hpsi_pointer = this ->hpsi ->get_pointer ();
168
200
size_t total_hpsi_size = nbands_range * this ->hpsi ->get_nbasis ();
169
201
// ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size);
@@ -172,7 +204,8 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
172
204
return hpsi_pointer;
173
205
}
174
206
175
- namespace hamilt {
207
+ namespace hamilt
208
+ {
176
209
template class Operator <float , base_device::DEVICE_CPU>;
177
210
template class Operator <std::complex<float >, base_device::DEVICE_CPU>;
178
211
template class Operator <double , base_device::DEVICE_CPU>;
@@ -183,4 +216,4 @@ template class Operator<std::complex<float>, base_device::DEVICE_GPU>;
183
216
template class Operator <double , base_device::DEVICE_GPU>;
184
217
template class Operator <std::complex<double >, base_device::DEVICE_GPU>;
185
218
#endif
186
- }
219
+ } // namespace hamilt
0 commit comments