Skip to content

Commit 45af9a9

Browse files
fywkevinYuanwei Fang
andauthored
[triton][tool] A CLI Tool for Tensor Layout Printing (#4486)
A CLI tool to print the layout of a tensor. Currently, only triton_gpu's `DistributedEncoding` (no `SharedEncoding`) tensor layout print is supported via the exposed `getLayoutStr` API from the dialect library. In the future, we could also add more tensor layout print from other backend HW targets (e.g., CPU). Example usage: ``` triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view ``` An input file usually looks like: ``` #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}> ``` The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Yuanwei Fang <[email protected]>
1 parent 1a20556 commit 45af9a9

File tree

3 files changed

+294
-0
lines changed

3 files changed

+294
-0
lines changed

bin/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,10 @@ target_link_libraries(triton-llvm-opt PRIVATE
9090
LLVMCodeGen
9191
)
9292
export_executable_symbols_for_plugins(triton-llvm-opt)
93+
94+
95+
add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
96+
target_link_libraries(triton-tensor-layout PRIVATE
97+
TritonGPUIR
98+
${triton_libs}
99+
)

bin/triton-tensor-layout.cpp

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
#include "mlir/AsmParser/AsmParser.h"
2+
#include "mlir/AsmParser/AsmParserState.h"
3+
#include "mlir/IR/MLIRContext.h"
4+
5+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6+
7+
#include "llvm/Support/CommandLine.h"
8+
#include "llvm/Support/ErrorOr.h"
9+
#include "llvm/Support/FileSystem.h"
10+
#include "llvm/Support/MemoryBuffer.h"
11+
#include "llvm/Support/SourceMgr.h"
12+
#include "llvm/Support/raw_ostream.h"
13+
14+
using namespace llvm;
15+
using namespace mlir;
16+
17+
// A CLI tool to print the layout of a tensor.
18+
//
19+
// clang-format off
20+
// Example usage:
21+
//
22+
// triton-tensor-layout -l "#triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>"
23+
//
24+
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt
25+
//
26+
// triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt -alias-names="blocked,mma" -use-hw-view
27+
//
28+
// An input file usually looks like:
29+
// '''
30+
// #mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
31+
// #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
32+
// '''
33+
// clang-format on
34+
35+
//===--------------------------------------------------------------------===//
36+
// CLI options
37+
//===--------------------------------------------------------------------===//
38+
39+
cl::OptionCategory PrinterCategory("Available Print Options",
40+
"Options for the tensor layout printing.");
41+
42+
static cl::opt<std::string> InputFile(
43+
"i", cl::desc("File that contains the tensor data layout attributes"),
44+
cl::init(""), cl::value_desc("filename"), cl::cat(PrinterCategory));
45+
46+
static cl::opt<std::string>
47+
OutputFile("o", cl::desc("Output file to write the layout into"),
48+
cl::init(""), cl::value_desc("filename"),
49+
cl::cat(PrinterCategory));
50+
51+
static cl::opt<std::string>
52+
DataLayoutStr("l", cl::desc("Tensor data layout attribute in string"),
53+
cl::value_desc("layout-string"), cl::init(""),
54+
cl::cat(PrinterCategory));
55+
56+
static cl::list<std::string>
57+
AliasName("alias-names",
58+
cl::desc("A list of alias names (separated by comma) of the "
59+
"layout attributes in the input file"),
60+
cl::value_desc("name1,name2,name3,..."), cl::CommaSeparated,
61+
cl::ZeroOrMore, cl::cat(PrinterCategory));
62+
63+
static cl::opt<bool> UseHWPointOfView(
64+
"use-hw-view",
65+
llvm::cl::desc(
66+
"Print the layout in hardware point of view. This means the output is "
67+
"from the warp's perspective. Otherwise, the output is from the "
68+
"tensor's perspective (e.g., each element maps to xxx thread)."),
69+
cl::init(false), cl::cat(PrinterCategory));
70+
71+
static cl::opt<std::string> TensorStr(
72+
"t", cl::desc("Tensor shape and element type (e.g., tensor<2x2xf32>)"),
73+
cl::init(""), cl::value_desc("tensor-type"), cl::cat(PrinterCategory));
74+
75+
//===--------------------------------------------------------------------===//
76+
// Helper functions
77+
//===--------------------------------------------------------------------===//
78+
79+
LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) {
80+
StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace();
81+
82+
// Dispatch to the corresponding dialect helper function to print the layout.
83+
if (dialectName == "triton_gpu") {
84+
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
85+
return success();
86+
}
87+
88+
llvm::errs() << "Unsupported tensor layout attribute: "
89+
<< tensorType.getEncoding() << "\n";
90+
return failure();
91+
}
92+
93+
LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,
94+
ArrayRef<std::string> names,
95+
TensorType tensorTy, raw_string_ostream &ss) {
96+
if (filename.empty())
97+
return success();
98+
99+
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
100+
llvm::MemoryBuffer::getFileOrSTDIN(filename);
101+
if (std::error_code ec = fileOrErr.getError()) {
102+
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
103+
return failure();
104+
}
105+
106+
llvm::SourceMgr sourceMgr;
107+
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
108+
ParserConfig config(context);
109+
auto asmState = AsmParserState();
110+
111+
Block parsedIR;
112+
if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
113+
llvm::errs() << "Fail to parse the input file: " << filename << "\n";
114+
return failure();
115+
}
116+
117+
auto printLambda = [&](StringRef name, Attribute attr) {
118+
ss << "Print layout attribute: #" << name << " = " << attr << "\n";
119+
120+
auto rankedTensorTy = RankedTensorType::get(
121+
tensorTy.getShape(), tensorTy.getElementType(), attr);
122+
123+
return layoutPrint(rankedTensorTy, ss);
124+
};
125+
126+
if (names.empty())
127+
// If no alias name is given, we print all layout attributes in the file.
128+
for (auto def : asmState.getAttributeAliasDefs()) {
129+
if (failed(printLambda(def.name, def.value)))
130+
return failure();
131+
}
132+
else {
133+
// Print the layout attributes with the given alias names.
134+
for (auto alias : names) {
135+
auto def = asmState.getAttributeAliasDef(alias);
136+
if (!def) {
137+
llvm::errs() << "Can't find the layout attribute: " << alias << "\n";
138+
return failure();
139+
}
140+
141+
if (failed(printLambda(alias, def->value)))
142+
return failure();
143+
144+
ss << "\n";
145+
}
146+
}
147+
148+
return success();
149+
}
150+
151+
LogicalResult printLayoutFromString(MLIRContext *context,
152+
StringRef layoutAttrStr,
153+
TensorType tensorTy,
154+
raw_string_ostream &ss) {
155+
if (layoutAttrStr.empty())
156+
return success();
157+
158+
Attribute layout = parseAttribute(layoutAttrStr, context);
159+
if (!layout) {
160+
llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n";
161+
return failure();
162+
}
163+
164+
auto rankedTensorTy = RankedTensorType::get(
165+
tensorTy.getShape(), tensorTy.getElementType(), layout);
166+
167+
ss << "Print layout attribute: " << layout << "\n";
168+
169+
return layoutPrint(rankedTensorTy, ss);
170+
}
171+
172+
//===--------------------------------------------------------------------===//
173+
// Main entry point
174+
//===--------------------------------------------------------------------===//
175+
176+
int main(int argc, char **argv) {
177+
cl::HideUnrelatedOptions(PrinterCategory);
178+
cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n");
179+
180+
DialectRegistry registry;
181+
// Register all dialects that can print tensor layout.
182+
registry.insert<triton::gpu::TritonGPUDialect>();
183+
184+
MLIRContext ctx(registry);
185+
ctx.loadAllAvailableDialects();
186+
187+
if (TensorStr.empty()) {
188+
llvm::errs() << "Must specify the tensor type argument\n";
189+
return 1;
190+
}
191+
192+
Type parsedTy = parseType(TensorStr, &ctx);
193+
if (!parsedTy) {
194+
llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr
195+
<< "\n";
196+
return 1;
197+
}
198+
199+
TensorType tensorType = dyn_cast<TensorType>(parsedTy);
200+
if (!tensorType) {
201+
llvm::errs() << "Invalid tensor type argument: " << TensorStr << "\n";
202+
return 1;
203+
}
204+
205+
std::string storage;
206+
raw_string_ostream ss(storage);
207+
208+
if (failed(printLayoutFromFile(&ctx, InputFile, AliasName, tensorType, ss)))
209+
return 1;
210+
211+
if (failed(printLayoutFromString(&ctx, DataLayoutStr, tensorType, ss)))
212+
return 1;
213+
214+
if (OutputFile.empty()) {
215+
llvm::outs() << ss.str();
216+
} else {
217+
std::error_code ec;
218+
llvm::raw_fd_ostream outFs(OutputFile, ec, llvm::sys::fs::OF_Text);
219+
if (ec) {
220+
llvm::errs() << "Error: " << ec.message() << " : unable to open "
221+
<< OutputFile << " for output\n";
222+
return 1;
223+
}
224+
outFs << ss.str();
225+
outFs.close();
226+
}
227+
228+
return 0;
229+
}

0 commit comments

Comments
 (0)