-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][EmitC] Add pass to wrap a func in class #141158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@@ -22,6 +22,7 @@ | |||
#include "llvm/ADT/StringExtras.h" | |||
#include "llvm/ADT/StringMap.h" | |||
#include "llvm/ADT/TypeSwitch.h" | |||
#include "llvm/Support/CommandLine.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to be deleted when cleaning up. Had it when trying to bring the cl options to this file.
std::map<std::string, Value> fields; | ||
os << "std::map<std::string, char*> _buffer_map {"; | ||
if (argAttrs) { | ||
bool isFirst = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we can always add a ,
after an element in the dictionary initializer list.
Operation *operation = functionOp.getOperation(); | ||
if (emitter.shouldPrintClass()) { | ||
if (functionOp.isExternal()) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should it fail or should it warn and skip over?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm making it fail since it would fail later on, anyway. Also printing out a "Warning".
I've left a TO-DO to allow more discussion within the community in case there is something better that can be done.(As we discussed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it has to be an error, then don't label it as a warning. Typically we have diagnostic handlers that you should prefer over writing to a stream directly. Mircea and I will have to look over the convention in the MLIR codebase to figure out what that is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should have access to emitError()
, emitWarning()
, and emitRemark()
from Diagnostics.h
. There's an emitError()
example in this file, but you can look at the way things are handled in Target to get a better picture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would it fail later on - you can skip over it, no?
os << " { \"" << name << "\"" << ", reinterpret_cast<char*>(" | ||
<< emitter.getOrCreateName(v) << ") }, "; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What you're doing is fine, but you may want to use llvm::formatv()
if things get more complicated. https://llvm.org/docs/ProgrammersManual.html#formatting-strings-the-formatv-function
I'm just pointing this out as an FYI. No need to change your code unless you think it would help/be easier.
Operation *operation = functionOp.getOperation(); | ||
if (emitter.shouldPrintClass()) { | ||
if (functionOp.isExternal()) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it has to be an error, then don't label it as a warning. Typically we have diagnostic handlers that you should prefer over writing to a stream directly. Mircea and I will have to look over the convention in the MLIR codebase to figure out what that is.
@@ -0,0 +1,15 @@ | |||
/// The function has no argument attributes | |||
// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you expecting an error to be printed? 2>&1 | FileCheck ...
is a common pattern to make sure you can check the error stream for diagnostics.
There's also https://mlir.llvm.org/getting_started/TestingGuide/, which I haven't gone through much, but https://mlir.llvm.org/getting_started/TestingGuide/#diagnostic-tests seems apropos.
Operation *operation = functionOp.getOperation(); | ||
if (emitter.shouldPrintClass()) { | ||
if (functionOp.isExternal()) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would it fail later on - you can skip over it, no?
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-emitc Author: None (Jaddyen) ChangesFull diff: https://github.com/llvm/llvm-project/pull/141158.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h
index 7c5747a888261..d1a6c1dc12d4c 100644
--- a/mlir/include/mlir/Target/Cpp/CppEmitter.h
+++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h
@@ -28,7 +28,9 @@ namespace emitc {
/// with matching id are emitted.
LogicalResult translateToCpp(Operation *op, raw_ostream &os,
bool declareVariablesAtTop = false,
- StringRef fileId = {});
+ StringRef fileId = {}, bool emitClass = false,
+ StringRef className = {},
+ StringRef fieldNameAttribute = {});
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 2108ffd414c56..9e1533d34f6ea 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -33,13 +33,50 @@ void registerToCppTranslation() {
"file-id", llvm::cl::desc("Emit emitc.file ops with matching id"),
llvm::cl::init(""));
+ static llvm::cl::opt<bool> emitClass(
+ "emit-class",
+ llvm::cl::desc("If specified, the output will be a class where "
+ "the function(s) in the module are methods "
+ "Enables class-related options"),
+ llvm::cl::init(false));
+
+ static llvm::cl::opt<std::string> className(
+ "class-name",
+ llvm::cl::desc("Mandatory class name if --emit-class is set"),
+ llvm::cl::init(""));
+
+ static llvm::cl::opt<std::string> fieldNameAttribute(
+ "field-name-attribute",
+ llvm::cl::desc("Mandatory name of the attribute to use as field name if "
+ "--emit-class is set(default=tf_saved_model.index_path)"),
+ llvm::cl::init("tf_saved_model.index_path"));
+
TranslateFromMLIRRegistration reg(
"mlir-to-cpp", "translate from mlir to cpp",
[](Operation *op, raw_ostream &output) {
+ if (emitClass) {
+ if (className.empty()) {
+ llvm::errs() << "Error: --class-name is mandatory when "
+ "--emit-class is set.\n";
+ return mlir::failure();
+ }
+ if (fieldNameAttribute.empty()) {
+ llvm::errs() << "Error: --field-name-attribute is mandatory when "
+ "--emit-class is set.\n";
+ return mlir::failure();
+ }
+ return emitc::translateToCpp(
+ op, output,
+ /*declareVariablesAtTop=*/declareVariablesAtTop,
+ /*fileId=*/fileId, /*emitClass=*/emitClass,
+ /*className=*/className,
+ /*fieldNameAttribute=*/fieldNameAttribute);
+ }
return emitc::translateToCpp(
op, output,
/*declareVariablesAtTop=*/declareVariablesAtTop,
- /*fileId=*/fileId);
+ /*fileId=*/fileId, /*emitClass=*/emitClass, /*className=*/className,
+ /*fieldNameAttribute=*/fieldNameAttribute);
},
[](DialectRegistry ®istry) {
// clang-format off
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 0c4975a13d301..46891d0aca556 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -68,6 +68,14 @@ inline LogicalResult interleaveCommaWithError(const Container &c,
return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
}
+template <typename Container, typename UnaryFunctor>
+inline LogicalResult interleaveWithNewLineWithError(const Container &c,
+ raw_ostream &os,
+ UnaryFunctor eachFn) {
+ return interleaveWithError(c.begin(), c.end(), eachFn,
+ [&]() { os << ";\n"; });
+}
+
/// Return the precedence of a operator as an integer, higher values
/// imply higher precedence.
static FailureOr<int> getOperatorPrecedence(Operation *operation) {
@@ -116,7 +124,8 @@ namespace {
/// Emitter that uses dialect specific emitters to emit C++ code.
struct CppEmitter {
explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
- StringRef fileId);
+ StringRef fileId, bool emitClass, StringRef className,
+ StringRef fieldNameAttribute);
/// Emits attribute or returns failure.
LogicalResult emitAttribute(Location loc, Attribute attr);
@@ -233,6 +242,15 @@ struct CppEmitter {
/// be declared at the beginning of a function.
bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
+ // Returns whether we should emit a C++ class
+ bool shouldPrintClass() { return emitClass; };
+
+ // Returns the class name to emit
+ std::string getClassName() { return className; };
+
+ // Returns the field name to use in the map
+ std::string getfieldNameAttribute() { return fieldNameAttribute; };
+
/// Returns whether this file op should be emitted
bool shouldEmitFile(FileOp file) {
return !fileId.empty() && file.getId() == fileId;
@@ -268,6 +286,18 @@ struct CppEmitter {
/// Only emit file ops whos id matches this value.
std::string fileId;
+ /// Controls whether the output should be a C++ class.
+ /// If true, the generated C++ code will be encapsulated within a class,
+ /// and functions from the input module will become its member functions.
+ const bool emitClass;
+
+ /// The specified name for the generated C++ class
+ const std::string className;
+
+ /// Name of the MLIR attribute to use as a field name within the generated
+ /// class
+ const std::string fieldNameAttribute;
+
/// Map from value to name of C++ variable that contain the name.
ValueMapper valueMapper;
@@ -1025,6 +1055,17 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
}));
}
+static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
+ Region::BlockArgListType arguments) {
+ raw_indented_ostream &os = emitter.ostream();
+
+ return (interleaveWithNewLineWithError(
+ arguments, os, [&](BlockArgument arg) -> LogicalResult {
+ return emitter.emitVariableDeclaration(
+ functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
+ }));
+}
+
static LogicalResult printFunctionBody(CppEmitter &emitter,
Operation *functionOp,
Region::BlockListType &blocks) {
@@ -1129,6 +1170,45 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}
+static LogicalResult emitClassFields(CppEmitter &emitter,
+ emitc::FuncOp functionOp) {
+ raw_indented_ostream &os = emitter.ostream();
+ auto argAttrs = functionOp.getArgAttrs();
+ Operation *operation = functionOp.getOperation();
+ if (failed(printFields(emitter, operation, functionOp.getArguments())))
+ return failure();
+ os << ";\n";
+
+ std::map<std::string, Value> fields;
+ os << "\nstd::map<std::string, char*> _buffer_map {";
+ if (argAttrs) {
+ for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
+ if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
+ auto nv = da.getNamed(emitter.getfieldNameAttribute())->getValue();
+ auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
+ auto Ins = fields.insert({name, v});
+ if (!Ins.second)
+ return failure();
+ os << " { \"" << name << "\"" << ", reinterpret_cast<char*>("
+ << emitter.getOrCreateName(v) << ") }, ";
+ }
+ }
+ } else
+ return failure();
+
+ os << "};\n";
+ os << "char* getBufferForName(const std::string& name) const {\n";
+ os.indent();
+ os.indent();
+ os << "auto it = _buffer_map.find(name);\n";
+ os << "return (it == _buffer_map.end()) ? nullptr : it->second;\n";
+ os.unindent();
+ os.unindent();
+ os << "}\n\n";
+
+ return success();
+}
+
static LogicalResult printOperation(CppEmitter &emitter,
emitc::FuncOp functionOp) {
// We need to declare variables at top if the function has multiple blocks.
@@ -1140,6 +1220,29 @@ static LogicalResult printOperation(CppEmitter &emitter,
CppEmitter::Scope scope(emitter);
raw_indented_ostream &os = emitter.ostream();
+ Operation *operation = functionOp.getOperation();
+ if (emitter.shouldPrintClass()) {
+ if (functionOp.isExternal()) {
+ // TODO: Determine the best long-term strategy for external functions.
+ // Currently, we're skipping over this functionOp.
+ // We have considered using emitWarning() which would return
+ // InFlightDiagnostic which seems can be automatically converted to LogicalResult since
+ // this is done in emitAttributes where emitError is converted to LogicalResult. However, it requires that we pass in a
+ // location which at first glance we don't have in this scope. Open to
+ // further discussion on this.
+ os << "Warning: Cannot process external function '"
+ << functionOp.getName() << "'. "
+ << "This functionOp lacks a body so we will skip over it.";
+ return success();
+ }
+ os << "class " << emitter.getClassName() << " final {\n";
+ os << "public: \n";
+ os.indent();
+
+ if (failed(emitClassFields(emitter, functionOp)))
+ return failure();
+ }
+
if (functionOp.getSpecifiers()) {
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
os << cast<StringAttr>(specifier).str() << " ";
@@ -1149,23 +1252,37 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (failed(emitter.emitTypes(functionOp.getLoc(),
functionOp.getFunctionType().getResults())))
return failure();
+ // TODO: We may wanna consider having the name of the function be execute in
+ // the case that we want to emit a class instead of main. Leaving as is for
+ // now to make the change smaller.
os << " " << functionOp.getName();
os << "(";
- Operation *operation = functionOp.getOperation();
- if (functionOp.isExternal()) {
- if (failed(printFunctionArgs(emitter, operation,
- functionOp.getArgumentTypes())))
+
+ if (!emitter.shouldPrintClass()) {
+ if (functionOp.isExternal()) {
+ if (failed(printFunctionArgs(emitter, operation,
+ functionOp.getArgumentTypes())))
+ return failure();
+ os << ");";
+ return success();
+ }
+ if (failed(
+ printFunctionArgs(emitter, operation, functionOp.getArguments())))
return failure();
- os << ");";
- return success();
}
- if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
- return failure();
os << ") {\n";
+
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
return failure();
- os << "}\n";
+
+ if (emitter.shouldPrintClass()) {
+ os << "}\n";
+ os.unindent();
+ os << "};\n";
+ } else {
+ os << "}\n";
+ }
return success();
}
@@ -1202,9 +1319,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
}
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
- StringRef fileId)
+ StringRef fileId, bool emitClass, StringRef className,
+ StringRef fieldNameAttribute)
: os(os), declareVariablesAtTop(declareVariablesAtTop),
- fileId(fileId.str()) {
+ fileId(fileId.str()), emitClass(emitClass), className(className.str()),
+ fieldNameAttribute(fieldNameAttribute.str()) {
valueInScopeCount.push(0);
labelInScopeCount.push(0);
}
@@ -1787,7 +1906,10 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
bool declareVariablesAtTop,
- StringRef fileId) {
- CppEmitter emitter(os, declareVariablesAtTop, fileId);
+ StringRef fileId, bool emitClass,
+ StringRef className,
+ StringRef fieldNameAttribute) {
+ CppEmitter emitter(os, declareVariablesAtTop, fileId, emitClass, className,
+ fieldNameAttribute);
return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
}
diff --git a/mlir/test/mlir-translate/emit-class-neg-external.mlir b/mlir/test/mlir-translate/emit-class-neg-external.mlir
new file mode 100644
index 0000000000000..c34a1652abd3f
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class-neg-external.mlir
@@ -0,0 +1,8 @@
+/// An external function - has no body
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+ emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
+}
+
+// CHECK: Warning: Cannot process external function 'extern_func'. This functionOp lacks a body so we will skip over it.
diff --git a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
new file mode 100644
index 0000000000000..6d43fa953a946
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
@@ -0,0 +1,15 @@
+/// The function has no argument attributes
+// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+ emitc.func @foo(%arg0 : i32) {
+ emitc.call_opaque "bar" (%arg0) : (i32) -> ()
+ emitc.return
+ }
+}
+
+// CHECK: class ArgAttrs final {
+// CHECK-NEXT: public:
+// CHECK-NEXT: int32_t v1;
+// CHECK-EMPTY:
+// CHECK-NEXT: std::map<std::string, char*> _buffer_map {
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
new file mode 100644
index 0000000000000..2779cb315ed41
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+ emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %2 = load %1 : <f32>
+ %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %4 = load %3 : <f32>
+ %5 = add %2, %4 : (f32, f32) -> f32
+ %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ assign %5 : f32 to %6 : <f32>
+ return
+ }
+}
+
+// CHECK: class MyAdder final {
+// CHECK-NEXT: public:
+// CHECK-NEXT: float v1[1];
+// CHECK-NEXT: float v2[1];
+// CHECK-NEXT: float v3[1];
+// CHECK-EMPTY:
+// CHECK-NEXT: std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) },
+// CHECK-SAME: { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
+// CHECK-NEXT: char* getBufferForName(const std::string& name) const {
+// CHECK-NEXT: auto it = _buffer_map.find(name);
+// CHECK-NEXT: return (it == _buffer_map.end()) ? nullptr : it->second;
+// CHECK-NEXT: }
+// CHECK-EMPTY:
+// CHECK-NEXT: void main() {
+// CHECK-NEXT: size_t v4 = 0;
+// CHECK-NEXT: float v5 = v2[v4];
+// CHECK-NEXT: float v6 = v1[v4];
+// CHECK-NEXT: float v7 = v5 + v6;
+// CHECK-NEXT: v3[v4] = v7;
+// CHECK-NEXT: return;
+// CHECK-NEXT: }
+// CHECK-NEXT: };
+
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd have to find it again, but I believe there was a previous discussion on this (it was emitc specific but there is also a more general approach behind, https://www.youtube.com/watch?v=hIt6J1_E21c)..) In short the principle is that translates are simple from MLIR IR to some external format, ones that are rather trivial to check that correct and these input/output dialects are close to format so that more complex operations are plain MLIR IR -> IR transforms (where one has all the regular tooling, testing & debugging). Which is why you'll see most of the translate functions make no large decisions. Now, wrt C++ translation, it simply today does what is represented in Emit today. EmitC provides the script for the emission. So I'd expect this to be a pass before translate really.
StringRef fileId = {}); | ||
StringRef fileId = {}, bool emitClass = false, | ||
StringRef className = {}, | ||
StringRef fieldNameAttribute = {}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please make sure this is all very clearly documented in the function description?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I have rewritten this to be a pass(in mlir-opt
) first. I will make sure to clearly document this when updating mlir-translate
.
llvm::cl::desc("If specified, the output will be a class where " | ||
"the function(s) in the module are methods " | ||
"Enables class-related options"), | ||
llvm::cl::init(false)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this useful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are aiming to add the functionality of getting classes from emitc. This would enable one to instantiate a class once and load it with weights that it can then use to execute.
static llvm::cl::opt<std::string> className( | ||
"class-name", | ||
llvm::cl::desc("Mandatory class name if --emit-class is set"), | ||
llvm::cl::init("")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I don't think we need two options: just name it -emit-class=<classname>
and check if this isn't empty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will take care of this when I circle back to mlir-translate
changes.
"field-name-attribute", | ||
llvm::cl::desc("Mandatory name of the attribute to use as field name if " | ||
"--emit-class is set(default=tf_saved_model.index_path)"), | ||
llvm::cl::init("tf_saved_model.index_path")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not comfortable with a default in MLIR that refers to tensor flow or any ad-hoc system.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To do in next step.
return emitc::translateToCpp( | ||
op, output, | ||
/*declareVariablesAtTop=*/declareVariablesAtTop, | ||
/*fileId=*/fileId); | ||
/*fileId=*/fileId, /*emitClass=*/emitClass, /*className=*/className, | ||
/*fieldNameAttribute=*/fieldNameAttribute); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two lines would be better unchanged I believe.
Alternatively, you can remove line 68-73 entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To do in next step.
/*fileId=*/fileId, /*emitClass=*/emitClass, | ||
/*className=*/className, | ||
/*fieldNameAttribute=*/fieldNameAttribute); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some sanity check that isn't done here is that fieldNameAttribute
better be unset if emitClass
isn't set either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To do in next step.
os << "}\n\n"; | ||
|
||
return success(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the contract about this map and what it's doing documented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, will ensure in next step.
// further discussion on this. | ||
os << "Warning: Cannot process external function '" | ||
<< functionOp.getName() << "'. " | ||
<< "This functionOp lacks a body so we will skip over it."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every op has a location. The way to emit warning is simply:
functionOp->emitWarning() << ...;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@@ -1149,23 +1253,37 @@ static LogicalResult printOperation(CppEmitter &emitter, | |||
if (failed(emitter.emitTypes(functionOp.getLoc(), | |||
functionOp.getFunctionType().getResults()))) | |||
return failure(); | |||
// TODO: We may wanna consider having the name of the function be execute in | |||
// the case that we want to emit a class instead of main. Leaving as is for | |||
// now to make the change smaller. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please trim this comment and make it more clear, I don't quite get what is intended here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, this file is no longer affected.
if (!emitter.shouldPrintClass()) { | ||
if (functionOp.isExternal()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!emitter.shouldPrintClass()) { | |
if (functionOp.isExternal()) { | |
if (!emitter.shouldPrintClass() && functionOp.isExternal()) { |
Or maybe even better (because of the early return like 1237):
if (!emitter.shouldPrintClass()) { | |
if (functionOp.isExternal()) { | |
if (functionOp.isExternal()) { | |
assert(!emitter.shouldPrintClass()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, this file is no longer affected.
<< "This functionOp lacks a body so we will skip over it."; | ||
return success(); | ||
} | ||
os << "class " << emitter.getClassName() << " final {\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems completely broken to me when there are multiple functions in the input IR (please provide such a test)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To do in next step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some initial comments, mostly about code style. I do wonder if a simpler initial transform followed by a successive lowering would make things easier, like initially creating the class w/ it's fields, then generating accessors, and finally methods. I haven't thought too much about it, but I wonder if that would make things simpler.
Example without initial value: | ||
```mlir | ||
emitc.class @MyModelClass { | ||
emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>> | ||
} | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example can be folded into the one above (output_0 has no initial value, right?
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Outdated
if (!getInitialValue().has_value()) { | ||
return success(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getInitialValue() returns an option, right? you don't need to call has_value()
, since it's operator () will handle that in the conditional.
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Outdated
return success(); | ||
} | ||
|
||
Attribute initValue = getInitialValue().value(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is more idiomatic.
Attribute initValue = getInitialValue().value(); | |
Attribute initValue = *getInitialValue() |
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Outdated
// Check that the type of the initial value is compatible with the type of | ||
// the global variable. | ||
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) { | ||
auto initialValueType = elementsAttr.getType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally if the type isn't crazy, prefer using it over auto
. The auto
above is one of the places where the coding standard points out it helps with clarity. There isn't a hard rule on this, but usually prefer the type if its simple enough, and auto
when the result is obvious (like the cast) or when the type is complicated (template result or iterator).
https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Outdated
<< initialValueType << " is not compatible with field type " | ||
<< fieldType << " its inner type " << innerFieldType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably add single quotes around the type names.
<< initialValueType << " is not compatible with field type " | |
<< fieldType << " its inner type " << innerFieldType; | |
<< initialValueType << " is not compatible with field type '" | |
<< fieldType << "' its inner type '" << innerFieldType << "'"; |
|
||
auto argAttrs = funcOp.getArgAttrs(); | ||
if (argAttrs) { | ||
for (const auto [arg, val] : zip(*argAttrs, funcOp.getArguments())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to copy arg
and val
? if not you need &[...]
if (argAttrs) { | ||
for (const auto [arg, val] : zip(*argAttrs, funcOp.getArguments())) { | ||
if (auto da = dyn_cast<mlir::DictionaryAttr>(arg)) { | ||
auto nv = da.getNamed("tf_saved_model.index_path")->getValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be hard coded? shouldn't this work even w/o tensorflow?
for (const auto [arg, val] : zip(*argAttrs, funcOp.getArguments())) { | ||
if (auto da = dyn_cast<mlir::DictionaryAttr>(arg)) { | ||
auto nv = da.getNamed("tf_saved_model.index_path")->getValue(); | ||
auto fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks a bit questionable.
builder.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr, | ||
/* attributes*/ dictAttr); | ||
// 5. Get the pointers to the class fields | ||
auto pointer = emitc::PointerType::get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a more descriptive name than pointer
, field_ptr
maybe?
if (isa<emitc::ConstantOp>(opToClone) || | ||
isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) || | ||
isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) || | ||
isa<emitc::ReturnOp>(opToClone)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A helper function for this condition may be nicer/easier to follow.
bae7f35
to
4780ab3
Compare
Looks like the last rebase went wrong? There are unrelated commits and I'd expect changes to the emitter if operations are added to the dialect. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, overall looks good and good start.
def EmitC_ClassOp | ||
: EmitC_Op<"class", [AutomaticAllocationScope, IsolatedFromAbove, | ||
OpAsmOpInterface, SymbolTable, | ||
Symbol]#GraphRegionNoTerminator.traits> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: space around #
to make the concat easier to see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clang-format gets rid of the space around #
.
do we need to ignore clang-format on this?
It creates a distinct scope, isolating its contents from the surrounding | ||
MLIR region, similar to how C++ classes encapsulate their internals. | ||
|
||
Example: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: newline in between to ensure the Markdown highlighter using mlir-www triggers.
|
||
let regions = (region AnyRegion:$body); | ||
|
||
let builders = []; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to drop as this is the default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed.
thanks for the pointer!
let description = [{ | ||
The `emitc.field` operation declares a named field within an `emitc.class` | ||
operation. The field's type must be an EmitC type. | ||
If the corresponding function argument has attributes (accessed via `argAttrs`), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a property of the pass and not the op/this op exists independent of this behavior and best to document on pass.
these attributes are attached to the field operation. | ||
Otherwise, the field is created without additional attributes. | ||
|
||
Example of func argument with attributes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here and below about empty line between text and code markdown (but this moves to pass too).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed.
thanks for the pointer!
FunctionType funcType = FunctionType::get(funcContext, inputTypes, results); | ||
Location loc = funcOp.getLoc(); | ||
FuncOp newFuncOp = rewriter.create<emitc::FuncOp>( | ||
loc, rewriter.getStringAttr("execute"), funcType); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you document this name in the pass description? (I think it would be good to document the API generated there are ll together)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed.
thanks for the pointer!
FunctionType funcType = FunctionType::get(funcContext, inputTypes, results); | ||
Location loc = funcOp.getLoc(); | ||
FuncOp newFuncOp = rewriter.create<emitc::FuncOp>( | ||
loc, rewriter.getStringAttr("execute"), funcType); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could be able to just specify execute rather than make stringattr explicitly (a "unwrapped" builder is generated for the op)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed.
thanks for the pointer!
newFuncOp.getBody().takeBody(funcOp.getBody()); | ||
|
||
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front()); | ||
std::vector<Value> newArguments; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reserve required size to avoid resizing vector in loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
|
||
llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true); | ||
if (failed(newFuncOp.eraseArguments(argsToErase))) { | ||
newFuncOp->emitOpError("Failed to erase all arguments using BitVector."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error convention is sentence fragments (start lower case, no trailing period (https://llvm.org/docs/CodingStandards.html#error-and-warning-messages)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s | ||
|
||
module attributes { } { | ||
emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]}, %arg1: !emitc.array<1xf32> {emitc.opaque = ["some_feature"]}, %arg2: !emitc.array<1xf32> {emitc.opaque = ["output_0"]}) attributes { } { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can break this over lines to make easier to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
ack, will address in the next step! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've took a course look over the changes. Overall this looks quite good.
}]; | ||
|
||
let arguments = (ins FlatSymbolRefAttr:$field_name); | ||
let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in line with the get_global
op. The lvalue type models assignable l values only.
//===----------------------------------------------------------------------===// | ||
// FieldOp | ||
//===----------------------------------------------------------------------===// | ||
LogicalResult FieldOp::verify() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you share some of this with the verifier of the GlobalOp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite. Currently, the FieldOp is different from the GlobalOp with respect to what we need to verify. This is due to the change I've made in the most recent commit. We no longer have an initial value in the FieldOp. Instead, we jus have attributes. Due to this, the FieldOp would be different from the GlobalOp and the verify would be different.
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Outdated
|
||
Attribute initValue = *getInitialValue(); | ||
// Check that the type of the initial value is compatible with the type of | ||
// the global variable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// the global variable. | |
// the field. |
LogicalResult matchAndRewrite(emitc::FuncOp funcOp, | ||
PatternRewriter &rewriter) const override { | ||
if (funcOp->getParentOfType<emitc::ClassOp>()) { | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is used to skip functions already nested in the class op, which you could encounter when running the pass twice.
What should happen when the func op has results?
MLIRContext *funcContext = funcOp.getContext(); | ||
ArrayRef<Type> inputTypes = funcOp.getFunctionType().getInputs(); | ||
ArrayRef<Type> results = funcOp.getFunctionType().getResults(); | ||
FunctionType funcType = FunctionType::get(funcContext, inputTypes, results); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just use funcOp.getType() directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed.
TypeAttr typeAttr = TypeAttr::get(val.getType()); | ||
fields.push_back({fieldName, typeAttr}); | ||
rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr, | ||
argAttr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last argument is the initializer, right? Why is the name hint used as the initial value of the field?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed.
@@ -0,0 +1,38 @@ | |||
// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you chose a different name, maybe emitc.name_hint or something. This looks confusing as it morrors the name of the opaque type/attribute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed. thanks for the pointer!
|
||
// CHECK: module { | ||
// CHECK-NEXT: emitc.class @modelClass { | ||
// CHECK-NEXT: emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above: a dictionary attribute doesn't make sense to me as an initializer. I assume the emitter will print an error when encountering this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeap! thanks for pointing this out.
we can make the fieldop
have attributes instead of initializing. we plan on using these attributes to generate a buffer map later on when translating to cpp.
i.e, emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = ["input_tensor"]}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change causes us to end up with a different verify for FieldOp from the GlobalOp verify.
11ed1f6
to
009f137
Compare
11560b4
to
bbf6775
Compare
Looks like this is ready to land? I'll wait a bit and merge today, unless I hear back. |
Drive-by: I think the PR title could be updated to |
mlir-opt
to wrap a func in class
Addressed. Thanks for the pointer! |
@Jaddyen Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
Goal: Enable using C++ classes to AOT compile models for MLGO. This commit introduces a transformation pass that converts standalone `emitc.func` operations into `emitc.class `structures to support class-based C++ code generation for MLGO. Transformation details: - Wrap `emitc.func @func_name` into `emitc.class @Myfunc_nameClass` - Converts function arguments to class fields with preserved attributes - Transforms function body into an `execute()` method with no arguments - Replaces argument references with `get_field` operations Before: emitc.func @model(%arg0, %arg1, %arg2) with direct argument access After: emitc.class with fields and execute() method using get_field operations This enables generating C++ classes that can be instantiated and executed as self-contained model objects for AOT compilation workflows.
Goal: Enable using C++ classes to AOT compile models for MLGO. This commit introduces a transformation pass that converts standalone `emitc.func` operations into `emitc.class `structures to support class-based C++ code generation for MLGO. Transformation details: - Wrap `emitc.func @func_name` into `emitc.class @Myfunc_nameClass` - Converts function arguments to class fields with preserved attributes - Transforms function body into an `execute()` method with no arguments - Replaces argument references with `get_field` operations Before: emitc.func @model(%arg0, %arg1, %arg2) with direct argument access After: emitc.class with fields and execute() method using get_field operations This enables generating C++ classes that can be instantiated and executed as self-contained model objects for AOT compilation workflows.
…tfieldop` (llvm#145605) Add support to the emitter for `ClassOp`, `FieldOp` and `GetFieldOp`. These ops were introduced in llvm#141158
Goal: Enable using C++ classes to AOT compile models for MLGO.
This commit introduces a transformation pass that converts standalone
emitc.func
operations intoemitc.class
structures to support class-based C++ code generation for MLGO.Transformation details:
emitc.func @func_name
intoemitc.class @Myfunc_nameClass
execute()
method with no argumentsget_field
operationsBefore: emitc.func @model(%arg0, %arg1, %arg2) with direct argument access
After: emitc.class with fields and execute() method using get_field operations
This enables generating C++ classes that can be instantiated and executed as self-contained model objects for AOT compilation workflows.