Skip to content

Commit d952a50

Browse files
[BUG] Handle duplicates in chroma-load. (chroma-core#4423)
2 parents eb232da + 66db645 commit d952a50

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

rust/load/src/data_sets.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,10 @@ impl DataSet for ReferencingDataSet {
760760
let mut keys = vec![];
761761
let num_keys = gq.limit.sample(guac);
762762
for _ in 0..num_keys {
763-
keys.push(KeySelector::Random(gq.skew).select(guac, self));
763+
let key = KeySelector::Random(gq.skew).select(guac, self);
764+
if !keys.contains(&key) {
765+
keys.push(key);
766+
}
764767
}
765768
let collection = client.get_collection(&self.operates_on).await?;
766769
// TODO(rescrv): from the reference collection, pull the documents and embeddings and
@@ -787,7 +790,10 @@ impl DataSet for ReferencingDataSet {
787790
let mut keys = vec![];
788791
let num_keys = qq.limit.sample(guac);
789792
for _ in 0..num_keys {
790-
keys.push(KeySelector::Random(qq.skew).select(guac, self));
793+
let key = KeySelector::Random(qq.skew).select(guac, self);
794+
if !keys.contains(&key) {
795+
keys.push(key);
796+
}
791797
}
792798
let keys = keys.iter().map(|k| k.as_str()).collect::<Vec<_>>();
793799
if let Some(res) = self.references.get_by_key(client, &keys).await? {
@@ -835,7 +841,10 @@ impl DataSet for ReferencingDataSet {
835841
let collection = client.get_collection(&self.operates_on).await?;
836842
let mut keys = vec![];
837843
for offset in 0..uq.batch_size {
838-
keys.push(uq.key.select_from_reference(self, offset));
844+
let key = uq.key.select_from_reference(self, offset);
845+
if !keys.contains(&key) {
846+
keys.push(key);
847+
}
839848
}
840849
let keys = keys.iter().map(|k| k.as_str()).collect::<Vec<_>>();
841850
if let Some(res) = self.references.get_by_key(client, &keys).await? {
@@ -1019,7 +1028,10 @@ impl DataSet for VerifyingDataSet {
10191028
}
10201029

10211030
for _ in 0..num_keys {
1022-
keys.push(KeySelector::Random(gq.skew).select(guac, self));
1031+
let key = KeySelector::Random(gq.skew).select(guac, self);
1032+
if !keys.contains(&key) {
1033+
keys.push(key);
1034+
}
10231035
}
10241036

10251037
let reference_collection = client
@@ -1208,7 +1220,10 @@ impl DataSet for VerifyingDataSet {
12081220
);
12091221

12101222
for offset in 0..uq.batch_size {
1211-
keys.push(uq.key.select_from_reference(self, offset));
1223+
let key = uq.key.select_from_reference(self, offset);
1224+
if !keys.contains(&key) {
1225+
keys.push(key)
1226+
}
12121227
}
12131228
let keys = keys.iter().map(|k| k.as_str()).collect::<Vec<_>>();
12141229
if let Some(res) = self.reference_data_set.get_by_key(client, &keys).await? {

0 commit comments

Comments
 (0)