Skip to content

Commit 4bb595b

Browse files
committed
feat: add __setitem__ impl to post_processor::PySequence
1 parent 1a31fc9 commit 4bb595b

File tree

3 files changed

+76
-63
lines changed

3 files changed

+76
-63
lines changed

bindings/python/src/pre_tokenizers.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,8 @@ impl PySequence {
510510
}
511511

512512
fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> {
513-
let norm: PyPreTokenizer = value.extract()?;
514-
let PyPreTokenizerTypeWrapper::Single(norm) = norm.pretok else { return Err(PyException::new_err("normalizer should not be a sequence")); };
513+
let pretok: PyPreTokenizer = value.extract()?;
514+
let PyPreTokenizerTypeWrapper::Single(norm) = pretok.pretok else { return Err(PyException::new_err("normalizer should not be a sequence")); };
515515
match &self_.as_ref().pretok {
516516
PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) {
517517
Some(item) => {
@@ -524,7 +524,7 @@ impl PySequence {
524524
}
525525
},
526526
PyPreTokenizerTypeWrapper::Single(_) => {
527-
return Err(PyException::new_err("normalizer is not a sequence"))
527+
return Err(PyException::new_err("pre tokenizer is not a sequence"))
528528
}
529529
};
530530
Ok(())

bindings/python/src/processors.rs

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -52,28 +52,34 @@ impl PyPostProcessor {
5252

5353
pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult<PyObject> {
5454
let base = self.clone();
55-
Ok(match self.processor.read().unwrap().clone() {
56-
PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?
57-
.into_pyobject(py)?
58-
.into_any()
59-
.into(),
60-
PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?
61-
.into_pyobject(py)?
62-
.into_any()
63-
.into(),
64-
PostProcessorWrapper::Roberta(_) => Py::new(py, (PyRobertaProcessing {}, base))?
65-
.into_pyobject(py)?
66-
.into_any()
67-
.into(),
68-
PostProcessorWrapper::Template(_) => Py::new(py, (PyTemplateProcessing {}, base))?
69-
.into_pyobject(py)?
70-
.into_any()
71-
.into(),
72-
PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?
73-
.into_pyobject(py)?
74-
.into_any()
75-
.into(),
76-
})
55+
Ok(
56+
match &*self
57+
.processor
58+
.read()
59+
.map_err(|_| PyException::new_err("pre tokenizer rwlock is poisoned"))?
60+
{
61+
PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?
62+
.into_pyobject(py)?
63+
.into_any()
64+
.into(),
65+
PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?
66+
.into_pyobject(py)?
67+
.into_any()
68+
.into(),
69+
PostProcessorWrapper::Roberta(_) => Py::new(py, (PyRobertaProcessing {}, base))?
70+
.into_pyobject(py)?
71+
.into_any()
72+
.into(),
73+
PostProcessorWrapper::Template(_) => Py::new(py, (PyTemplateProcessing {}, base))?
74+
.into_pyobject(py)?
75+
.into_any()
76+
.into(),
77+
PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?
78+
.into_pyobject(py)?
79+
.into_any()
80+
.into(),
81+
},
82+
)
7783
}
7884
}
7985

@@ -538,19 +544,14 @@ impl PySequence {
538544
}
539545

540546
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
541-
let super_ = self_.as_ref();
542-
let mut wrapper = super_.processor.write().unwrap();
543-
// if let PostProcessorWrapper::Sequence(ref mut post) = *wrapper {
544-
// match post.get(index) {
545-
// Some(item) => PyPostProcessor::new(Arc::clone(item)).get_as_subtype(py),
546-
// _ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
547-
// "Index not found",
548-
// )),
549-
// }
550-
// }
547+
let wrapper = self_
548+
.as_ref()
549+
.processor
550+
.read()
551+
.map_err(|_| PyException::new_err("post processor rwlock is poisoned"))?;
551552

552553
match *wrapper {
553-
PostProcessorWrapper::Sequence(ref mut inner) => match inner.get_mut(index) {
554+
PostProcessorWrapper::Sequence(ref inner) => match inner.get(index) {
554555
Some(item) => {
555556
PyPostProcessor::new(Arc::new(RwLock::new(item.to_owned()))).get_as_subtype(py)
556557
}
@@ -564,32 +565,31 @@ impl PySequence {
564565
}
565566
}
566567

567-
fn __setitem__(
568-
self_: PyRefMut<'_, Self>,
569-
index: usize,
570-
value: PyRef<'_, PyPostProcessor>,
571-
) -> PyResult<()> {
572-
let super_ = self_.as_ref();
573-
let mut wrapper = super_.processor.write().unwrap();
574-
let value = value.processor.read().unwrap().clone();
568+
fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> {
569+
let processor: PyPostProcessor = value.extract()?;
570+
let mut wrapper = self_
571+
.as_ref()
572+
.processor
573+
.write()
574+
.map_err(|_| PyException::new_err("post processor rwlock is poisoned"))?;
575575
match *wrapper {
576-
PostProcessorWrapper::Sequence(ref mut inner) => {
577-
// Convert the Py<PyAny> into the appropriate Rust type
578-
// Ensure we can set an item at the given index
579-
if index < inner.get_processors().len() {
580-
inner.set_mut(index, value); // Assuming you want to wrap the new item in Arc<RwLock>
581-
582-
Ok(())
583-
} else {
584-
Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
585-
"Index out of bounds",
576+
PostProcessorWrapper::Sequence(ref mut inner) => match inner.get_mut(index) {
577+
Some(item) => {
578+
*item = processor.processor.read().unwrap().clone();
579+
}
580+
_ => {
581+
return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
582+
"Index not found",
586583
))
587584
}
585+
},
586+
_ => {
587+
return Err(PyException::new_err(
588+
"This processor is not a Sequence, it does not support __setitem__",
589+
))
588590
}
589-
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
590-
"This processor is not a Sequence, it does not support __setitem__",
591-
)),
592-
}
591+
};
592+
Ok(())
593593
}
594594
}
595595

tokenizers/src/processors/sequence.rs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,40 @@ impl Sequence {
1414
Self { processors }
1515
}
1616

17-
pub fn get(&self, index: usize) -> Option<& PostProcessorWrapper> {
18-
self.processors.get(index as usize)
17+
pub fn get(&self, index: usize) -> Option<&PostProcessorWrapper> {
18+
self.processors.get(index)
1919
}
2020

2121
pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> {
2222
self.processors.get_mut(index)
2323
}
2424

2525
pub fn set_mut(&mut self, index: usize, post_proc: PostProcessorWrapper) {
26-
self.processors[index as usize] = post_proc;
26+
self.processors[index] = post_proc;
2727
}
28+
}
2829

29-
pub fn get_processors(&self) -> &[PostProcessorWrapper] {
30+
impl AsRef<[PostProcessorWrapper]> for Sequence {
31+
fn as_ref(&self) -> &[PostProcessorWrapper] {
3032
&self.processors
3133
}
34+
}
3235

33-
pub fn get_processors_mut(&mut self) -> &mut [PostProcessorWrapper] {
36+
impl AsMut<[PostProcessorWrapper]> for Sequence {
37+
fn as_mut(&mut self) -> &mut [PostProcessorWrapper] {
3438
&mut self.processors
3539
}
3640
}
3741

42+
impl IntoIterator for Sequence {
43+
type Item = PostProcessorWrapper;
44+
type IntoIter = std::vec::IntoIter<Self::Item>;
45+
46+
fn into_iter(self) -> Self::IntoIter {
47+
self.processors.into_iter()
48+
}
49+
}
50+
3851
impl PostProcessor for Sequence {
3952
fn added_tokens(&self, is_pair: bool) -> usize {
4053
self.processors

0 commit comments

Comments
 (0)