@@ -22,94 +22,133 @@ use tk::tokenizer::{
22
22
23
23
#[ pyclass( dict, module = "tokenizers" ) ]
24
24
pub struct AddedToken {
25
- pub token : tk:: tokenizer:: AddedToken ,
25
+ pub content : String ,
26
+ pub is_special_token : bool ,
27
+ pub single_word : Option < bool > ,
28
+ pub lstrip : Option < bool > ,
29
+ pub rstrip : Option < bool > ,
30
+ pub normalized : Option < bool > ,
26
31
}
32
+ impl AddedToken {
33
+ pub fn from < S : Into < String > > ( content : S , is_special_token : Option < bool > ) -> Self {
34
+ Self {
35
+ content : content. into ( ) ,
36
+ is_special_token : is_special_token. unwrap_or ( false ) ,
37
+ single_word : None ,
38
+ lstrip : None ,
39
+ rstrip : None ,
40
+ normalized : None ,
41
+ }
42
+ }
43
+
44
+ pub fn get_token ( & self ) -> tk:: tokenizer:: AddedToken {
45
+ let mut token = tk:: AddedToken :: from ( & self . content , self . is_special_token ) ;
46
+
47
+ if let Some ( sw) = self . single_word {
48
+ token = token. single_word ( sw) ;
49
+ }
50
+ if let Some ( ls) = self . lstrip {
51
+ token = token. lstrip ( ls) ;
52
+ }
53
+ if let Some ( rs) = self . rstrip {
54
+ token = token. rstrip ( rs) ;
55
+ }
56
+ if let Some ( n) = self . normalized {
57
+ token = token. normalized ( n) ;
58
+ }
59
+
60
+ token
61
+ }
62
+
63
+ pub fn as_pydict < ' py > ( & self , py : Python < ' py > ) -> PyResult < & ' py PyDict > {
64
+ let dict = PyDict :: new ( py) ;
65
+ let token = self . get_token ( ) ;
66
+
67
+ dict. set_item ( "content" , token. content ) ?;
68
+ dict. set_item ( "single_word" , token. single_word ) ?;
69
+ dict. set_item ( "lstrip" , token. lstrip ) ?;
70
+ dict. set_item ( "rstrip" , token. rstrip ) ?;
71
+ dict. set_item ( "normalized" , token. normalized ) ?;
72
+
73
+ Ok ( dict)
74
+ }
75
+ }
76
+
27
77
#[ pymethods]
28
78
impl AddedToken {
29
79
#[ new]
30
80
#[ args( kwargs = "**" ) ]
31
- fn new ( content : & str , is_special_token : bool , kwargs : Option < & PyDict > ) -> PyResult < Self > {
32
- let mut token = tk :: tokenizer :: AddedToken :: from ( content, is_special_token ) ;
81
+ fn new ( content : Option < & str > , kwargs : Option < & PyDict > ) -> PyResult < Self > {
82
+ let mut token = AddedToken :: from ( content. unwrap_or ( "" ) , None ) ;
33
83
34
84
if let Some ( kwargs) = kwargs {
35
85
for ( key, value) in kwargs {
36
86
let key: & str = key. extract ( ) ?;
37
87
match key {
38
- "single_word" => token = token . single_word ( value. extract ( ) ?) ,
39
- "lstrip" => token = token . lstrip ( value. extract ( ) ?) ,
40
- "rstrip" => token = token . rstrip ( value. extract ( ) ?) ,
41
- "normalized" => token = token . normalized ( value. extract ( ) ?) ,
88
+ "single_word" => token. single_word = Some ( value. extract ( ) ?) ,
89
+ "lstrip" => token. lstrip = Some ( value. extract ( ) ?) ,
90
+ "rstrip" => token. rstrip = Some ( value. extract ( ) ?) ,
91
+ "normalized" => token. normalized = Some ( value. extract ( ) ?) ,
42
92
_ => println ! ( "Ignored unknown kwarg option {}" , key) ,
43
93
}
44
94
}
45
95
}
46
96
47
- Ok ( AddedToken { token } )
97
+ Ok ( token)
48
98
}
49
99
50
- fn __getstate__ ( & self , py : Python ) -> PyResult < PyObject > {
51
- let data = serde_json:: to_string ( & self . token ) . map_err ( |e| {
52
- exceptions:: Exception :: py_err ( format ! (
53
- "Error while attempting to pickle AddedToken: {}" ,
54
- e. to_string( )
55
- ) )
56
- } ) ?;
57
- Ok ( PyBytes :: new ( py, data. as_bytes ( ) ) . to_object ( py) )
100
+ fn __getstate__ < ' py > ( & self , py : Python < ' py > ) -> PyResult < & ' py PyDict > {
101
+ self . as_pydict ( py)
58
102
}
59
103
60
104
fn __setstate__ ( & mut self , py : Python , state : PyObject ) -> PyResult < ( ) > {
61
- match state. extract :: < & PyBytes > ( py) {
62
- Ok ( s) => {
63
- self . token = serde_json:: from_slice ( s. as_bytes ( ) ) . map_err ( |e| {
64
- exceptions:: Exception :: py_err ( format ! (
65
- "Error while attempting to unpickle AddedToken: {}" ,
66
- e. to_string( )
67
- ) )
68
- } ) ?;
105
+ match state. extract :: < & PyDict > ( py) {
106
+ Ok ( state) => {
107
+ for ( key, value) in state {
108
+ let key: & str = key. extract ( ) ?;
109
+ match key {
110
+ "single_word" => self . single_word = Some ( value. extract ( ) ?) ,
111
+ "lstrip" => self . lstrip = Some ( value. extract ( ) ?) ,
112
+ "rstrip" => self . rstrip = Some ( value. extract ( ) ?) ,
113
+ "normalized" => self . normalized = Some ( value. extract ( ) ?) ,
114
+ _ => { }
115
+ }
116
+ }
69
117
Ok ( ( ) )
70
118
}
71
119
Err ( e) => Err ( e) ,
72
120
}
73
121
}
74
122
75
- fn __getnewargs__ < ' p > ( & self , py : Python < ' p > ) -> PyResult < & ' p PyTuple > {
76
- // We don't really care about the values of `content` & `is_special_token` here because
77
- // they will get overriden by `__setstate__`
78
- let content: PyObject = "" . into_py ( py) ;
79
- let is_special_token: PyObject = false . into_py ( py) ;
80
- let args = PyTuple :: new ( py, vec ! [ content, is_special_token] ) ;
81
- Ok ( args)
82
- }
83
-
84
123
#[ getter]
85
124
fn get_content ( & self ) -> & str {
86
- & self . token . content
125
+ & self . content
87
126
}
88
127
89
128
#[ getter]
90
129
fn get_rstrip ( & self ) -> bool {
91
- self . token . rstrip
130
+ self . get_token ( ) . rstrip
92
131
}
93
132
94
133
#[ getter]
95
134
fn get_lstrip ( & self ) -> bool {
96
- self . token . lstrip
135
+ self . get_token ( ) . lstrip
97
136
}
98
137
99
138
#[ getter]
100
139
fn get_single_word ( & self ) -> bool {
101
- self . token . single_word
140
+ self . get_token ( ) . single_word
102
141
}
103
142
104
143
#[ getter]
105
144
fn get_normalized ( & self ) -> bool {
106
- self . token . normalized
145
+ self . get_token ( ) . normalized
107
146
}
108
147
}
109
148
#[ pyproto]
110
149
impl PyObjectProtocol for AddedToken {
111
150
fn __str__ ( & ' p self ) -> PyResult < & ' p str > {
112
- Ok ( & self . token . content )
151
+ Ok ( & self . content )
113
152
}
114
153
115
154
fn __repr__ ( & self ) -> PyResult < String > {
@@ -118,13 +157,14 @@ impl PyObjectProtocol for AddedToken {
118
157
false => "False" ,
119
158
} ;
120
159
160
+ let token = self . get_token ( ) ;
121
161
Ok ( format ! (
122
162
"AddedToken(\" {}\" , rstrip={}, lstrip={}, single_word={}, normalized={})" ,
123
- self . token . content,
124
- bool_to_python( self . token. rstrip) ,
125
- bool_to_python( self . token. lstrip) ,
126
- bool_to_python( self . token. single_word) ,
127
- bool_to_python( self . token. normalized)
163
+ self . content,
164
+ bool_to_python( token. rstrip) ,
165
+ bool_to_python( token. lstrip) ,
166
+ bool_to_python( token. single_word) ,
167
+ bool_to_python( token. normalized)
128
168
) )
129
169
}
130
170
}
@@ -583,9 +623,10 @@ impl Tokenizer {
583
623
. into_iter ( )
584
624
. map ( |token| {
585
625
if let Ok ( content) = token. extract :: < String > ( ) {
586
- Ok ( tk:: tokenizer:: AddedToken :: from ( content, false ) )
587
- } else if let Ok ( token) = token. extract :: < PyRef < AddedToken > > ( ) {
588
- Ok ( token. token . clone ( ) )
626
+ Ok ( AddedToken :: from ( content, Some ( false ) ) . get_token ( ) )
627
+ } else if let Ok ( mut token) = token. extract :: < PyRefMut < AddedToken > > ( ) {
628
+ token. is_special_token = false ;
629
+ Ok ( token. get_token ( ) )
589
630
} else {
590
631
Err ( exceptions:: Exception :: py_err (
591
632
"Input must be a List[Union[str, AddedToken]]" ,
@@ -603,8 +644,9 @@ impl Tokenizer {
603
644
. map ( |token| {
604
645
if let Ok ( content) = token. extract :: < String > ( ) {
605
646
Ok ( tk:: tokenizer:: AddedToken :: from ( content, true ) )
606
- } else if let Ok ( token) = token. extract :: < PyRef < AddedToken > > ( ) {
607
- Ok ( token. token . clone ( ) )
647
+ } else if let Ok ( mut token) = token. extract :: < PyRefMut < AddedToken > > ( ) {
648
+ token. is_special_token = true ;
649
+ Ok ( token. get_token ( ) )
608
650
} else {
609
651
Err ( exceptions:: Exception :: py_err (
610
652
"Input must be a List[Union[str, AddedToken]]" ,
0 commit comments