Skip to content

Commit e317774

Browse files
zheng-xqMikhail Zolotukhin
authored andcommitted
Fixed CudaCodeGen output streams. Switch to __ldg by default (pytorch#148)
1 parent 5af5528 commit e317774

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

torch/csrc/jit/tensorexpr/cuda_codegen.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ScopedVarName {
3333
const std::string& name)
3434
: ScopedVarName(&manager->unique_name_mapping_, var, name) {}
3535

36-
~ScopedVarName() {
36+
~ScopedVarName() noexcept(false) {
3737
auto iter = mapping_->find(var_);
3838
TORCH_CHECK(iter != mapping_->end(), "Invalid var entry");
3939
mapping_->erase(var_);
@@ -124,29 +124,34 @@ void CudaPrinter::visit(const For* v) {
124124
}
125125
}
126126

127+
void CudaPrinter::visit(const Load* v) {
128+
// TODO: find a better metric in using ldg or not. Support different dtypes.
129+
os() << "__ldg(" << v->base_handle() << " + " << v->index() << ")";
130+
}
131+
127132
void CudaCodeGen::Initialize() {
128133
printer_.reset(new CudaPrinter(&oss_));
129134
// TODO: handle multiple kernels.
130135
// TODO: handle dynamic dimension.
131136
// TODO: call nvrtc.
132-
oss_ << "extern \"C\" __global__" << std::endl << "void f(";
137+
os() << "extern \"C\" __global__" << std::endl << "void f(";
133138
const std::vector<BufferArg> buffer_args = this->buffer_args();
134139
for (int i = 0; i < buffer_args.size(); i++) {
135140
if (i > 0) {
136-
oss_ << ", ";
141+
os() << ", ";
137142
}
138143
const BufferArg& buffer_arg = buffer_args[i];
139144
const Var& var = buffer_arg.var();
140145
Dtype dtype = buffer_arg.dtype();
141-
oss_ << dtype.ToCppString() << (buffer_arg.isVar() ? " " : "* ")
146+
os() << dtype.ToCppString() << (buffer_arg.isVar() ? " " : "* ")
142147
<< name_manager()->get_unique_name(var);
143148
}
144-
oss_ << ") {";
149+
os() << ") {";
145150

146-
oss_ << std::endl;
151+
os() << std::endl;
147152
stmt().accept(printer_.get());
148-
oss_ << std::endl;
149-
oss_ << "}";
153+
os() << std::endl;
154+
os() << "}";
150155

151156
// Check that all block extents had been set.
152157
const std::vector<Expr>& gpu_block_extents = printer_->gpu_block_extents();

torch/csrc/jit/tensorexpr/cuda_codegen.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace tensorexpr {
2222
// A class that overrides the underlying IRPrinter to produce Cuda C.
2323
class CudaPrinter : public IRPrinter {
2424
public:
25-
explicit CudaPrinter(std::ostream* os) : IRPrinter(*os), os_(os) {}
25+
explicit CudaPrinter(std::ostream* os) : IRPrinter(*os) {}
2626

2727
void visit(const Cast* v) {
2828
auto dtype = v->dtype();
@@ -38,9 +38,7 @@ class CudaPrinter : public IRPrinter {
3838

3939
void visit(const For* v);
4040

41-
std::ostream& os() {
42-
return *os_;
43-
}
41+
void visit(const Load* v);
4442

4543
const std::vector<Expr>& gpu_block_extents() const {
4644
return gpu_block_extents_;
@@ -53,7 +51,6 @@ class CudaPrinter : public IRPrinter {
5351
using IRPrinter::name_manager;
5452

5553
private:
56-
std::ostream* os_ = nullptr;
5754
std::vector<Expr> gpu_block_extents_;
5855
std::vector<Expr> gpu_thread_extents_;
5956
};
@@ -94,6 +91,10 @@ class TORCH_API CudaCodeGen : public CodeGen {
9491
return printer_->name_manager();
9592
}
9693

94+
std::ostream& os() {
95+
return printer_->os();
96+
}
97+
9798
std::ostringstream oss_;
9899
std::unique_ptr<CudaPrinter> printer_;
99100
CUfunction function_;

torch/csrc/jit/tensorexpr/ir_printer.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ class TORCH_API IRPrinter : public IRVisitor {
6464
}
6565

6666
private:
67-
std::ostream& raw_os() {
68-
return printer_os_;
69-
}
70-
7167
PrinterStream printer_os_;
7268
UniqueNameManager name_manager_;
7369
};

0 commit comments

Comments
 (0)