Skip to content

indices_tuple: Add assertion that each pair should be either positive or negative #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 20, 2025

Conversation

lucamarini22
Copy link

While using indices_tuples to compute multi-label similarity, I noticed that two embeddings could be at the same time positive and negative pairs without raising an error. For example:

a1 = torch.tensor([0, 0, 1])
p = torch.tensor([2, 3, 4])
a2 = torch.tensor([1, 1, 0])
n = torch.tensor([2, 3, 2])

, where the 1st positive pair is (a1[0], p[0]), which is (0, 2), which corresponds with (embeddings[0], embeddings[2]).
And the 3rd negative pair is (a2[2], n[2]), which also is (0, 2), which also corresponds with (embeddings[0], embeddings[2]).

Therefore, I added an assertion to make sure that each pair is either positive or negative when using indices_tuple.

@KevinMusgrave KevinMusgrave changed the base branch from master to dev April 4, 2025 13:36
@KevinMusgrave KevinMusgrave changed the base branch from dev to master April 4, 2025 13:37
@KevinMusgrave KevinMusgrave changed the base branch from master to dev April 4, 2025 13:37
@KevinMusgrave
Copy link
Owner

Good idea, thanks!

Looks like there might be a bug in one of my tests. I will check it out later.

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Apr 4, 2025

@lucamarini22 Grok suggested the following as being more efficient. It seems correct to me. Also Claude confirmed. What do you think?

def _assert_either_pos_or_neg(pos_mask, neg_mask):
    assert not torch.any((pos_mask != 0) & (neg_mask != 0)), "Each pair should be either be positive or negative"

@lucamarini22
Copy link
Author

@KevinMusgrave thanks for the reply!
Yes I agree, it also looks cleaner

I just pushed the change

@lucamarini22
Copy link
Author

Hey @KevinMusgrave, I can also take a look at the unit tests, but now they need to be re-executed I think

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Apr 10, 2025

I've rerun them now. I think the problem is in CrossBatchMemory, either with the input indices_tuple or the internal indices_tuple. I'm planning to remove the input indices_tuple argument, so if that's the easiest fix then I'll just do that: #614

@KevinMusgrave
Copy link
Owner

Yeah I think it's buggy, so I've removed that part of the test in anticipation of removing the argument #614

@KevinMusgrave KevinMusgrave merged commit 8e1f952 into KevinMusgrave:dev Apr 20, 2025
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants