Skip to content

Commit c02d4e2

Browse files
committed
Python - Improve AddedToken interface
1 parent a14cd7b commit c02d4e2

File tree

5 files changed

+125
-87
lines changed

5 files changed

+125
-87
lines changed

bindings/python/src/tokenizer.rs

Lines changed: 91 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,94 +22,133 @@ use tk::tokenizer::{
2222

2323
#[pyclass(dict, module = "tokenizers")]
2424
pub struct AddedToken {
25-
pub token: tk::tokenizer::AddedToken,
25+
pub content: String,
26+
pub is_special_token: bool,
27+
pub single_word: Option<bool>,
28+
pub lstrip: Option<bool>,
29+
pub rstrip: Option<bool>,
30+
pub normalized: Option<bool>,
2631
}
32+
impl AddedToken {
33+
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
34+
Self {
35+
content: content.into(),
36+
is_special_token: is_special_token.unwrap_or(false),
37+
single_word: None,
38+
lstrip: None,
39+
rstrip: None,
40+
normalized: None,
41+
}
42+
}
43+
44+
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
45+
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
46+
47+
if let Some(sw) = self.single_word {
48+
token = token.single_word(sw);
49+
}
50+
if let Some(ls) = self.lstrip {
51+
token = token.lstrip(ls);
52+
}
53+
if let Some(rs) = self.rstrip {
54+
token = token.rstrip(rs);
55+
}
56+
if let Some(n) = self.normalized {
57+
token = token.normalized(n);
58+
}
59+
60+
token
61+
}
62+
63+
pub fn as_pydict<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> {
64+
let dict = PyDict::new(py);
65+
let token = self.get_token();
66+
67+
dict.set_item("content", token.content)?;
68+
dict.set_item("single_word", token.single_word)?;
69+
dict.set_item("lstrip", token.lstrip)?;
70+
dict.set_item("rstrip", token.rstrip)?;
71+
dict.set_item("normalized", token.normalized)?;
72+
73+
Ok(dict)
74+
}
75+
}
76+
2777
#[pymethods]
2878
impl AddedToken {
2979
#[new]
3080
#[args(kwargs = "**")]
31-
fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> {
32-
let mut token = tk::tokenizer::AddedToken::from(content, is_special_token);
81+
fn new(content: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<Self> {
82+
let mut token = AddedToken::from(content.unwrap_or(""), None);
3383

3484
if let Some(kwargs) = kwargs {
3585
for (key, value) in kwargs {
3686
let key: &str = key.extract()?;
3787
match key {
38-
"single_word" => token = token.single_word(value.extract()?),
39-
"lstrip" => token = token.lstrip(value.extract()?),
40-
"rstrip" => token = token.rstrip(value.extract()?),
41-
"normalized" => token = token.normalized(value.extract()?),
88+
"single_word" => token.single_word = Some(value.extract()?),
89+
"lstrip" => token.lstrip = Some(value.extract()?),
90+
"rstrip" => token.rstrip = Some(value.extract()?),
91+
"normalized" => token.normalized = Some(value.extract()?),
4292
_ => println!("Ignored unknown kwarg option {}", key),
4393
}
4494
}
4595
}
4696

47-
Ok(AddedToken { token })
97+
Ok(token)
4898
}
4999

50-
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
51-
let data = serde_json::to_string(&self.token).map_err(|e| {
52-
exceptions::Exception::py_err(format!(
53-
"Error while attempting to pickle AddedToken: {}",
54-
e.to_string()
55-
))
56-
})?;
57-
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
100+
fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> {
101+
self.as_pydict(py)
58102
}
59103

60104
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
61-
match state.extract::<&PyBytes>(py) {
62-
Ok(s) => {
63-
self.token = serde_json::from_slice(s.as_bytes()).map_err(|e| {
64-
exceptions::Exception::py_err(format!(
65-
"Error while attempting to unpickle AddedToken: {}",
66-
e.to_string()
67-
))
68-
})?;
105+
match state.extract::<&PyDict>(py) {
106+
Ok(state) => {
107+
for (key, value) in state {
108+
let key: &str = key.extract()?;
109+
match key {
110+
"single_word" => self.single_word = Some(value.extract()?),
111+
"lstrip" => self.lstrip = Some(value.extract()?),
112+
"rstrip" => self.rstrip = Some(value.extract()?),
113+
"normalized" => self.normalized = Some(value.extract()?),
114+
_ => {}
115+
}
116+
}
69117
Ok(())
70118
}
71119
Err(e) => Err(e),
72120
}
73121
}
74122

75-
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
76-
// We don't really care about the values of `content` & `is_special_token` here because
77-
// they will get overriden by `__setstate__`
78-
let content: PyObject = "".into_py(py);
79-
let is_special_token: PyObject = false.into_py(py);
80-
let args = PyTuple::new(py, vec![content, is_special_token]);
81-
Ok(args)
82-
}
83-
84123
#[getter]
85124
fn get_content(&self) -> &str {
86-
&self.token.content
125+
&self.content
87126
}
88127

89128
#[getter]
90129
fn get_rstrip(&self) -> bool {
91-
self.token.rstrip
130+
self.get_token().rstrip
92131
}
93132

94133
#[getter]
95134
fn get_lstrip(&self) -> bool {
96-
self.token.lstrip
135+
self.get_token().lstrip
97136
}
98137

99138
#[getter]
100139
fn get_single_word(&self) -> bool {
101-
self.token.single_word
140+
self.get_token().single_word
102141
}
103142

104143
#[getter]
105144
fn get_normalized(&self) -> bool {
106-
self.token.normalized
145+
self.get_token().normalized
107146
}
108147
}
109148
#[pyproto]
110149
impl PyObjectProtocol for AddedToken {
111150
fn __str__(&'p self) -> PyResult<&'p str> {
112-
Ok(&self.token.content)
151+
Ok(&self.content)
113152
}
114153

115154
fn __repr__(&self) -> PyResult<String> {
@@ -118,13 +157,14 @@ impl PyObjectProtocol for AddedToken {
118157
false => "False",
119158
};
120159

160+
let token = self.get_token();
121161
Ok(format!(
122162
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})",
123-
self.token.content,
124-
bool_to_python(self.token.rstrip),
125-
bool_to_python(self.token.lstrip),
126-
bool_to_python(self.token.single_word),
127-
bool_to_python(self.token.normalized)
163+
self.content,
164+
bool_to_python(token.rstrip),
165+
bool_to_python(token.lstrip),
166+
bool_to_python(token.single_word),
167+
bool_to_python(token.normalized)
128168
))
129169
}
130170
}
@@ -583,9 +623,10 @@ impl Tokenizer {
583623
.into_iter()
584624
.map(|token| {
585625
if let Ok(content) = token.extract::<String>() {
586-
Ok(tk::tokenizer::AddedToken::from(content, false))
587-
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
588-
Ok(token.token.clone())
626+
Ok(AddedToken::from(content, Some(false)).get_token())
627+
} else if let Ok(mut token) = token.extract::<PyRefMut<AddedToken>>() {
628+
token.is_special_token = false;
629+
Ok(token.get_token())
589630
} else {
590631
Err(exceptions::Exception::py_err(
591632
"Input must be a List[Union[str, AddedToken]]",
@@ -603,8 +644,9 @@ impl Tokenizer {
603644
.map(|token| {
604645
if let Ok(content) = token.extract::<String>() {
605646
Ok(tk::tokenizer::AddedToken::from(content, true))
606-
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
607-
Ok(token.token.clone())
647+
} else if let Ok(mut token) = token.extract::<PyRefMut<AddedToken>>() {
648+
token.is_special_token = true;
649+
Ok(token.get_token())
608650
} else {
609651
Err(exceptions::Exception::py_err(
610652
"Input must be a List[Union[str, AddedToken]]",

bindings/python/src/trainers.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ impl BpeTrainer {
3636
.into_iter()
3737
.map(|token| {
3838
if let Ok(content) = token.extract::<String>() {
39-
Ok(tk::tokenizer::AddedToken {
40-
content,
41-
..Default::default()
42-
})
43-
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
44-
Ok(token.token.clone())
39+
Ok(AddedToken::from(content, Some(true)).get_token())
40+
} else if let Ok(mut token) =
41+
token.extract::<PyRefMut<AddedToken>>()
42+
{
43+
token.is_special_token = true;
44+
Ok(token.get_token())
4545
} else {
4646
Err(exceptions::Exception::py_err(
4747
"special_tokens must be a List[Union[str, AddedToken]]",
@@ -105,12 +105,12 @@ impl WordPieceTrainer {
105105
.into_iter()
106106
.map(|token| {
107107
if let Ok(content) = token.extract::<String>() {
108-
Ok(tk::tokenizer::AddedToken {
109-
content,
110-
..Default::default()
111-
})
112-
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
113-
Ok(token.token.clone())
108+
Ok(AddedToken::from(content, Some(true)).get_token())
109+
} else if let Ok(mut token) =
110+
token.extract::<PyRefMut<AddedToken>>()
111+
{
112+
token.is_special_token = true;
113+
Ok(token.get_token())
114114
} else {
115115
Err(exceptions::Exception::py_err(
116116
"special_tokens must be a List[Union[str, AddedToken]]",

bindings/python/tests/bindings/test_tokenizer.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,49 +12,46 @@
1212

1313
class TestAddedToken:
1414
def test_instantiate_with_content_only(self):
15-
added_token = AddedToken("<mask>", True)
15+
added_token = AddedToken("<mask>")
1616
assert type(added_token) == AddedToken
1717
assert str(added_token) == "<mask>"
1818
assert (
1919
repr(added_token)
20-
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False)'
20+
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True)'
2121
)
2222
assert added_token.rstrip == False
2323
assert added_token.lstrip == False
2424
assert added_token.single_word == False
25-
assert added_token.normalized == False
25+
assert added_token.normalized == True
2626
assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken)
2727

2828
def test_can_set_rstrip(self):
29-
added_token = AddedToken("<mask>", True, rstrip=True)
29+
added_token = AddedToken("<mask>", rstrip=True)
3030
assert added_token.rstrip == True
3131
assert added_token.lstrip == False
3232
assert added_token.single_word == False
33+
assert added_token.normalized == True
3334

3435
def test_can_set_lstrip(self):
35-
added_token = AddedToken("<mask>", True, lstrip=True)
36+
added_token = AddedToken("<mask>", lstrip=True)
3637
assert added_token.rstrip == False
3738
assert added_token.lstrip == True
3839
assert added_token.single_word == False
40+
assert added_token.normalized == True
3941

4042
def test_can_set_single_world(self):
41-
added_token = AddedToken("<mask>", True, single_word=True)
43+
added_token = AddedToken("<mask>", single_word=True)
4244
assert added_token.rstrip == False
4345
assert added_token.lstrip == False
4446
assert added_token.single_word == True
47+
assert added_token.normalized == True
4548

4649
def test_can_set_normalized(self):
47-
added_token = AddedToken("<mask>", True, normalized=True)
50+
added_token = AddedToken("<mask>", normalized=False)
4851
assert added_token.rstrip == False
4952
assert added_token.lstrip == False
5053
assert added_token.single_word == False
51-
assert added_token.normalized == True
52-
53-
def test_second_argument_defines_normalized(self):
54-
added_token = AddedToken("<mask>", True)
5554
assert added_token.normalized == False
56-
added_token = AddedToken("<mask>", False)
57-
assert added_token.normalized == True
5855

5956

6057
class TestTokenizer:
@@ -91,10 +88,12 @@ def test_add_tokens(self):
9188
added = tokenizer.add_tokens(["my", "name", "is", "john"])
9289
assert added == 4
9390

94-
added = tokenizer.add_tokens(
95-
[AddedToken("the", False), AddedToken("quick", False, rstrip=True)]
96-
)
91+
tokens = [AddedToken("the"), AddedToken("quick", normalized=False), AddedToken()]
92+
assert tokens[0].normalized == True
93+
added = tokenizer.add_tokens(tokens)
9794
assert added == 2
95+
assert tokens[0].normalized == True
96+
assert tokens[1].normalized == False
9897

9998
def test_add_special_tokens(self):
10099
tokenizer = Tokenizer(BPE())
@@ -104,10 +103,12 @@ def test_add_special_tokens(self):
104103
assert added == 4
105104

106105
# Can add special tokens as `AddedToken`
107-
added = tokenizer.add_special_tokens(
108-
[AddedToken("the", False), AddedToken("quick", False, rstrip=True)]
109-
)
106+
tokens = [AddedToken("the"), AddedToken("quick", normalized=True), AddedToken()]
107+
assert tokens[0].normalized == True
108+
added = tokenizer.add_special_tokens(tokens)
110109
assert added == 2
110+
assert tokens[0].normalized == False
111+
assert tokens[1].normalized == True
111112

112113
def test_encode(self):
113114
tokenizer = Tokenizer(BPE())

bindings/python/tokenizers/__init__.pyi

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,7 @@ class AddedToken:
201201

202202
def __new__(
203203
cls,
204-
content: str,
205-
is_special_token: bool,
204+
content: str = "",
206205
single_word: bool = False,
207206
lstrip: bool = False,
208207
rstrip: bool = False,
@@ -214,10 +213,6 @@ class AddedToken:
214213
content: str:
215214
The content of the token
216215
217-
is_special_token: bool:
218-
Whether this token is a special token. This has an impact on the default value for
219-
`normalized` which is False for special tokens, but True for others.
220-
221216
single_word: bool
222217
Whether this token should only match against single words. If True,
223218
this token will never match inside of a word. For example the token `ing` would

tokenizers/src/tokenizer/added_vocabulary.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ impl AddedVocabulary {
219219
normalizer: Option<&dyn Normalizer>,
220220
) -> usize {
221221
for token in tokens {
222-
if !self.special_tokens_set.contains(&token.content) {
222+
if !token.content.is_empty() && !self.special_tokens_set.contains(&token.content) {
223223
self.special_tokens.push(token.to_owned());
224224
self.special_tokens_set.insert(token.content.clone());
225225
}

0 commit comments

Comments
 (0)