Skip to content

Commit 6417664

Browse files
committed
Completing refactor of VLA CodeGen
Closes 3eb4f59d0eb Things done: - Emit QUAL.OSS.CAPTURED bundle in clang - Gather QUAL.OSS.CAPTURED info in analysis - assert if a value inside task body is not in QUAL.OSS.CAPTURED (instead of asserting with VLA dims) - assert if QUAL.OSS.VLA.DIMS are not in QUAL.OSS.CAPTURED - Add a flag to analysis to output all those QUAL.OSS.VLA.DIMS that are not in QUAL.OSS.CAPTURED
1 parent 237c2ab commit 6417664

File tree

6 files changed

+98
-22
lines changed

6 files changed

+98
-22
lines changed

llvm/include/llvm/Analysis/OmpSsRegionAnalysis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct TaskDSAInfo {
2626

2727
// <VLA, VLA_dims>
2828
using TaskVLADimsInfo = MapVector<Value *, SetVector<Value *>>;
29+
using TaskCapturedInfo = SetVector<Value *>;
2930

3031
struct DependInfo {
3132
int SymbolIndex;
@@ -56,6 +57,7 @@ struct TaskInfo {
5657
TaskDependsInfo DependsInfo;
5758
Value *Final = nullptr;
5859
Value *If = nullptr;
60+
TaskCapturedInfo CapturedInfo;
5961
Instruction *Entry;
6062
Instruction *Exit;
6163
};

llvm/include/llvm/IR/LLVMContext.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class LLVMContext {
9898
OB_oss_dep_weakinout = 12, // "OB_oss_dep_weakinout"
9999
OB_oss_final = 13, // "OB_oss_final"
100100
OB_oss_if = 14, // "OB_oss_if"
101+
OB_oss_captured = 15, // "OB_oss_captured"
101102
};
102103

103104
/// getMDKindID - Return a unique non-zero ID for the specified metadata kind.

llvm/lib/Analysis/OmpSsRegionAnalysis.cpp

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ enum PrintVerbosity {
3434
PV_Uses,
3535
PV_UnpackAndConst,
3636
PV_DsaMissing,
37-
PV_DsaVLADimsMissing
37+
PV_DsaVLADimsMissing,
38+
PV_VLADimsCaptureMissing
3839
};
3940

4041
static cl::opt<PrintVerbosity>
@@ -46,7 +47,8 @@ PrintVerboseLevel("print-verbosity",
4647
clEnumValN(PV_Uses, "uses", "Print task layout with uses"),
4748
clEnumValN(PV_UnpackAndConst, "unpack", "Print task layout with unpack instructions/constexprs needed in dependencies"),
4849
clEnumValN(PV_DsaMissing, "dsa_missing", "Print task layout with uses without DSA"),
49-
clEnumValN(PV_DsaVLADimsMissing, "dsa_vla_dims_missing", "Print task layout with DSAs without VLA info or VLA info without DSAs")));
50+
clEnumValN(PV_DsaVLADimsMissing, "dsa_vla_dims_missing", "Print task layout with DSAs without VLA info or VLA info without DSAs"),
51+
clEnumValN(PV_VLADimsCaptureMissing, "vla_dims_capture_missing", "Print task layout with VLA dimensions without capture")));
5052

5153
char OmpSsRegionAnalysisPass::ID = 0;
5254

@@ -77,6 +79,11 @@ static bool valueInVLADimsBundles(const TaskVLADimsInfo& VLADimsInfo,
7779
return false;
7880
}
7981

82+
static bool valueInCapturedBundle(const TaskCapturedInfo& CapturedInfo,
83+
Value *const V) {
84+
return CapturedInfo.count(V);
85+
}
86+
8087
void OmpSsRegionAnalysisPass::print(raw_ostream &OS, const Module *M) const {
8188
for (auto it = TaskProgramOrder.begin(); it != TaskProgramOrder.end(); ++it) {
8289
Instruction *I = it->first;
@@ -140,14 +147,25 @@ void OmpSsRegionAnalysisPass::print(raw_ostream &OS, const Module *M) const {
140147
DSAVLADimsFreqMap[VLAWithDimsMap.first]++;
141148
}
142149
for (const auto &Pair : DSAVLADimsFreqMap) {
143-
// It's expected to have only two VLA bundles, the DSA and de dimensions
150+
// It's expected to have only two VLA bundles, the DSA and dimensions
144151
if (Pair.second != 2) {
145152
dbgs() << "\n";
146153
dbgs() << std::string((Depth + 1) * PrintSpaceMultiplier, ' ');
147154
Pair.first->printAsOperand(dbgs(), false);
148155
}
149156
}
150157
}
158+
else if (PrintVerboseLevel == PV_VLADimsCaptureMissing) {
159+
for (auto &VLAWithDimsMap : Info.VLADimsInfo) {
160+
for (Value *const &V : VLAWithDimsMap.second) {
161+
if (!valueInCapturedBundle(Info.CapturedInfo, V)) {
162+
dbgs() << "\n";
163+
dbgs() << std::string((Depth + 1) * PrintSpaceMultiplier, ' ');
164+
V->printAsOperand(dbgs(), false);
165+
}
166+
}
167+
}
168+
}
151169

152170
dbgs() << "\n";
153171
}
@@ -245,7 +263,7 @@ static bool insertUniqInstInProgramOrder(SmallVectorImpl<Instruction *> &InstLis
245263
}
246264

247265
static void gatherUnpackInstructions(const TaskDSAInfo &DSAInfo,
248-
const TaskVLADimsInfo &VLADimsInfo,
266+
const TaskCapturedInfo &CapturedInfo,
249267
const OrderedInstructions &OI,
250268
DependInfo &DI,
251269
TaskAnalysisInfo &TAI,
@@ -265,13 +283,12 @@ static void gatherUnpackInstructions(const TaskDSAInfo &DSAInfo,
265283
Value *Dep = It->second;
266284
WorkList.erase(It);
267285
bool IsDSA = valueInDSABundles(DSAInfo, Cur);
268-
// TODO: this will be get from captured info
269-
bool IsVLADim = valueInVLADimsBundles(VLADimsInfo, Cur);
286+
bool IsCaptured = valueInCapturedBundle(CapturedInfo, Cur);
270287
// Go over all uses until:
271288
// 1. We get a DSA so assign a symbol index
272-
// 2. We get a VLA dimension, so we're done. We don't want to move
273-
// instructions that generate the vla dimension
274-
if (!IsDSA && !IsVLADim) {
289+
// 2. We get a Capture value, so we're done. We don't want to move
290+
// instructions that generate this
291+
if (!IsDSA && !IsCaptured) {
275292
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Cur)) {
276293
for (Use &U : CE->operands()) {
277294
WorkList.emplace_back(U.get(), Dep);
@@ -300,7 +317,7 @@ static void gatherUnpackInstructions(const TaskDSAInfo &DSAInfo,
300317
static void gatherDependsInfoFromBundles(const SmallVectorImpl<OperandBundleDef> &OpBundles,
301318
const OrderedInstructions &OI,
302319
const TaskDSAInfo &DSAInfo,
303-
const TaskVLADimsInfo &VLADimsInfo,
320+
const TaskCapturedInfo &CapturedInfo,
304321
TaskAnalysisInfo &TAI,
305322
SmallVectorImpl<DependInfo> &DependsList,
306323
SmallVectorImpl<Instruction *> &UnpackInsts,
@@ -317,7 +334,7 @@ static void gatherDependsInfoFromBundles(const SmallVectorImpl<OperandBundleDef>
317334
DI.Dims.push_back(OBArgs[i]);
318335
}
319336

320-
gatherUnpackInstructions(DSAInfo, VLADimsInfo, OI, DI, TAI, UnpackInsts, UnpackConsts);
337+
gatherUnpackInstructions(DSAInfo, CapturedInfo, OI, DI, TAI, UnpackInsts, UnpackConsts);
321338

322339
DependsList.push_back(DI);
323340
}
@@ -327,50 +344,50 @@ static void gatherDependsInfoFromBundles(const SmallVectorImpl<OperandBundleDef>
327344
static void gatherDependsInfoWithID(const IntrinsicInst *I,
328345
const OrderedInstructions &OI,
329346
const TaskDSAInfo &DSAInfo,
330-
const TaskVLADimsInfo &VLADimsInfo,
347+
const TaskCapturedInfo &CapturedInfo,
331348
TaskAnalysisInfo &TAI,
332349
SmallVectorImpl<DependInfo> &DependsList,
333350
SmallVectorImpl<Instruction *> &UnpackInsts,
334351
SetVector<ConstantExpr *> &UnpackConsts,
335352
uint64_t Id) {
336353
SmallVector<OperandBundleDef, 4> OpBundles;
337354
getOperandBundlesAsDefsWithID(I, OpBundles, Id);
338-
gatherDependsInfoFromBundles(OpBundles, OI, DSAInfo, VLADimsInfo, TAI, DependsList, UnpackInsts, UnpackConsts);
355+
gatherDependsInfoFromBundles(OpBundles, OI, DSAInfo, CapturedInfo, TAI, DependsList, UnpackInsts, UnpackConsts);
339356
}
340357

341358
// Gathers all dependencies needed information
342359
static void gatherDependsInfo(const IntrinsicInst *I, TaskInfo &TI,
343360
TaskAnalysisInfo &TAI,
344361
const OrderedInstructions &OI) {
345-
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.VLADimsInfo, TAI,
362+
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.CapturedInfo, TAI,
346363
TI.DependsInfo.Ins,
347364
TI.DependsInfo.UnpackInstructions,
348365
TI.DependsInfo.UnpackConstants,
349366
LLVMContext::OB_oss_dep_in);
350-
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.VLADimsInfo, TAI,
367+
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.CapturedInfo, TAI,
351368
TI.DependsInfo.Outs,
352369
TI.DependsInfo.UnpackInstructions,
353370
TI.DependsInfo.UnpackConstants,
354371
LLVMContext::OB_oss_dep_out);
355-
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.VLADimsInfo, TAI,
372+
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.CapturedInfo, TAI,
356373
TI.DependsInfo.Inouts,
357374
TI.DependsInfo.UnpackInstructions,
358375
TI.DependsInfo.UnpackConstants,
359376
LLVMContext::OB_oss_dep_inout);
360377

361-
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.VLADimsInfo, TAI,
378+
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.CapturedInfo, TAI,
362379
TI.DependsInfo.WeakIns,
363380
TI.DependsInfo.UnpackInstructions,
364381
TI.DependsInfo.UnpackConstants,
365382
LLVMContext::OB_oss_dep_weakin);
366383

367-
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.VLADimsInfo, TAI,
384+
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.CapturedInfo, TAI,
368385
TI.DependsInfo.WeakOuts,
369386
TI.DependsInfo.UnpackInstructions,
370387
TI.DependsInfo.UnpackConstants,
371388
LLVMContext::OB_oss_dep_weakout);
372389

373-
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.VLADimsInfo, TAI,
390+
gatherDependsInfoWithID(I, OI, TI.DSAInfo, TI.CapturedInfo, TAI,
374391
TI.DependsInfo.WeakInouts,
375392
TI.DependsInfo.UnpackInstructions,
376393
TI.DependsInfo.UnpackConstants,
@@ -383,6 +400,20 @@ static void gatherIfFinalInfo(const IntrinsicInst *I, TaskInfo &TI) {
383400
getValueFromOperandBundleWithID(I, TI.If, LLVMContext::OB_oss_if);
384401
}
385402

403+
// It's expected to have VLA dims info before calling this
404+
static void gatherCapturedInfo(const IntrinsicInst *I, TaskInfo &TI) {
405+
getValueListFromOperandBundlesWithID(I, TI.CapturedInfo, LLVMContext::OB_oss_captured);
406+
if (!DisableChecks) {
407+
// VLA Dims that are not Captured is an error
408+
for (auto &VLAWithDimsMap : TI.VLADimsInfo) {
409+
for (Value *const &V : VLAWithDimsMap.second) {
410+
if (!valueInCapturedBundle(TI.CapturedInfo, V))
411+
llvm_unreachable("VLA dimension has not been captured");
412+
}
413+
}
414+
}
415+
}
416+
386417
void OmpSsRegionAnalysisPass::getOmpSsFunctionInfo(
387418
Function &F, DominatorTree &DT, FunctionInfo &FI,
388419
TaskFunctionAnalysisInfo &TFAI,
@@ -417,6 +448,7 @@ void OmpSsRegionAnalysisPass::getOmpSsFunctionInfo(
417448

418449
gatherDSAInfo(II, T.Info);
419450
gatherVLADimsInfo(II, T.Info);
451+
gatherCapturedInfo(II, T.Info);
420452
gatherDependsInfo(II, T.Info, T.AnalysisInfo, OI);
421453
gatherIfFinalInfo(II, T.Info);
422454

@@ -448,7 +480,7 @@ void OmpSsRegionAnalysisPass::getOmpSsFunctionInfo(
448480
T.AnalysisInfo.UsesBeforeEntry.insert(I2);
449481
if (!DisableChecks
450482
&& !valueInDSABundles(T.Info.DSAInfo, I2)
451-
&& !valueInVLADimsBundles(T.Info.VLADimsInfo, I2)) {
483+
&& !valueInCapturedBundle(T.Info.CapturedInfo, I2)) {
452484
llvm_unreachable("Value supposed to be inside task entry "
453485
"OperandBundle not found.");
454486
}
@@ -457,7 +489,7 @@ void OmpSsRegionAnalysisPass::getOmpSsFunctionInfo(
457489
T.AnalysisInfo.UsesBeforeEntry.insert(A);
458490
if (!DisableChecks
459491
&& !valueInDSABundles(T.Info.DSAInfo, A)
460-
&& !valueInVLADimsBundles(T.Info.VLADimsInfo, A)) {
492+
&& !valueInCapturedBundle(T.Info.CapturedInfo, A)) {
461493
llvm_unreachable("Value supposed to be inside task entry "
462494
"OperandBundle not found.");
463495
}

llvm/lib/IR/LLVMContext.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ LLVMContext::LLVMContext() : pImpl(new LLVMContextImpl(*this)) {
122122
assert(OSSIfEntry->second == LLVMContext::OB_oss_if &&
123123
"oss_if operand bundle id drifted!");
124124
(void)OSSIfEntry;
125+
126+
auto *OSSCapturedEntry = pImpl->getOrInsertBundleTag("QUAL.OSS.CAPTURED");
127+
assert(OSSCapturedEntry->second == LLVMContext::OB_oss_captured &&
128+
"oss_captured operand bundle id drifted!");
129+
(void)OSSCapturedEntry;
125130
// END OmpSs IDs
126131

127132
SyncScope::ID SingleThreadSSID =

llvm/test/Analysis/OmpSsRegionAnalysis/task_layout_with_unpack.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ entry:
105105
%9 = mul i64 %4, 4
106106
%10 = mul i64 %7, 4
107107
; to here
108-
%11 = call token @llvm.directive.region.entry() [ "DIR.OSS"([5 x i8] c"TASK\00"), "QUAL.OSS.SHARED"(i32* %vla), "QUAL.OSS.VLA.DIMS"(i32* %vla, i64 %1), "QUAL.OSS.FIRSTPRIVATE"(i32* %k.addr), "QUAL.OSS.FIRSTPRIVATE"(i32* %j.addr), "QUAL.OSS.DEP.IN"(i32* %vla, i64 %8, i64 %9, i64 %10) ]
108+
%11 = call token @llvm.directive.region.entry() [ "DIR.OSS"([5 x i8] c"TASK\00"), "QUAL.OSS.SHARED"(i32* %vla), "QUAL.OSS.VLA.DIMS"(i32* %vla, i64 %1), "QUAL.OSS.CAPTURED"(i64 %1), "QUAL.OSS.FIRSTPRIVATE"(i32* %k.addr), "QUAL.OSS.FIRSTPRIVATE"(i32* %j.addr), "QUAL.OSS.DEP.IN"(i32* %vla, i64 %8, i64 %9, i64 %10) ]
109109
call void @llvm.directive.region.exit(token %11)
110110
%12 = load i8*, i8** %saved_stack, align 8
111111
call void @llvm.stackrestore(i8* %12)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
; RUN: opt -ompss-2-regions -analyze -disable-checks -print-verbosity=vla_dims_capture_missing < %s 2>&1 | FileCheck %s
2+
3+
define dso_local void @vla_section_dep(i32 %n) {
4+
entry:
5+
%n.addr = alloca i32, align 4
6+
%saved_stack = alloca i8*, align 8
7+
%__vla_expr0 = alloca i64, align 8
8+
%__vla_expr1 = alloca i64, align 8
9+
store i32 %n, i32* %n.addr, align 4
10+
%0 = load i32, i32* %n.addr, align 4
11+
%add = add nsw i32 %0, 1
12+
%1 = zext i32 %add to i64
13+
%2 = load i32, i32* %n.addr, align 4
14+
%add1 = add nsw i32 %2, 2
15+
%3 = zext i32 %add1 to i64
16+
%4 = call i8* @llvm.stacksave()
17+
store i8* %4, i8** %saved_stack, align 8
18+
%5 = mul nuw i64 %1, %3
19+
%vla = alloca i32, i64 %5, align 16
20+
store i64 %1, i64* %__vla_expr0, align 8
21+
store i64 %3, i64* %__vla_expr1, align 8
22+
%6 = call token @llvm.directive.region.entry() [ "DIR.OSS"([5 x i8] c"TASK\00"), "QUAL.OSS.SHARED"(i32* %vla), "QUAL.OSS.VLA.DIMS"(i32* %vla, i64 %1, i64 %3) ]
23+
call void @llvm.directive.region.exit(token %6)
24+
%7 = load i8*, i8** %saved_stack, align 8
25+
call void @llvm.stackrestore(i8* %7)
26+
ret void
27+
}
28+
29+
; CHECK: [0] %6
30+
; CHECK-NEXT: %1
31+
; CHECK-NEXT: %3
32+
33+
declare i8* @llvm.stacksave()
34+
declare token @llvm.directive.region.entry()
35+
declare void @llvm.directive.region.exit(token)
36+
declare void @llvm.stackrestore(i8*)

0 commit comments

Comments
 (0)