Skip to content

Commit 113a59f

Browse files
jspark1105facebook-github-bot
authored andcommitted
Use caffe2::int8::Int8TensorCPU when input type is uint8_t
Summary: We use caffe2::int8::Int8TensorCPU for quantized tensor with uint8_t element type. Differential Revision: D10156452 fbshipit-source-id: a4c260cd0bfecfb783adc468de25587e04badd79
1 parent 06360c3 commit 113a59f

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

binaries/benchmark_helper.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "caffe2/core/logging.h"
2929
#include "caffe2/core/net.h"
3030
#include "caffe2/core/operator.h"
31+
#include "caffe2/core/tensor_int8.h"
3132
#include "caffe2/utils/bench_utils.h"
3233
#include "caffe2/utils/string_utils.h"
3334
#include "observers/net_observer_reporter_print.h"
@@ -163,12 +164,16 @@ void loadInput(
163164
CAFFE_THROW("Not support GPU on mobile.");
164165
#endif
165166
} else {
166-
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
167-
CHECK_NOTNULL(tensor);
168-
tensor->Resize(input_dims);
169167
if (input_type_list[i] == "uint8_t") {
170-
tensor->mutable_data<uint8_t>();
168+
caffe2::int8::Int8TensorCPU* tensor =
169+
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
170+
CHECK_NOTNULL(tensor);
171+
tensor->t.Resize(input_dims);
172+
tensor->t.mutable_data<uint8_t>();
171173
} else if (input_type_list[i] == "float") {
174+
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
175+
CHECK_NOTNULL(tensor);
176+
tensor->Resize(input_dims);
172177
tensor->mutable_data<float>();
173178
} else {
174179
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);

binaries/speed_benchmark.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "caffe2/core/init.h"
2121
#include "caffe2/core/logging.h"
2222
#include "caffe2/core/operator.h"
23+
#include "caffe2/core/tensor_int8.h"
2324
#ifdef CAFFE2_OPTIMIZER
2425
#include "caffe2/opt/optimizer.h"
2526
#endif
@@ -137,14 +138,18 @@ int main(int argc, char** argv) {
137138
if (blob == nullptr) {
138139
blob = workspace->CreateBlob(input_names[i]);
139140
}
140-
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
141-
CHECK_NOTNULL(tensor);
142-
tensor->Resize(input_dims);
143141
if (input_type_list[i] == "uint8_t") {
144-
tensor->mutable_data<uint8_t>();
142+
caffe2::int8::Int8TensorCPU* tensor =
143+
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
144+
CHECK_NOTNULL(tensor);
145+
tensor->t.Resize(input_dims);
146+
tensor->t.mutable_data<uint8_t>();
145147
} else if (input_type_list[i] == "float") {
148+
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
149+
CHECK_NOTNULL(tensor);
150+
tensor->Resize(input_dims);
146151
tensor->mutable_data<float>();
147-
} else {
152+
} else {
148153
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
149154
}
150155
}

0 commit comments

Comments
 (0)