Skip to content

Commit 4e78191

Browse files
committed
Updating the deserialization error for models.
1 parent 40f4f24 commit 4e78191

File tree

1 file changed

+70
-6
lines changed

1 file changed

+70
-6
lines changed

tokenizers/src/models/mod.rs

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ pub mod wordpiece;
88
use std::collections::HashMap;
99
use std::path::{Path, PathBuf};
1010

11-
use serde::{Deserialize, Serialize, Serializer};
11+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
1212

1313
use crate::models::bpe::{BpeTrainer, BPE};
1414
use crate::models::unigram::{Unigram, UnigramTrainer};
@@ -57,7 +57,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
5757
}
5858
}
5959

60-
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
60+
#[derive(Serialize, Debug, PartialEq, Clone)]
6161
#[serde(untagged)]
6262
pub enum ModelWrapper {
6363
BPE(BPE),
@@ -68,6 +68,73 @@ pub enum ModelWrapper {
6868
Unigram(Unigram),
6969
}
7070

71+
impl<'de> Deserialize<'de> for ModelWrapper {
72+
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
73+
where
74+
D: Deserializer<'de>,
75+
{
76+
#[derive(Deserialize)]
77+
pub struct Tagged {
78+
#[serde(rename = "type")]
79+
variant: EnumType,
80+
#[serde(flatten)]
81+
rest: serde_json::Value,
82+
}
83+
#[derive(Deserialize)]
84+
pub enum EnumType {
85+
BPE,
86+
WordPiece,
87+
WordLevel,
88+
Unigram,
89+
}
90+
91+
#[derive(Deserialize)]
92+
#[serde(untagged)]
93+
pub enum ModelHelper {
94+
Tagged(Tagged),
95+
Legacy(serde_json::Value),
96+
}
97+
98+
#[derive(Deserialize)]
99+
#[serde(untagged)]
100+
pub enum ModelUntagged {
101+
BPE(BPE),
102+
// WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
103+
// with the versions not including the "type"), since WordLevel is a subset of WordPiece
104+
WordPiece(WordPiece),
105+
WordLevel(WordLevel),
106+
Unigram(Unigram),
107+
}
108+
109+
let helper = ModelHelper::deserialize(deserializer)?;
110+
Ok(match helper {
111+
ModelHelper::Tagged(bpe) => match bpe.variant {
112+
EnumType::BPE => ModelWrapper::BPE(
113+
serde_json::from_value(bpe.rest).map_err(serde::de::Error::custom)?,
114+
),
115+
EnumType::WordPiece => ModelWrapper::WordPiece(
116+
serde_json::from_value(bpe.rest).map_err(serde::de::Error::custom)?,
117+
),
118+
EnumType::WordLevel => ModelWrapper::WordLevel(
119+
serde_json::from_value(bpe.rest).map_err(serde::de::Error::custom)?,
120+
),
121+
EnumType::Unigram => ModelWrapper::Unigram(
122+
serde_json::from_value(bpe.rest).map_err(serde::de::Error::custom)?,
123+
),
124+
},
125+
ModelHelper::Legacy(value) => {
126+
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
127+
match untagged {
128+
ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe),
129+
ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe),
130+
ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe),
131+
ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe),
132+
}
133+
}
134+
})
135+
}
136+
}
137+
71138
impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
72139
impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
73140
impl_enum_from!(BPE, ModelWrapper, BPE);
@@ -263,10 +330,7 @@ mod tests {
263330
let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
264331
serde_json::from_str(invalid);
265332
match reconstructed {
266-
Err(err) => assert_eq!(
267-
err.to_string(),
268-
"data did not match any variant of untagged enum ModelWrapper"
269-
),
333+
Err(err) => assert_eq!(err.to_string(), "Merges text file invalid at line 1"),
270334
_ => panic!("Expected an error here"),
271335
}
272336
}

0 commit comments

Comments
 (0)