Skip to content

Commit baf8afe

Browse files
bertmaherMikhail Zolotukhin
authored andcommitted
Allow CodeGen to take Var args (interpreter support only) (pytorch#78)
* Test demonstrating dynamic shape * Allow binding of Vars to args in interpreter * Pass BufferArgs to LLVMCodeGen * clang-format-diff
1 parent f5b7ac5 commit baf8afe

File tree

8 files changed

+134
-66
lines changed

8 files changed

+134
-66
lines changed

test/cpp/tensorexpr/test_expr.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,5 +269,25 @@ void testExprBinaryMath01() {
269269
EXPECT_NEAR(eval.value().as<float>(), v_ref, 1e-6) << "fail: " << v_expr;
270270
}
271271
}
272+
273+
void testExprDynamicShapeAdd() {
274+
auto testWithSize = [](int32_t size) {
275+
Var n("n", kInt32);
276+
Buffer a(Var("a", kHandle), kFloat32, {n});
277+
Buffer b(Var("b", kHandle), kFloat32, {n});
278+
Buffer c(Var("c", kHandle), kFloat32, {n});
279+
Var i("i", kInt32);
280+
Stmt s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
281+
std::vector<float> aData(size, 1.0f);
282+
std::vector<float> bData(size, 2.0f);
283+
std::vector<float> cData(size, 0.0f);
284+
SimpleIREvaluator(s, a, b, c, n)(aData, bData, cData, size);
285+
ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
286+
};
287+
testWithSize(1);
288+
testWithSize(16);
289+
testWithSize(37);
290+
}
291+
272292
} // namespace jit
273293
} // namespace torch

test/cpp/tensorexpr/test_llvm.cpp

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ void testLLVMBufferTest() {
101101
std::vector<int32_t> v(5);
102102
std::vector<void*> args({v.data()});
103103
auto rv = IntImm::make(0);
104-
LLVMCodeGen cg(rv, {&a});
104+
LLVMCodeGen cg(rv, {a});
105105
EXPECT_EQ(cg.value<int>(args), 0);
106106
}
107107

@@ -116,7 +116,7 @@ void testLLVMBlockTest() {
116116
Store::make(a, IntImm::make(0), IntImm::make(4), IntImm::make(1)),
117117
});
118118

119-
LLVMCodeGen cg(block, {&a});
119+
LLVMCodeGen cg(block, {a});
120120
EXPECT_EQ(cg.value<int>(args), 0);
121121
EXPECT_EQ(v[0], 4);
122122
EXPECT_EQ(v[1], 4);
@@ -133,7 +133,7 @@ void testLLVMLoadStoreTest() {
133133
IntImm::make(0),
134134
Load::make(a, IntImm::make(0), IntImm::make(1)),
135135
IntImm::make(1));
136-
LLVMCodeGen cg(store, {&a, &b});
136+
LLVMCodeGen cg(store, {a, b});
137137
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
138138
EXPECT_EQ(cg.value<int>(args), 0);
139139
EXPECT_EQ(a_buffer[0], 42);
@@ -151,7 +151,7 @@ void testLLVMVecLoadStoreTest() {
151151
Ramp::make(0, 1, 4),
152152
Load::make(a, Ramp::make(0, 1, 4), Broadcast::make(IntImm::make(1), 4)),
153153
Broadcast::make(IntImm::make(1), 4));
154-
LLVMCodeGen cg(store, {&a, &b});
154+
LLVMCodeGen cg(store, {a, b});
155155
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
156156
EXPECT_EQ(cg.value<int>(args), 0);
157157
EXPECT_EQ(a_buffer[0], 1);
@@ -176,7 +176,7 @@ void testLLVMMemcpyTest() {
176176
auto expr =
177177
For::make(i, 0, N, Store::make(b, i, Load::make(a, i, mask), mask));
178178

179-
LLVMCodeGen cg(expr, {&a, &b});
179+
LLVMCodeGen cg(expr, {a, b});
180180

181181
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
182182
ASSERT_EQ(cg.value<int>(args), 0);
@@ -194,10 +194,9 @@ void testLLVMBzeroTest() {
194194

195195
auto mask = IntImm::make(1);
196196
Var i("i", kInt32);
197-
auto expr =
198-
For::make(i, 0, N, Store::make(b, i, IntImm::make(0), mask));
197+
auto expr = For::make(i, 0, N, Store::make(b, i, IntImm::make(0), mask));
199198

200-
LLVMCodeGen cg(expr, {&b});
199+
LLVMCodeGen cg(expr, {b});
201200

202201
std::vector<void*> args({b_buffer.data()});
203202
ASSERT_EQ(cg.value<int>(args), 0);
@@ -227,7 +226,7 @@ void testLLVMElemwiseAdd() {
227226
Add::make(Load::make(a, i, mask), Load::make(b, i, mask)),
228227
mask));
229228

230-
LLVMCodeGen cg(expr, {&a, &b, &c});
229+
LLVMCodeGen cg(expr, {a, b, c});
231230

232231
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
233232
ASSERT_EQ(cg.value<int>(args), 0);
@@ -257,7 +256,7 @@ void testLLVMElemwiseAddFloat() {
257256
N,
258257
Store::make(c, i, Load::make(a, i, mask) + Load::make(b, i, mask), mask));
259258

260-
LLVMCodeGen cg(expr, {&a, &b, &c});
259+
LLVMCodeGen cg(expr, {a, b, c});
261260

262261
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
263262
ASSERT_EQ(cg.value<int>(args), 0);
@@ -282,10 +281,14 @@ void testLLVMElemwiseLog10Float() {
282281
auto expr = For::make(
283282
i,
284283
0,
285-
N/4,
286-
Store::make(b, Ramp::make(i * 4, 1, 4), log10(Load::make(a, Ramp::make(i * 4, 1, 4), mask)), mask));
284+
N / 4,
285+
Store::make(
286+
b,
287+
Ramp::make(i * 4, 1, 4),
288+
log10(Load::make(a, Ramp::make(i * 4, 1, 4), mask)),
289+
mask));
287290

288-
LLVMCodeGen cg(expr, {&a, &b});
291+
LLVMCodeGen cg(expr, {a, b});
289292

290293
std::vector<void*> args({a_buffer.data(), b_buffer.data()});
291294
ASSERT_EQ(cg.value<int>(args), 0);
@@ -317,7 +320,7 @@ void testLLVMElemwiseMaxInt() {
317320
Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
318321
mask));
319322

320-
LLVMCodeGen cg(expr, {&a, &b, &c});
323+
LLVMCodeGen cg(expr, {a, b, c});
321324

322325
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
323326
ASSERT_EQ(cg.value<int>(args), 0);
@@ -351,7 +354,7 @@ void testLLVMElemwiseMinInt() {
351354
Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
352355
mask));
353356

354-
LLVMCodeGen cg(expr, {&a, &b, &c});
357+
LLVMCodeGen cg(expr, {a, b, c});
355358

356359
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
357360
ASSERT_EQ(cg.value<int>(args), 0);
@@ -385,7 +388,7 @@ void testLLVMElemwiseMaxNumFloat() {
385388
Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
386389
mask));
387390

388-
LLVMCodeGen cg(expr, {&a, &b, &c});
391+
LLVMCodeGen cg(expr, {a, b, c});
389392

390393
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
391394
ASSERT_EQ(cg.value<int>(args), 0);
@@ -419,7 +422,7 @@ void testLLVMElemwiseMaxNumNaNFloat() {
419422
Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
420423
mask));
421424

422-
LLVMCodeGen cg(expr, {&a, &b, &c});
425+
LLVMCodeGen cg(expr, {a, b, c});
423426

424427
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
425428
ASSERT_EQ(cg.value<int>(args), 0);
@@ -452,7 +455,7 @@ void testLLVMElemwiseMinNumFloat() {
452455
Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
453456
mask));
454457

455-
LLVMCodeGen cg(expr, {&a, &b, &c});
458+
LLVMCodeGen cg(expr, {a, b, c});
456459

457460
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
458461
ASSERT_EQ(cg.value<int>(args), 0);
@@ -486,7 +489,7 @@ void testLLVMElemwiseMinNumNaNFloat() {
486489
Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false),
487490
mask));
488491

489-
LLVMCodeGen cg(expr, {&a, &b, &c});
492+
LLVMCodeGen cg(expr, {a, b, c});
490493

491494
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
492495
ASSERT_EQ(cg.value<int>(args), 0);
@@ -520,7 +523,7 @@ void testLLVMElemwiseMaximumFloat() {
520523
Max::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
521524
mask));
522525

523-
LLVMCodeGen cg(expr, {&a, &b, &c});
526+
LLVMCodeGen cg(expr, {a, b, c});
524527

525528
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
526529
ASSERT_EQ(cg.value<int>(args), 0);
@@ -554,7 +557,7 @@ void testLLVMElemwiseMaximumNaNFloat() {
554557
Max::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
555558
mask));
556559

557-
LLVMCodeGen cg(expr, {&a, &b, &c});
560+
LLVMCodeGen cg(expr, {a, b, c});
558561

559562
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
560563
ASSERT_EQ(cg.value<int>(args), 0);
@@ -589,7 +592,7 @@ void testLLVMElemwiseMinimumFloat() {
589592
Min::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
590593
mask));
591594

592-
LLVMCodeGen cg(expr, {&a, &b, &c});
595+
LLVMCodeGen cg(expr, {a, b, c});
593596

594597
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
595598
ASSERT_EQ(cg.value<int>(args), 0);
@@ -623,7 +626,7 @@ void testLLVMElemwiseMinimumNaNFloat() {
623626
Min::make(Load::make(a, i, mask), Load::make(b, i, mask), true),
624627
mask));
625628

626-
LLVMCodeGen cg(expr, {&a, &b, &c});
629+
LLVMCodeGen cg(expr, {a, b, c});
627630

628631
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
629632
ASSERT_EQ(cg.value<int>(args), 0);
@@ -668,7 +671,7 @@ void testLLVMCompareSelectIntEQ() {
668671
CompareSelectOperation::kEQ),
669672
mask));
670673

671-
LLVMCodeGen cg(expr, {&a, &b, &c});
674+
LLVMCodeGen cg(expr, {a, b, c});
672675

673676
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
674677
ASSERT_EQ(cg.value<int>(args), 0);
@@ -707,7 +710,7 @@ void testLLVMCompareSelectFloatEQ() {
707710
CompareSelectOperation::kEQ),
708711
mask));
709712

710-
LLVMCodeGen cg(expr, {&a, &b, &c});
713+
LLVMCodeGen cg(expr, {a, b, c});
711714

712715
std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
713716
ASSERT_EQ(cg.value<int>(args), 0);
@@ -726,7 +729,7 @@ void testLLVMStoreFloat() {
726729
std::vector<float> result_buffer = {0.0f};
727730
auto expr = Store::make(
728731
result, IntImm::make(0), FloatImm::make(3.14f), IntImm::make(1));
729-
LLVMCodeGen cg(expr, {&result});
732+
LLVMCodeGen cg(expr, {result});
730733
std::vector<void*> args({result_buffer.data()});
731734
ASSERT_EQ(cg.value<int>(args), 0);
732735
EXPECT_EQ(result_buffer[0], 3.14f);
@@ -739,7 +742,7 @@ void testLLVMSimpleMath01() {
739742
Schedule sch = Schedule::make({tensor});
740743
Stmt stmt = sch.Lower();
741744
Buffer f_buf(tensor.function().func_var(), kFloat32, {N});
742-
LLVMCodeGen cg(stmt, {&f_buf});
745+
LLVMCodeGen cg(stmt, {f_buf});
743746

744747
PaddedBuffer<float> f_v(N, "f_v");
745748
std::vector<void*> args({f_v.data()});
@@ -764,7 +767,7 @@ void testLLVMComputeMul() {
764767
Schedule sch = Schedule::make({c});
765768
Stmt s = sch.Lower();
766769

767-
LLVMCodeGen cg(s, {&a, &b, &c_buf});
770+
LLVMCodeGen cg(s, {a, b, c_buf});
768771

769772
std::vector<float> a_vec(N, 21.0f);
770773
std::vector<float> b_vec(N, 2.0f);
@@ -789,7 +792,7 @@ void testLLVMBroadcastAdd() {
789792
Schedule sch = Schedule::make({c});
790793
Stmt s = sch.Lower();
791794

792-
LLVMCodeGen cg(s, {&a, &b, &c_buf});
795+
LLVMCodeGen cg(s, {a, b, c_buf});
793796

794797
std::vector<float> av(M * N);
795798
std::iota(av.begin(), av.end(), 0);
@@ -805,6 +808,30 @@ void testLLVMBroadcastAdd() {
805808
}
806809
}
807810
}
811+
812+
void testLLVMDynamicShapeAdd() {
813+
#if 0
814+
auto testWithSize = [](int32_t size) {
815+
Var n("n", kInt32);
816+
Buffer a(Var("a", kHandle), kFloat32, {n});
817+
Buffer b(Var("b", kHandle), kFloat32, {n});
818+
Buffer c(Var("c", kHandle), kFloat32, {n});
819+
Var i("i", kInt32);
820+
Stmt s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
821+
std::vector<float> aData(size, 1.0f);
822+
std::vector<float> bData(size, 2.0f);
823+
std::vector<float> cData(size, 0.0f);
824+
LLVMCodeGen cg(s, {a, b, c, n});
825+
std::vector<void*> args({aData.data(), bData.data(), cData.data(), size));
826+
cg.value<float>(args);
827+
ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
828+
};
829+
testWithSize(1);
830+
testWithSize(16);
831+
testWithSize(37);
832+
#endif
833+
}
834+
808835
} // namespace jit
809836
} // namespace torch
810837

test/cpp/tensorexpr/tests.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace jit {
1919
_(ExprMath01) \
2020
_(ExprUnaryMath01) \
2121
_(ExprBinaryMath01) \
22+
_(ExprDynamicShapeAdd) \
2223
_(IRPrinterBasicValueTest) \
2324
_(IRPrinterBasicValueTest02) \
2425
_(IRPrinterLetTest01) \
@@ -69,6 +70,7 @@ namespace jit {
6970
_(LLVMSimpleMath01) \
7071
_(LLVMComputeMul) \
7172
_(LLVMBroadcastAdd) \
73+
_(LLVMDynamicShapeAdd) \
7274
_(CudaTestVectorAdd01) \
7375
_(ATen_cast_Float) \
7476
_(ATennegInt) \

torch/csrc/jit/passes/tensorexpr_fuser.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -686,17 +686,12 @@ struct TensorExprKernel {
686686
}
687687
}
688688
Stmt stmt = sch.Lower();
689+
689690
#ifdef ENABLE_LLVM
690691
// Set up formal params (inputs, then outputs) for kernel.
691-
std::vector<Buffer*> params;
692-
for (auto& b : buffer_args) {
693-
params.push_back(&b);
694-
}
695-
Buffer outbuf(
696-
tensor_output->function().func_var(),
697-
tensor_output->dtype(),
698-
tensor_output->dims());
699-
params.push_back(&outbuf);
692+
std::vector<CodeGen::BufferArg> params(
693+
buffer_args.begin(), buffer_args.end());
694+
params.push_back(*tensor_output);
700695

701696
// Generate code.
702697
codegen = std::make_unique<LLVMCodeGen>(stmt, params);

torch/csrc/jit/tensorexpr/codegen.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class CodeGen::BufferArg {
6666
dtype_(tensor.function().body().dtype()) {}
6767
BufferArg(const Function& func)
6868
: var_(func.func_var()), dtype_(func.body().dtype()) {}
69+
BufferArg(const Var& var) : var_(var), dtype_(var.dtype()), isVar_(true) {}
70+
6971
const Var& var() const {
7072
return var_;
7173
}
@@ -76,9 +78,14 @@ class CodeGen::BufferArg {
7678
return dtype_;
7779
}
7880

81+
bool isVar() const {
82+
return isVar_;
83+
}
84+
7985
private:
8086
Var var_;
8187
Dtype dtype_;
88+
bool isVar_{false};
8289
};
8390

8491
class CodeGen::CallArg {
@@ -91,12 +98,28 @@ class CodeGen::CallArg {
9198

9299
CallArg(void* ptr) : ptr_(ptr) {}
93100

101+
CallArg(int32_t i) : ival_(i) {}
102+
103+
CallArg(float f) : fval_(f) {}
104+
94105
void* data() const {
95106
return ptr_;
96107
}
97108

109+
int32_t intData() const {
110+
return ival_;
111+
}
112+
113+
float floatData() const {
114+
return fval_;
115+
}
116+
98117
private:
99-
void* ptr_ = nullptr;
118+
union {
119+
void* ptr_;
120+
float fval_;
121+
int32_t ival_;
122+
};
100123
};
101124

102125
} // namespace tensorexpr

0 commit comments

Comments
 (0)