Skip to content

Commit feac27e

Browse files
authored
Chain type (#255)
* implement chain, chain2 and chain3 * remove chain2 and chain3, more tests * benchmarks and test flatten * fix schema * improve test coverage
1 parent 75dbe9a commit feac27e

File tree

6 files changed

+325
-0
lines changed

6 files changed

+325
-0
lines changed

pydantic_core/_pydantic_core.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ __all__ = '__version__', 'SchemaValidator', 'SchemaError', 'ValidationError', 'P
1212
__version__: str
1313

1414
class SchemaValidator:
15+
title: str
1516
def __init__(self, schema: CoreSchema, config: 'CoreConfig | None' = None) -> None: ...
1617
def validate_python(self, input: Any, strict: 'bool | None' = None, context: Any = None) -> Any: ...
1718
def isinstance_python(self, input: Any, strict: 'bool | None' = None, context: Any = None) -> bool: ...

pydantic_core/core_schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,12 @@ class WithDefaultSchema(TypedDict, total=False):
315315
ref: str
316316

317317

318+
class ChainSchema(TypedDict):
319+
type: Literal['chain']
320+
steps: List[CoreSchema]
321+
ref: NotRequired[str]
322+
323+
318324
CoreSchema = Union[
319325
AnySchema,
320326
BoolSchema,
@@ -347,4 +353,5 @@ class WithDefaultSchema(TypedDict, total=False):
347353
ArgumentsSchema,
348354
CallSchema,
349355
WithDefaultSchema,
356+
ChainSchema,
350357
]

src/validators/chain.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
use pyo3::intern;
2+
use pyo3::prelude::*;
3+
use pyo3::types::{PyDict, PyList};
4+
5+
use crate::build_tools::{py_error, SchemaDict};
6+
use crate::errors::ValResult;
7+
use crate::input::Input;
8+
use crate::questions::Question;
9+
use crate::recursion_guard::RecursionGuard;
10+
11+
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
12+
13+
#[derive(Debug, Clone)]
14+
pub struct ChainValidator {
15+
steps: Vec<CombinedValidator>,
16+
name: String,
17+
}
18+
19+
impl BuildValidator for ChainValidator {
20+
const EXPECTED_TYPE: &'static str = "chain";
21+
22+
fn build(
23+
schema: &PyDict,
24+
config: Option<&PyDict>,
25+
build_context: &mut BuildContext,
26+
) -> PyResult<CombinedValidator> {
27+
let steps: Vec<CombinedValidator> = schema
28+
.get_as_req::<&PyList>(intern!(schema.py(), "steps"))?
29+
.iter()
30+
.map(|step| build_validator_steps(step, config, build_context))
31+
.collect::<PyResult<Vec<Vec<CombinedValidator>>>>()?
32+
.into_iter()
33+
.flatten()
34+
.collect::<Vec<CombinedValidator>>();
35+
36+
match steps.len() {
37+
0 => py_error!("One or more steps are required for a chain validator"),
38+
1 => {
39+
let step = steps.into_iter().next().unwrap();
40+
Ok(step)
41+
}
42+
_ => {
43+
let descr = steps.iter().map(|v| v.get_name()).collect::<Vec<_>>().join(",");
44+
45+
Ok(Self {
46+
steps,
47+
name: format!("{}[{}]", Self::EXPECTED_TYPE, descr),
48+
}
49+
.into())
50+
}
51+
}
52+
}
53+
}
54+
55+
// either a vec of the steps from a nested `ChainValidator`, or a length-1 vec containing the validator
56+
// to be flattened into `steps` above
57+
fn build_validator_steps<'a>(
58+
step: &'a PyAny,
59+
config: Option<&'a PyDict>,
60+
build_context: &mut BuildContext,
61+
) -> PyResult<Vec<CombinedValidator>> {
62+
let validator = build_validator(step, config, build_context)?;
63+
if let CombinedValidator::Chain(chain_validator) = validator {
64+
Ok(chain_validator.steps)
65+
} else {
66+
Ok(vec![validator])
67+
}
68+
}
69+
70+
impl Validator for ChainValidator {
71+
fn validate<'s, 'data>(
72+
&'s self,
73+
py: Python<'data>,
74+
input: &'data impl Input<'data>,
75+
extra: &Extra,
76+
slots: &'data [CombinedValidator],
77+
recursion_guard: &'s mut RecursionGuard,
78+
) -> ValResult<'data, PyObject> {
79+
let mut steps_iter = self.steps.iter();
80+
let first_step = steps_iter.next().unwrap();
81+
let value = first_step.validate(py, input, extra, slots, recursion_guard)?;
82+
83+
steps_iter.try_fold(value, |v, step| {
84+
step.validate(py, v.into_ref(py), extra, slots, recursion_guard)
85+
})
86+
}
87+
88+
fn get_name(&self) -> &str {
89+
&self.name
90+
}
91+
92+
fn ask(&self, question: &Question) -> bool {
93+
// any makes more sense since at the moment we only use ask for "return_fields_set", might need
94+
// more complex logic in future
95+
self.steps.iter().any(|v| v.ask(question))
96+
}
97+
98+
fn complete(&mut self, build_context: &BuildContext) -> PyResult<()> {
99+
self.steps.iter_mut().try_for_each(|v| v.complete(build_context))
100+
}
101+
}

src/validators/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mod bool;
2222
mod bytes;
2323
mod call;
2424
mod callable;
25+
mod chain;
2526
mod date;
2627
mod datetime;
2728
mod dict;
@@ -51,6 +52,7 @@ pub struct SchemaValidator {
5152
validator: CombinedValidator,
5253
slots: Vec<CombinedValidator>,
5354
schema: PyObject,
55+
#[pyo3(get)]
5456
title: PyObject,
5557
}
5658

@@ -380,6 +382,8 @@ pub fn build_validator<'a>(
380382
arguments::ArgumentsValidator,
381383
// default value
382384
with_default::WithDefaultValidator,
385+
// chain validators
386+
chain::ChainValidator,
383387
)
384388
}
385389

@@ -490,6 +494,8 @@ pub enum CombinedValidator {
490494
Arguments(arguments::ArgumentsValidator),
491495
// default value
492496
WithDefault(with_default::WithDefaultValidator),
497+
// chain validators
498+
Chain(chain::ChainValidator),
493499
}
494500

495501
/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,

tests/benchmarks/test_micro_benchmarks.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import platform
77
from datetime import date, datetime, timedelta, timezone
8+
from decimal import Decimal
89
from typing import Dict, FrozenSet, List, Optional, Set, Union
910

1011
import pytest
@@ -890,3 +891,66 @@ def test_with_default(benchmark):
890891
def t():
891892
v.validate_python({'name': 'Foo'})
892893
v.validate_python({})
894+
895+
896+
@pytest.mark.benchmark(group='chain')
897+
def test_chain_list(benchmark):
898+
validator = SchemaValidator(
899+
{
900+
'type': 'chain',
901+
'steps': [
902+
{'type': 'str'},
903+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: Decimal(v)},
904+
],
905+
}
906+
)
907+
assert validator.validate_python('42.42') == Decimal('42.42')
908+
909+
benchmark(validator.validate_python, '42.42')
910+
911+
912+
@pytest.mark.benchmark(group='chain')
913+
def test_chain_function(benchmark):
914+
validator = SchemaValidator(
915+
{'type': 'function', 'mode': 'after', 'schema': {'type': 'str'}, 'function': lambda v, **kwargs: Decimal(v)}
916+
)
917+
assert validator.validate_python('42.42') == Decimal('42.42')
918+
919+
benchmark(validator.validate_python, '42.42')
920+
921+
922+
@pytest.mark.benchmark(group='chain-functions')
923+
def test_chain_two_functions(benchmark):
924+
validator = SchemaValidator(
925+
{
926+
'type': 'chain',
927+
'steps': [
928+
{'type': 'str'},
929+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: Decimal(v)},
930+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: v * 2},
931+
],
932+
}
933+
)
934+
assert validator.validate_python('42.42') == Decimal('84.84')
935+
936+
benchmark(validator.validate_python, '42.42')
937+
938+
939+
@pytest.mark.benchmark(group='chain-functions')
940+
def test_chain_nested_functions(benchmark):
941+
validator = SchemaValidator(
942+
{
943+
'type': 'function',
944+
'schema': {
945+
'type': 'function',
946+
'schema': {'type': 'str'},
947+
'mode': 'after',
948+
'function': lambda v, **kwargs: Decimal(v),
949+
},
950+
'mode': 'after',
951+
'function': lambda v, **kwargs: v * 2,
952+
}
953+
)
954+
assert validator.validate_python('42.42') == Decimal('84.84')
955+
956+
benchmark(validator.validate_python, '42.42')

tests/validators/test_chain.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import re
2+
from decimal import Decimal
3+
4+
import pytest
5+
6+
from pydantic_core import SchemaError, SchemaValidator, ValidationError
7+
8+
from ..conftest import PyAndJson, plain_repr
9+
10+
11+
def test_chain():
12+
validator = SchemaValidator(
13+
{
14+
'type': 'chain',
15+
'steps': [
16+
{'type': 'str'},
17+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: Decimal(v)},
18+
],
19+
}
20+
)
21+
22+
assert validator.validate_python('1.44') == Decimal('1.44')
23+
assert validator.validate_python(b'1.44') == Decimal('1.44')
24+
25+
26+
def test_chain_many():
27+
validator = SchemaValidator(
28+
{
29+
'type': 'chain',
30+
'steps': [
31+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: f'{v}-1'},
32+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: f'{v}-2'},
33+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: f'{v}-3'},
34+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: f'{v}-4'},
35+
],
36+
}
37+
)
38+
39+
assert validator.validate_python('input') == 'input-1-2-3-4'
40+
41+
42+
def test_chain_error():
43+
validator = SchemaValidator({'type': 'chain', 'steps': [{'type': 'str'}, {'type': 'int'}]})
44+
45+
assert validator.validate_python('123') == 123
46+
assert validator.validate_python(b'123') == 123
47+
48+
with pytest.raises(ValidationError) as exc_info:
49+
validator.validate_python('abc')
50+
# insert_assert(exc_info.value.errors())
51+
assert exc_info.value.errors() == [
52+
{
53+
'kind': 'int_parsing',
54+
'loc': [],
55+
'message': 'Input should be a valid integer, unable to parse string as an integer',
56+
'input_value': 'abc',
57+
}
58+
]
59+
60+
61+
@pytest.mark.parametrize(
62+
'input_value,expected', [('1.44', Decimal('1.44')), (1, Decimal(1)), (1.44, pytest.approx(1.44))]
63+
)
64+
def test_json(py_and_json: PyAndJson, input_value, expected):
65+
validator = py_and_json(
66+
{
67+
'type': 'chain',
68+
'steps': [
69+
{'type': 'union', 'choices': [{'type': 'str'}, {'type': 'float'}]},
70+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: Decimal(v)},
71+
],
72+
}
73+
)
74+
output = validator.validate_test(input_value)
75+
assert output == expected
76+
assert isinstance(output, Decimal)
77+
78+
79+
def test_flatten():
80+
validator = SchemaValidator(
81+
{
82+
'type': 'chain',
83+
'steps': [
84+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: f'{v}-1'},
85+
{
86+
'type': 'chain',
87+
'steps': [
88+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: f'{v}-2'},
89+
{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: f'{v}-3'},
90+
],
91+
},
92+
],
93+
}
94+
)
95+
96+
assert validator.validate_python('input') == 'input-1-2-3'
97+
assert validator.title == 'chain[function-plain,function-plain,function-plain]'
98+
99+
100+
def test_chain_empty():
101+
with pytest.raises(SchemaError, match='One or more steps are required for a chain validator'):
102+
SchemaValidator({'type': 'chain', 'steps': []})
103+
104+
105+
def test_chain_one():
106+
validator = SchemaValidator(
107+
{'type': 'chain', 'steps': [{'type': 'function', 'mode': 'plain', 'function': lambda v, **kwargs: f'{v}-1'}]}
108+
)
109+
assert validator.validate_python('input') == 'input-1'
110+
assert validator.title == 'function-plain'
111+
112+
113+
def test_ask():
114+
class MyModel:
115+
__slots__ = '__dict__', '__fields_set__'
116+
117+
calls = []
118+
119+
def f(input_value, **kwargs):
120+
calls.append(input_value)
121+
return input_value
122+
123+
v = SchemaValidator(
124+
{
125+
'type': 'new-class',
126+
'class_type': MyModel,
127+
'schema': {
128+
'type': 'chain',
129+
'steps': [
130+
{
131+
'type': 'typed-dict',
132+
'return_fields_set': True,
133+
'fields': {'field_a': {'schema': {'type': 'str'}}},
134+
},
135+
{'type': 'function', 'mode': 'plain', 'function': f},
136+
],
137+
},
138+
}
139+
)
140+
assert re.search('expect_fields_set:(true|false)', plain_repr(v)).group(1) == 'true'
141+
m = v.validate_python({'field_a': 'abc'})
142+
assert isinstance(m, MyModel)
143+
assert m.field_a == 'abc'
144+
assert m.__fields_set__ == {'field_a'}
145+
# insert_assert(calls)
146+
assert calls == [({'field_a': 'abc'}, {'field_a'})]

0 commit comments

Comments
 (0)