@@ -101,7 +101,7 @@ void testLLVMBufferTest() {
101
101
std::vector<int32_t > v (5 );
102
102
std::vector<void *> args ({v.data ()});
103
103
auto rv = IntImm::make (0 );
104
- LLVMCodeGen cg (rv, {& a});
104
+ LLVMCodeGen cg (rv, {a});
105
105
EXPECT_EQ (cg.value <int >(args), 0 );
106
106
}
107
107
@@ -116,7 +116,7 @@ void testLLVMBlockTest() {
116
116
Store::make (a, IntImm::make (0 ), IntImm::make (4 ), IntImm::make (1 )),
117
117
});
118
118
119
- LLVMCodeGen cg (block, {& a});
119
+ LLVMCodeGen cg (block, {a});
120
120
EXPECT_EQ (cg.value <int >(args), 0 );
121
121
EXPECT_EQ (v[0 ], 4 );
122
122
EXPECT_EQ (v[1 ], 4 );
@@ -133,7 +133,7 @@ void testLLVMLoadStoreTest() {
133
133
IntImm::make (0 ),
134
134
Load::make (a, IntImm::make (0 ), IntImm::make (1 )),
135
135
IntImm::make (1 ));
136
- LLVMCodeGen cg (store, {& a, & b});
136
+ LLVMCodeGen cg (store, {a, b});
137
137
std::vector<void *> args ({a_buffer.data (), b_buffer.data ()});
138
138
EXPECT_EQ (cg.value <int >(args), 0 );
139
139
EXPECT_EQ (a_buffer[0 ], 42 );
@@ -151,7 +151,7 @@ void testLLVMVecLoadStoreTest() {
151
151
Ramp::make (0 , 1 , 4 ),
152
152
Load::make (a, Ramp::make (0 , 1 , 4 ), Broadcast::make (IntImm::make (1 ), 4 )),
153
153
Broadcast::make (IntImm::make (1 ), 4 ));
154
- LLVMCodeGen cg (store, {& a, & b});
154
+ LLVMCodeGen cg (store, {a, b});
155
155
std::vector<void *> args ({a_buffer.data (), b_buffer.data ()});
156
156
EXPECT_EQ (cg.value <int >(args), 0 );
157
157
EXPECT_EQ (a_buffer[0 ], 1 );
@@ -176,7 +176,7 @@ void testLLVMMemcpyTest() {
176
176
auto expr =
177
177
For::make (i, 0 , N, Store::make (b, i, Load::make (a, i, mask), mask));
178
178
179
- LLVMCodeGen cg (expr, {& a, & b});
179
+ LLVMCodeGen cg (expr, {a, b});
180
180
181
181
std::vector<void *> args ({a_buffer.data (), b_buffer.data ()});
182
182
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -194,10 +194,9 @@ void testLLVMBzeroTest() {
194
194
195
195
auto mask = IntImm::make (1 );
196
196
Var i (" i" , kInt32 );
197
- auto expr =
198
- For::make (i, 0 , N, Store::make (b, i, IntImm::make (0 ), mask));
197
+ auto expr = For::make (i, 0 , N, Store::make (b, i, IntImm::make (0 ), mask));
199
198
200
- LLVMCodeGen cg (expr, {& b});
199
+ LLVMCodeGen cg (expr, {b});
201
200
202
201
std::vector<void *> args ({b_buffer.data ()});
203
202
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -227,7 +226,7 @@ void testLLVMElemwiseAdd() {
227
226
Add::make (Load::make (a, i, mask), Load::make (b, i, mask)),
228
227
mask));
229
228
230
- LLVMCodeGen cg (expr, {& a, & b, & c});
229
+ LLVMCodeGen cg (expr, {a, b, c});
231
230
232
231
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
233
232
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -257,7 +256,7 @@ void testLLVMElemwiseAddFloat() {
257
256
N,
258
257
Store::make (c, i, Load::make (a, i, mask) + Load::make (b, i, mask), mask));
259
258
260
- LLVMCodeGen cg (expr, {& a, & b, & c});
259
+ LLVMCodeGen cg (expr, {a, b, c});
261
260
262
261
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
263
262
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -282,10 +281,14 @@ void testLLVMElemwiseLog10Float() {
282
281
auto expr = For::make (
283
282
i,
284
283
0 ,
285
- N/4 ,
286
- Store::make (b, Ramp::make (i * 4 , 1 , 4 ), log10 (Load::make (a, Ramp::make (i * 4 , 1 , 4 ), mask)), mask));
284
+ N / 4 ,
285
+ Store::make (
286
+ b,
287
+ Ramp::make (i * 4 , 1 , 4 ),
288
+ log10 (Load::make (a, Ramp::make (i * 4 , 1 , 4 ), mask)),
289
+ mask));
287
290
288
- LLVMCodeGen cg (expr, {& a, & b});
291
+ LLVMCodeGen cg (expr, {a, b});
289
292
290
293
std::vector<void *> args ({a_buffer.data (), b_buffer.data ()});
291
294
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -317,7 +320,7 @@ void testLLVMElemwiseMaxInt() {
317
320
Max::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
318
321
mask));
319
322
320
- LLVMCodeGen cg (expr, {& a, & b, & c});
323
+ LLVMCodeGen cg (expr, {a, b, c});
321
324
322
325
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
323
326
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -351,7 +354,7 @@ void testLLVMElemwiseMinInt() {
351
354
Min::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
352
355
mask));
353
356
354
- LLVMCodeGen cg (expr, {& a, & b, & c});
357
+ LLVMCodeGen cg (expr, {a, b, c});
355
358
356
359
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
357
360
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -385,7 +388,7 @@ void testLLVMElemwiseMaxNumFloat() {
385
388
Max::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
386
389
mask));
387
390
388
- LLVMCodeGen cg (expr, {& a, & b, & c});
391
+ LLVMCodeGen cg (expr, {a, b, c});
389
392
390
393
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
391
394
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -419,7 +422,7 @@ void testLLVMElemwiseMaxNumNaNFloat() {
419
422
Max::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
420
423
mask));
421
424
422
- LLVMCodeGen cg (expr, {& a, & b, & c});
425
+ LLVMCodeGen cg (expr, {a, b, c});
423
426
424
427
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
425
428
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -452,7 +455,7 @@ void testLLVMElemwiseMinNumFloat() {
452
455
Min::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
453
456
mask));
454
457
455
- LLVMCodeGen cg (expr, {& a, & b, & c});
458
+ LLVMCodeGen cg (expr, {a, b, c});
456
459
457
460
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
458
461
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -486,7 +489,7 @@ void testLLVMElemwiseMinNumNaNFloat() {
486
489
Min::make (Load::make (a, i, mask), Load::make (b, i, mask), false ),
487
490
mask));
488
491
489
- LLVMCodeGen cg (expr, {& a, & b, & c});
492
+ LLVMCodeGen cg (expr, {a, b, c});
490
493
491
494
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
492
495
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -520,7 +523,7 @@ void testLLVMElemwiseMaximumFloat() {
520
523
Max::make (Load::make (a, i, mask), Load::make (b, i, mask), true ),
521
524
mask));
522
525
523
- LLVMCodeGen cg (expr, {& a, & b, & c});
526
+ LLVMCodeGen cg (expr, {a, b, c});
524
527
525
528
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
526
529
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -554,7 +557,7 @@ void testLLVMElemwiseMaximumNaNFloat() {
554
557
Max::make (Load::make (a, i, mask), Load::make (b, i, mask), true ),
555
558
mask));
556
559
557
- LLVMCodeGen cg (expr, {& a, & b, & c});
560
+ LLVMCodeGen cg (expr, {a, b, c});
558
561
559
562
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
560
563
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -589,7 +592,7 @@ void testLLVMElemwiseMinimumFloat() {
589
592
Min::make (Load::make (a, i, mask), Load::make (b, i, mask), true ),
590
593
mask));
591
594
592
- LLVMCodeGen cg (expr, {& a, & b, & c});
595
+ LLVMCodeGen cg (expr, {a, b, c});
593
596
594
597
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
595
598
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -623,7 +626,7 @@ void testLLVMElemwiseMinimumNaNFloat() {
623
626
Min::make (Load::make (a, i, mask), Load::make (b, i, mask), true ),
624
627
mask));
625
628
626
- LLVMCodeGen cg (expr, {& a, & b, & c});
629
+ LLVMCodeGen cg (expr, {a, b, c});
627
630
628
631
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
629
632
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -668,7 +671,7 @@ void testLLVMCompareSelectIntEQ() {
668
671
CompareSelectOperation::kEQ ),
669
672
mask));
670
673
671
- LLVMCodeGen cg (expr, {& a, & b, & c});
674
+ LLVMCodeGen cg (expr, {a, b, c});
672
675
673
676
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
674
677
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -707,7 +710,7 @@ void testLLVMCompareSelectFloatEQ() {
707
710
CompareSelectOperation::kEQ ),
708
711
mask));
709
712
710
- LLVMCodeGen cg (expr, {& a, & b, & c});
713
+ LLVMCodeGen cg (expr, {a, b, c});
711
714
712
715
std::vector<void *> args ({a_buffer.data (), b_buffer.data (), c_buffer.data ()});
713
716
ASSERT_EQ (cg.value <int >(args), 0 );
@@ -726,7 +729,7 @@ void testLLVMStoreFloat() {
726
729
std::vector<float > result_buffer = {0 .0f };
727
730
auto expr = Store::make (
728
731
result, IntImm::make (0 ), FloatImm::make (3 .14f ), IntImm::make (1 ));
729
- LLVMCodeGen cg (expr, {& result});
732
+ LLVMCodeGen cg (expr, {result});
730
733
std::vector<void *> args ({result_buffer.data ()});
731
734
ASSERT_EQ (cg.value <int >(args), 0 );
732
735
EXPECT_EQ (result_buffer[0 ], 3 .14f );
@@ -739,7 +742,7 @@ void testLLVMSimpleMath01() {
739
742
Schedule sch = Schedule::make ({tensor});
740
743
Stmt stmt = sch.Lower ();
741
744
Buffer f_buf (tensor.function ().func_var (), kFloat32 , {N});
742
- LLVMCodeGen cg (stmt, {& f_buf});
745
+ LLVMCodeGen cg (stmt, {f_buf});
743
746
744
747
PaddedBuffer<float > f_v (N, " f_v" );
745
748
std::vector<void *> args ({f_v.data ()});
@@ -764,7 +767,7 @@ void testLLVMComputeMul() {
764
767
Schedule sch = Schedule::make ({c});
765
768
Stmt s = sch.Lower ();
766
769
767
- LLVMCodeGen cg (s, {& a, & b, & c_buf});
770
+ LLVMCodeGen cg (s, {a, b, c_buf});
768
771
769
772
std::vector<float > a_vec (N, 21 .0f );
770
773
std::vector<float > b_vec (N, 2 .0f );
@@ -789,7 +792,7 @@ void testLLVMBroadcastAdd() {
789
792
Schedule sch = Schedule::make ({c});
790
793
Stmt s = sch.Lower ();
791
794
792
- LLVMCodeGen cg (s, {& a, & b, & c_buf});
795
+ LLVMCodeGen cg (s, {a, b, c_buf});
793
796
794
797
std::vector<float > av (M * N);
795
798
std::iota (av.begin (), av.end (), 0 );
@@ -805,6 +808,30 @@ void testLLVMBroadcastAdd() {
805
808
}
806
809
}
807
810
}
811
+
812
+ void testLLVMDynamicShapeAdd () {
813
+ #if 0
814
+ auto testWithSize = [](int32_t size) {
815
+ Var n("n", kInt32);
816
+ Buffer a(Var("a", kHandle), kFloat32, {n});
817
+ Buffer b(Var("b", kHandle), kFloat32, {n});
818
+ Buffer c(Var("c", kHandle), kFloat32, {n});
819
+ Var i("i", kInt32);
820
+ Stmt s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1));
821
+ std::vector<float> aData(size, 1.0f);
822
+ std::vector<float> bData(size, 2.0f);
823
+ std::vector<float> cData(size, 0.0f);
824
+ LLVMCodeGen cg(s, {a, b, c, n});
825
+ std::vector<void*> args({aData.data(), bData.data(), cData.data(), size));
826
+ cg.value<float>(args);
827
+ ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
828
+ };
829
+ testWithSize(1);
830
+ testWithSize(16);
831
+ testWithSize(37);
832
+ #endif
833
+ }
834
+
808
835
} // namespace jit
809
836
} // namespace torch
810
837
0 commit comments