-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[IR2Vec] Overloading operator+
for Embeddings
#145118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/svkeerthy/06-20-increasing_tolerance_in_approximatelyequals
Are you sure you want to change the base?
[IR2Vec] Overloading operator+
for Embeddings
#145118
Conversation
Warning This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
operator+
for `Embeddings
operator+
for `Embeddingsoperator+
for Embeddings
@llvm/pr-subscribers-mlgo @llvm/pr-subscribers-llvm-analysis Author: S. VenkataKeerthy (svkeerthy) ChangesAdd out-of-place addition operator for Embedding class in IR2Vec. This is used in subsequent patches. Full diff: https://github.com/llvm/llvm-project/pull/145118.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 480b834077b86..f6c40d36f8026 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -106,6 +106,7 @@ struct Embedding {
const std::vector<double> &getData() const { return Data; }
/// Arithmetic operators
+ Embedding operator+(const Embedding &RHS) const;
Embedding &operator+=(const Embedding &RHS);
Embedding &operator-=(const Embedding &RHS);
Embedding &operator*=(double Factor);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 27cc2a4109879..d5d27db8bd2bf 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -71,6 +71,14 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
// Embedding
//===----------------------------------------------------------------------===//
+Embedding Embedding::operator+(const Embedding &RHS) const {
+ assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+ Embedding Result(*this);
+ std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(),
+ std::plus<double>());
+ return Result;
+}
+
Embedding &Embedding::operator+=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 33ac16828eb6c..50eb7f73c6f50 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) {
}
}
+TEST(EmbeddingTest, AddVectorsOutOfPlace) {
+ Embedding E1 = {1.0, 2.0, 3.0};
+ Embedding E2 = {0.5, 1.5, -1.0};
+
+ Embedding E3 = E1 + E2;
+ EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
+
+ // Check that E1 and E2 are unchanged
+ EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+ EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
TEST(EmbeddingTest, AddVectors) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
@@ -180,6 +192,12 @@ TEST(EmbeddingTest, AccessOutOfBounds) {
EXPECT_DEATH(E[4] = 4.0, "Index out of bounds");
}
+TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) {
+ Embedding E1 = {1.0, 2.0};
+ Embedding E2 = {1.0};
+ EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension");
+}
+
TEST(EmbeddingTest, MismatchedDimensionsAddVectors) {
Embedding E1 = {1.0, 2.0};
Embedding E2 = {1.0};
|
Add out-of-place addition operator for Embedding class in IR2Vec.
This is used in subsequent patches.
(Tracking issue - #141817)