@@ -8,7 +8,7 @@ pub mod wordpiece;
8
8
use std:: collections:: HashMap ;
9
9
use std:: path:: { Path , PathBuf } ;
10
10
11
- use serde:: { Deserialize , Serialize , Serializer } ;
11
+ use serde:: { Deserialize , Deserializer , Serialize , Serializer } ;
12
12
13
13
use crate :: models:: bpe:: { BpeTrainer , BPE } ;
14
14
use crate :: models:: unigram:: { Unigram , UnigramTrainer } ;
@@ -57,7 +57,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
57
57
}
58
58
}
59
59
60
- #[ derive( Deserialize , Serialize , Debug , PartialEq , Clone ) ]
60
+ #[ derive( Serialize , Debug , PartialEq , Clone ) ]
61
61
#[ serde( untagged) ]
62
62
pub enum ModelWrapper {
63
63
BPE ( BPE ) ,
@@ -68,6 +68,73 @@ pub enum ModelWrapper {
68
68
Unigram ( Unigram ) ,
69
69
}
70
70
71
+ impl < ' de > Deserialize < ' de > for ModelWrapper {
72
+ fn deserialize < D > ( deserializer : D ) -> std:: result:: Result < Self , D :: Error >
73
+ where
74
+ D : Deserializer < ' de > ,
75
+ {
76
+ #[ derive( Deserialize ) ]
77
+ pub struct Tagged {
78
+ #[ serde( rename = "type" ) ]
79
+ variant : EnumType ,
80
+ #[ serde( flatten) ]
81
+ rest : serde_json:: Value ,
82
+ }
83
+ #[ derive( Deserialize ) ]
84
+ pub enum EnumType {
85
+ BPE ,
86
+ WordPiece ,
87
+ WordLevel ,
88
+ Unigram ,
89
+ }
90
+
91
+ #[ derive( Deserialize ) ]
92
+ #[ serde( untagged) ]
93
+ pub enum ModelHelper {
94
+ Tagged ( Tagged ) ,
95
+ Legacy ( serde_json:: Value ) ,
96
+ }
97
+
98
+ #[ derive( Deserialize ) ]
99
+ #[ serde( untagged) ]
100
+ pub enum ModelUntagged {
101
+ BPE ( BPE ) ,
102
+ // WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
103
+ // with the versions not including the "type"), since WordLevel is a subset of WordPiece
104
+ WordPiece ( WordPiece ) ,
105
+ WordLevel ( WordLevel ) ,
106
+ Unigram ( Unigram ) ,
107
+ }
108
+
109
+ let helper = ModelHelper :: deserialize ( deserializer) ?;
110
+ Ok ( match helper {
111
+ ModelHelper :: Tagged ( bpe) => match bpe. variant {
112
+ EnumType :: BPE => ModelWrapper :: BPE (
113
+ serde_json:: from_value ( bpe. rest ) . map_err ( serde:: de:: Error :: custom) ?,
114
+ ) ,
115
+ EnumType :: WordPiece => ModelWrapper :: WordPiece (
116
+ serde_json:: from_value ( bpe. rest ) . map_err ( serde:: de:: Error :: custom) ?,
117
+ ) ,
118
+ EnumType :: WordLevel => ModelWrapper :: WordLevel (
119
+ serde_json:: from_value ( bpe. rest ) . map_err ( serde:: de:: Error :: custom) ?,
120
+ ) ,
121
+ EnumType :: Unigram => ModelWrapper :: Unigram (
122
+ serde_json:: from_value ( bpe. rest ) . map_err ( serde:: de:: Error :: custom) ?,
123
+ ) ,
124
+ } ,
125
+ ModelHelper :: Legacy ( value) => {
126
+ let untagged = serde_json:: from_value ( value) . map_err ( serde:: de:: Error :: custom) ?;
127
+ match untagged {
128
+ ModelUntagged :: BPE ( bpe) => ModelWrapper :: BPE ( bpe) ,
129
+ ModelUntagged :: WordPiece ( bpe) => ModelWrapper :: WordPiece ( bpe) ,
130
+ ModelUntagged :: WordLevel ( bpe) => ModelWrapper :: WordLevel ( bpe) ,
131
+ ModelUntagged :: Unigram ( bpe) => ModelWrapper :: Unigram ( bpe) ,
132
+ }
133
+ }
134
+ } )
135
+ }
136
+ }
137
+
71
138
impl_enum_from ! ( WordLevel , ModelWrapper , WordLevel ) ;
72
139
impl_enum_from ! ( WordPiece , ModelWrapper , WordPiece ) ;
73
140
impl_enum_from ! ( BPE , ModelWrapper , BPE ) ;
@@ -263,10 +330,7 @@ mod tests {
263
330
let reconstructed: std:: result:: Result < ModelWrapper , serde_json:: Error > =
264
331
serde_json:: from_str ( invalid) ;
265
332
match reconstructed {
266
- Err ( err) => assert_eq ! (
267
- err. to_string( ) ,
268
- "data did not match any variant of untagged enum ModelWrapper"
269
- ) ,
333
+ Err ( err) => assert_eq ! ( err. to_string( ) , "Merges text file invalid at line 1" ) ,
270
334
_ => panic ! ( "Expected an error here" ) ,
271
335
}
272
336
}
0 commit comments