Skip to content

Commit 0f617cc

Browse files
committed
Merges cannot handle tokens containing spaces.
This fixes this while keeping backward support. We don't want to merge that blindly.
1 parent c74e9e6 commit 0f617cc

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

tokenizers/src/models/bpe/serialization.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ impl Serialize for BPE {
2828
.map(|(pair, (rank, _))| (pair, rank))
2929
.collect();
3030
merges.sort_unstable_by_key(|k| *k.1);
31-
let merges_str = merges
31+
let merges = merges
3232
.into_iter()
33-
.map(|(pair, _)| format!("{} {}", self.vocab_r[&pair.0], self.vocab_r[&pair.1]))
33+
.map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
3434
.collect::<Vec<_>>();
3535
let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
3636

3737
model.serialize_field("vocab", &ordered_vocab)?;
38-
model.serialize_field("merges", &merges_str)?;
38+
model.serialize_field("merges", &merges)?;
3939

4040
model.end()
4141
}
@@ -77,7 +77,14 @@ impl<'de> Visitor<'de> for BPEVisitor {
7777
{
7878
let mut builder = BpeBuilder::new();
7979
let mut vocab: Option<HashMap<String, u32>> = None;
80-
let mut merges: Option<Vec<String>> = None;
80+
81+
#[derive(Debug, Deserialize)]
82+
#[serde(untagged)]
83+
enum MergeType {
84+
Tuple(Vec<(String, String)>),
85+
Legacy(Vec<String>),
86+
}
87+
let mut merges: Option<MergeType> = None;
8188
while let Some(key) = map.next_key::<String>()? {
8289
match key.as_ref() {
8390
"dropout" => {
@@ -120,8 +127,12 @@ impl<'de> Visitor<'de> for BPEVisitor {
120127
}
121128
}
122129
if let (Some(vocab), Some(merges)) = (vocab, merges) {
123-
let merges =
124-
convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?;
130+
let merges = match merges {
131+
MergeType::Tuple(merges) => merges,
132+
MergeType::Legacy(merges) => {
133+
convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?
134+
}
135+
};
125136
builder = builder.vocab_and_merges(vocab, merges);
126137
Ok(builder.build().map_err(Error::custom)?)
127138
} else {

0 commit comments

Comments
 (0)