From b3728649baee2d33ed51e5940e9910298ccf19c3 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 24 May 2025 17:35:36 +0100 Subject: [PATCH 1/4] avoid adding trailing slash to URLs --- python/pydantic_core/_pydantic_core.pyi | 22 +++++-- python/pydantic_core/core_schema.py | 4 ++ src/url.rs | 83 +++++++++++++++++-------- src/validators/url.rs | 80 +++++++++++++++--------- tests/validators/test_url.py | 53 +++++++++++++++- 5 files changed, 180 insertions(+), 62 deletions(-) diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 398020bb1..20144f49e 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -522,8 +522,15 @@ class Url(SupportsAllComparisons): by Mozilla. """ - def __init__(self, url: str) -> None: ... - def __new__(cls, url: str) -> Self: ... + def __init__(self, url: str, *, extra_trailing_slash: bool = True) -> None: + """Initialize a new URL object. + + Args: + url: The URL string to parse. + extra_trailing_slash: Whether to add an extra trailing slash to the URL, defaults to `True` for + backward compatibility, default will change to `False` in v3 version. + """ + def __new__(cls, url: str, *, extra_trailing_slash: bool = True) -> Self: ... @property def scheme(self) -> str: ... @property @@ -568,8 +575,15 @@ class MultiHostUrl(SupportsAllComparisons): by Mozilla. """ - def __init__(self, url: str) -> None: ... - def __new__(cls, url: str) -> Self: ... + def __init__(self, url: str, *, extra_trailing_slash: bool = True) -> None: + """Initialize a new MultiHostUrl object. + + Args: + url: The URL string to parse. + extra_trailing_slash: Whether to add an extra trailing slash to the URL, defaults to `True` for + backward compatibility, default will change to `False` in v3 version. + """ + def __new__(cls, url: str, *, extra_trailing_slash: bool = True) -> Self: ... @property def scheme(self) -> str: ... @property diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index f03daac36..4af684240 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3838,6 +3838,7 @@ def url_schema( default_host: str | None = None, default_port: int | None = None, default_path: str | None = None, + extra_trailing_slash: bool | None = None, strict: bool | None = None, ref: str | None = None, metadata: dict[str, Any] | None = None, @@ -3862,6 +3863,8 @@ def url_schema( default_host: The default host to use if the URL does not have a host default_port: The default port to use if the URL does not have a port default_path: The default path to use if the URL does not have a path + extra_trailing_slash: Whether to add an extra trailing slash to the URL, defaults to `True` for + backward compatibility, default will change to `False` in v3 version. strict: Whether to use strict URL parsing ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -3879,6 +3882,7 @@ def url_schema( ref=ref, metadata=metadata, serialization=serialization, + extra_trailing_slash=extra_trailing_slash, ) diff --git a/src/url.rs b/src/url.rs index da63be5af..909e5aa1c 100644 --- a/src/url.rs +++ b/src/url.rs @@ -14,38 +14,65 @@ use url::Url; use crate::tools::SchemaDict; use crate::SchemaValidator; -static SCHEMA_DEFINITION_URL: GILOnceCell = GILOnceCell::new(); - #[pyclass(name = "Url", module = "pydantic_core._pydantic_core", subclass, frozen)] #[derive(Clone, Hash)] #[cfg_attr(debug_assertions, derive(Debug))] pub struct PyUrl { lib_url: Url, + remove_trailing_slash: bool, } impl PyUrl { - pub fn new(lib_url: Url) -> Self { - Self { lib_url } + pub fn new(lib_url: Url, remove_trailing_slash: bool) -> Self { + Self { + lib_url, + remove_trailing_slash, + } } pub fn url(&self) -> &Url { &self.lib_url } + + pub fn mut_url(&mut self) -> &mut Url { + &mut self.lib_url + } +} + +impl From for Url { + fn from(value: PyUrl) -> Url { + value.lib_url + } +} + +static SCHEMA_URL_SINGLE_TRUE: GILOnceCell = GILOnceCell::new(); +static SCHEMA_URL_SINGLE_FALSE: GILOnceCell = GILOnceCell::new(); +static SCHEMA_URL_MULTI_TRUE: GILOnceCell = GILOnceCell::new(); +static SCHEMA_URL_MULTI_FALSE: GILOnceCell = GILOnceCell::new(); + +fn get_schema_validator(py: Python<'_>, multi_host: bool, extra_trailing_slash: bool) -> &SchemaValidator { + match (multi_host, extra_trailing_slash) { + (false, true) => SCHEMA_URL_SINGLE_TRUE.get_or_init(py, || build_schema_validator(py, "url", true)), + (false, false) => SCHEMA_URL_SINGLE_FALSE.get_or_init(py, || build_schema_validator(py, "url", false)), + (true, true) => SCHEMA_URL_MULTI_TRUE.get_or_init(py, || build_schema_validator(py, "multi-host-url", true)), + (true, false) => SCHEMA_URL_MULTI_FALSE.get_or_init(py, || build_schema_validator(py, "multi-host-url", false)), + } } -fn build_schema_validator(py: Python, schema_type: &str) -> SchemaValidator { +fn build_schema_validator(py: Python, schema_type: &str, extra_trailing_slash: bool) -> SchemaValidator { let schema = PyDict::new(py); schema.set_item("type", schema_type).unwrap(); + schema.set_item("extra_trailing_slash", extra_trailing_slash).unwrap(); SchemaValidator::py_new(py, &schema, None).unwrap() } #[pymethods] impl PyUrl { #[new] - pub fn py_new(py: Python, url: &Bound<'_, PyAny>) -> PyResult { - let schema_obj = SCHEMA_DEFINITION_URL - .get_or_init(py, || build_schema_validator(py, "url")) - .validate_python(py, url, None, None, None, None, false.into(), None, None)?; + #[pyo3(signature = (url, *, extra_trailing_slash=false))] + pub fn py_new(py: Python, url: &Bound<'_, PyAny>, extra_trailing_slash: bool) -> PyResult { + let schema_validator = get_schema_validator(py, false, extra_trailing_slash); + let schema_obj = schema_validator.validate_python(py, url, None, None, None, None, false.into(), None, None)?; schema_obj.extract(py) } @@ -114,11 +141,15 @@ impl PyUrl { // string representation of the URL, with punycode decoded when appropriate pub fn unicode_string(&self) -> String { - unicode_url(&self.lib_url) + unicode_url(&self.lib_url, self.remove_trailing_slash) } pub fn __str__(&self) -> &str { - self.lib_url.as_str() + let mut s = self.lib_url.as_str(); + if self.remove_trailing_slash && s.ends_with('/') { + s = &s[..s.len() - 1]; + } + s } pub fn __repr__(&self) -> String { @@ -201,11 +232,8 @@ pub struct PyMultiHostUrl { } impl PyMultiHostUrl { - pub fn new(ref_url: Url, extra_urls: Option>) -> Self { - Self { - ref_url: PyUrl::new(ref_url), - extra_urls, - } + pub fn new(ref_url: PyUrl, extra_urls: Option>) -> Self { + Self { ref_url, extra_urls } } pub fn lib_url(&self) -> &Url { @@ -217,15 +245,13 @@ impl PyMultiHostUrl { } } -static SCHEMA_DEFINITION_MULTI_HOST_URL: GILOnceCell = GILOnceCell::new(); - #[pymethods] impl PyMultiHostUrl { #[new] - pub fn py_new(py: Python, url: &Bound<'_, PyAny>) -> PyResult { - let schema_obj = SCHEMA_DEFINITION_MULTI_HOST_URL - .get_or_init(py, || build_schema_validator(py, "multi-host-url")) - .validate_python(py, url, None, None, None, None, false.into(), None, None)?; + #[pyo3(signature = (url, *, extra_trailing_slash=false))] + pub fn py_new(py: Python, url: &Bound<'_, PyAny>, extra_trailing_slash: bool) -> PyResult { + let schema_validator = get_schema_validator(py, true, extra_trailing_slash); + let schema_obj = schema_validator.validate_python(py, url, None, None, None, None, false.into(), None, None)?; schema_obj.extract(py) } @@ -279,13 +305,12 @@ impl PyMultiHostUrl { // special urls will have had a trailing slash added, non-special urls will not // hence we need to remove the last char if the schema is special - #[allow(clippy::bool_to_int_with_if)] - let sub = if schema_is_special(schema) { 1 } else { 0 }; + let sub: usize = (!self.ref_url.remove_trailing_slash && schema_is_special(schema)).into(); let hosts = extra_urls .iter() .map(|url| { - let str = unicode_url(url); + let str = unicode_url(url, self.ref_url.remove_trailing_slash); str[host_offset..str.len() - sub].to_string() }) .collect::>() @@ -477,10 +502,10 @@ fn host_to_dict<'a>(py: Python<'a>, lib_url: &Url) -> PyResult Ok(dict) } -fn unicode_url(lib_url: &Url) -> String { +fn unicode_url(lib_url: &Url, remove_trailing_slash: bool) -> String { let mut s = lib_url.to_string(); - match lib_url.host() { + s = match lib_url.host() { Some(url::Host::Domain(domain)) if is_punnycode_domain(lib_url, domain) => { if let Some(decoded) = decode_punycode(domain) { // replace the range containing the punycode domain with the decoded domain @@ -490,7 +515,11 @@ fn unicode_url(lib_url: &Url) -> String { s } _ => s, + }; + if remove_trailing_slash && s.ends_with('/') { + s.pop(); } + s } fn decode_punycode(domain: &str) -> Option { diff --git a/src/validators/url.rs b/src/validators/url.rs index d220d60ed..9fd0529b9 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -34,6 +34,7 @@ pub struct UrlValidator { default_port: Option, default_path: Option, name: String, + extra_trailing_slash: bool, } impl BuildValidator for UrlValidator { @@ -55,6 +56,9 @@ impl BuildValidator for UrlValidator { default_path: schema.get_as(intern!(schema.py(), "default_path"))?, allowed_schemes, name, + extra_trailing_slash: schema + .get_as(intern!(schema.py(), "extra_trailing_slash"))? + .unwrap_or(true), } .into()) } @@ -69,7 +73,7 @@ impl Validator for UrlValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let mut either_url = self.get_url(input, state.strict_or(self.strict))?; + let mut either_url = self.get_url(input, state.strict_or(self.strict), self.extra_trailing_slash)?; if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { if !allowed_schemes.contains(either_url.url().scheme()) { @@ -106,7 +110,12 @@ impl Validator for UrlValidator { } impl UrlValidator { - fn get_url<'py>(&self, input: &(impl Input<'py> + ?Sized), strict: bool) -> ValResult> { + fn get_url<'py>( + &self, + input: &(impl Input<'py> + ?Sized), + strict: bool, + extra_trailing_slash: bool, + ) -> ValResult> { match input.validate_str(strict, false) { Ok(val_match) => { let either_str = val_match.into_inner(); @@ -115,19 +124,19 @@ impl UrlValidator { self.check_length(input, url_str)?; - parse_url(url_str, input, strict).map(EitherUrl::Rust) + parse_url(url_str, input, strict, extra_trailing_slash).map(EitherUrl::Owned) } Err(_) => { // we don't need to worry about whether the url was parsed in strict mode before, // even if it was, any syntax errors would have been fixed by the first validation if let Some(py_url) = downcast_python_input::(input) { self.check_length(input, py_url.get().url().as_str())?; - Ok(EitherUrl::Py(py_url.clone())) + Ok(EitherUrl::Bound(py_url.clone())) } else if let Some(multi_host_url) = downcast_python_input::(input) { let url_str = multi_host_url.get().__str__(); self.check_length(input, &url_str)?; - parse_url(&url_str, input, strict).map(EitherUrl::Rust) + parse_url(&url_str, input, strict, extra_trailing_slash).map(EitherUrl::Owned) } else { Err(ValError::new(ErrorTypeDefaults::UrlType, input)) } @@ -151,9 +160,10 @@ impl UrlValidator { } } +// TODO do we still need this? enum EitherUrl<'py> { - Py(Bound<'py, PyUrl>), - Rust(Url), + Bound(Bound<'py, PyUrl>), + Owned(PyUrl), } impl<'py> IntoPyObject<'py> for EitherUrl<'py> { @@ -163,8 +173,8 @@ impl<'py> IntoPyObject<'py> for EitherUrl<'py> { fn into_pyobject(self, py: Python<'py>) -> PyResult { match self { - EitherUrl::Py(py_url) => Ok(py_url), - EitherUrl::Rust(rust_url) => Bound::new(py, PyUrl::new(rust_url)), + EitherUrl::Bound(py_url) => Ok(py_url), + EitherUrl::Owned(py_url) => Bound::new(py, py_url), } } } @@ -172,18 +182,18 @@ impl<'py> IntoPyObject<'py> for EitherUrl<'py> { impl CopyFromPyUrl for EitherUrl<'_> { fn url(&self) -> &Url { match self { - EitherUrl::Py(py_url) => py_url.get().url(), - EitherUrl::Rust(rust_url) => rust_url, + EitherUrl::Bound(py_url) => py_url.get().url(), + EitherUrl::Owned(rust_url) => rust_url.url(), } } fn url_mut(&mut self) -> &mut Url { - if let EitherUrl::Py(py_url) = self { - *self = EitherUrl::Rust(py_url.get().url().clone()); + if let EitherUrl::Bound(py_url) = self { + *self = EitherUrl::Owned(py_url.get().clone()); } match self { - EitherUrl::Py(_) => unreachable!(), - EitherUrl::Rust(rust_url) => rust_url, + EitherUrl::Bound(_) => unreachable!(), + EitherUrl::Owned(ref mut rust_url) => rust_url.mut_url(), } } } @@ -198,6 +208,7 @@ pub struct MultiHostUrlValidator { default_port: Option, default_path: Option, name: String, + extra_trailing_slash: bool, } impl BuildValidator for MultiHostUrlValidator { @@ -225,6 +236,9 @@ impl BuildValidator for MultiHostUrlValidator { default_port: schema.get_as(intern!(schema.py(), "default_port"))?, default_path: schema.get_as(intern!(schema.py(), "default_path"))?, name, + extra_trailing_slash: schema + .get_as(intern!(schema.py(), "extra_trailing_slash"))? + .unwrap_or(false), } .into()) } @@ -239,7 +253,7 @@ impl Validator for MultiHostUrlValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let mut multi_url = self.get_url(input, state.strict_or(self.strict))?; + let mut multi_url = self.get_url(input, state.strict_or(self.strict), self.extra_trailing_slash)?; if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { if !allowed_schemes.contains(multi_url.url().scheme()) { @@ -275,7 +289,12 @@ impl Validator for MultiHostUrlValidator { } impl MultiHostUrlValidator { - fn get_url<'py>(&self, input: &(impl Input<'py> + ?Sized), strict: bool) -> ValResult> { + fn get_url<'py>( + &self, + input: &(impl Input<'py> + ?Sized), + strict: bool, + extra_trailing_slash: bool, + ) -> ValResult> { match input.validate_str(strict, false) { Ok(val_match) => { let either_str = val_match.into_inner(); @@ -284,7 +303,7 @@ impl MultiHostUrlValidator { self.check_length(input, || url_str.len())?; - parse_multihost_url(url_str, input, strict).map(EitherMultiHostUrl::Rust) + parse_multihost_url(url_str, input, strict, extra_trailing_slash).map(EitherMultiHostUrl::Rust) } Err(_) => { // we don't need to worry about whether the url was parsed in strict mode before, @@ -295,7 +314,7 @@ impl MultiHostUrlValidator { } else if let Some(py_url) = downcast_python_input::(input) { self.check_length(input, || py_url.get().url().as_str().len())?; Ok(EitherMultiHostUrl::Rust(PyMultiHostUrl::new( - py_url.get().url().clone(), + py_url.get().clone(), None, ))) } else { @@ -365,6 +384,7 @@ fn parse_multihost_url<'py>( url_str: &str, input: &(impl Input<'py> + ?Sized), strict: bool, + extra_trailing_slash: bool, ) -> ValResult { macro_rules! parsing_err { ($parse_error:expr) => { @@ -454,21 +474,21 @@ fn parse_multihost_url<'py>( // with just one host, for consistent behaviour, we parse the URL the same as with multiple hosts let reconstructed_url = format!("{prefix}{}", &url_str[start..]); - let ref_url = parse_url(&reconstructed_url, input, strict)?; + let ref_url = parse_url(&reconstructed_url, input, strict, extra_trailing_slash)?; if hosts.is_empty() { // if there's no one host (e.g. no `,`), we allow it to be empty to allow for default hosts Ok(PyMultiHostUrl::new(ref_url, None)) } else { // with more than one host, none of them can be empty - if !ref_url.has_host() { + if !ref_url.url().has_host() { return parsing_err!(ParseError::EmptyHost); } let extra_urls: Vec = hosts .iter() .map(|host| { let reconstructed_url = format!("{prefix}{host}"); - parse_url(&reconstructed_url, input, strict) + parse_url(&reconstructed_url, input, strict, extra_trailing_slash).map(Into::into) }) .collect::>()?; @@ -480,7 +500,7 @@ fn parse_multihost_url<'py>( } } -fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool) -> ValResult { +fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool, extra_trailing_slash: bool) -> ValResult { if url_str.is_empty() { return Err(ValError::new( ErrorType::UrlParsing { @@ -490,8 +510,9 @@ fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool) -> ValResult input, )); } + let remove_trailing_slash = !extra_trailing_slash && !url_str.ends_with('/'); - // if we're in strict mode, we collect consider a syntax violation as an error + // if we're in strict mode, we consider a syntax violation as an error if strict { // we could build a vec of syntax violations and return them all, but that seems like overkill // and unlike other parser style validators @@ -517,7 +538,7 @@ fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool) -> ValResult input, )) } else { - Ok(url) + Ok(PyUrl::new(url, remove_trailing_slash)) } } Err(e) => Err(ValError::new( @@ -529,15 +550,16 @@ fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool) -> ValResult )), } } else { - Url::parse(url_str).map_err(move |e| { - ValError::new( + match Url::parse(url_str) { + Ok(url) => Ok(PyUrl::new(url, remove_trailing_slash)), + Err(e) => Err(ValError::new( ErrorType::UrlParsing { error: e.to_string(), context: None, }, input, - ) - }) + )), + } } } diff --git a/tests/validators/test_url.py b/tests/validators/test_url.py index 13c01182a..ae0e473ea 100644 --- a/tests/validators/test_url.py +++ b/tests/validators/test_url.py @@ -1,6 +1,6 @@ import re from copy import deepcopy -from typing import Optional, Union +from typing import Any, Optional, Union import pytest from dirty_equals import HasRepr, IsInstance @@ -60,7 +60,10 @@ def url_validator_fixture(): def url_test_case_helper( - url: str, expected: Union[Err, str], validator_mode: str, url_validator: Optional[SchemaValidator] = None + url: str, + expected: Union[Err, str, dict[str, Any]], + validator_mode: str, + url_validator: Optional[SchemaValidator] = None, ): if isinstance(expected, Err): with pytest.raises(ValidationError) as exc_info: @@ -271,6 +274,52 @@ def test_url_cases(url_validator, url, expected, mode): url_test_case_helper(url, expected, mode, url_validator) +@pytest.fixture(scope='module', name='url_validator_trailing_slash') +def url_url_validator_trailing_slash() -> SchemaValidator: + return SchemaValidator(core_schema.url_schema(extra_trailing_slash=False)) + + +@pytest.mark.parametrize( + 'url,expected', + [ + ('http://example.com', 'http://example.com'), + ('http:example.com', 'http://example.com'), + ('http://example.com/', 'http://example.com/'), + ('http:example.com/', 'http://example.com/'), + ('http://example.com/path', 'http://example.com/path'), + ('http://example.com/path/', 'http://example.com/path/'), + ('http://example.com/path/?x=1', 'http://example.com/path/?x=1'), + ], +) +def test_trailing_slash(url_validator_trailing_slash: SchemaValidator, url: str, expected: str): + url1 = Url(url, extra_trailing_slash=False) + assert str(url1) == expected + assert url1.unicode_string() == expected + + url2 = url_validator_trailing_slash.validate_python(url) + assert str(url2) == expected + assert url2.unicode_string() == expected + + +@pytest.mark.parametrize( + 'url,expected', + [ + ('http://example.com', 'http://example.com'), + ('http://example.com/', 'http://example.com/'), + ('http://example.com/path', 'http://example.com/path'), + ('http://example.com/path/', 'http://example.com/path/'), + # ('http://localhost,127.0.0.1', 'http://localhost,127.0.0.1'), + # ('http://localhost,127.0.0.1/', 'http://localhost,127.0.0.1/'), + ('http://localhost,127.0.0.1/path', 'http://localhost,127.0.0.1/path'), + # ('http://localhost,127.0.0.1/path/', 'http://localhost,127.0.0.1/path/'), + ], +) +def test_multi_trailing_slash(url: str, expected: str): + url1 = MultiHostUrl(url, extra_trailing_slash=False) + assert str(url1) == expected + assert url1.unicode_string() == expected + + @pytest.mark.parametrize( 'validator_kwargs,url,expected', [ From 5f59c518655bd379dd693cc99e42aff6c10067f6 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 24 May 2025 18:07:09 +0100 Subject: [PATCH 2/4] tests passing --- pyproject.toml | 5 ++++- python/pydantic_core/core_schema.py | 6 ++++++ src/url.rs | 15 +++++++-------- src/validators/url.rs | 2 +- tests/validators/test_url.py | 25 +++++++++++++++++++------ 5 files changed, 37 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d09c62a10..eb0fd6284 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ { name = 'David Montague', email = 'david@pydantic.dev' }, { name = 'David Hewitt', email = 'mail@davidhewitt.dev' }, { name = 'Sydney Runkle', email = 'sydneymarierunkle@gmail.com' }, - { name = 'Victorien Plot', email='contact@vctrn.dev' }, + { name = 'Victorien Plot', email = 'contact@vctrn.dev' }, ] classifiers = [ 'Development Status :: 3 - Alpha', @@ -149,6 +149,9 @@ require_change_file = false [tool.pyright] include = ['python/pydantic_core', 'tests/test_typing.py'] reportUnnecessaryTypeIgnoreComment = true +executionEnvironments = [ + { root = "tests", reportPrivateImportUsage = false, reportMissingParameterType = false }, +] [tool.inline-snapshot.shortcuts] fix = ["create", "fix"] diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 4af684240..4fbe664a4 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3824,6 +3824,7 @@ class UrlSchema(TypedDict, total=False): default_host: str default_port: int default_path: str + extra_trailing_slash: bool strict: bool ref: str metadata: dict[str, Any] @@ -3894,6 +3895,7 @@ class MultiHostUrlSchema(TypedDict, total=False): default_host: str default_port: int default_path: str + extra_trailing_slash: bool strict: bool ref: str metadata: dict[str, Any] @@ -3908,6 +3910,7 @@ def multi_host_url_schema( default_host: str | None = None, default_port: int | None = None, default_path: str | None = None, + extra_trailing_slash: bool | None = None, strict: bool | None = None, ref: str | None = None, metadata: dict[str, Any] | None = None, @@ -3932,6 +3935,8 @@ def multi_host_url_schema( default_host: The default host to use if the URL does not have a host default_port: The default port to use if the URL does not have a port default_path: The default path to use if the URL does not have a path + extra_trailing_slash: Whether to add an extra trailing slash to the URL, defaults to `True` for + backward compatibility, default will change to `False` in v3 version. strict: Whether to use strict URL parsing ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -3945,6 +3950,7 @@ def multi_host_url_schema( default_host=default_host, default_port=default_port, default_path=default_path, + extra_trailing_slash=extra_trailing_slash, strict=strict, ref=ref, metadata=metadata, diff --git a/src/url.rs b/src/url.rs index 909e5aa1c..505830f4c 100644 --- a/src/url.rs +++ b/src/url.rs @@ -69,7 +69,7 @@ fn build_schema_validator(py: Python, schema_type: &str, extra_trailing_slash: b #[pymethods] impl PyUrl { #[new] - #[pyo3(signature = (url, *, extra_trailing_slash=false))] + #[pyo3(signature = (url, *, extra_trailing_slash=true))] pub fn py_new(py: Python, url: &Bound<'_, PyAny>, extra_trailing_slash: bool) -> PyResult { let schema_validator = get_schema_validator(py, false, extra_trailing_slash); let schema_obj = schema_validator.validate_python(py, url, None, None, None, None, false.into(), None, None)?; @@ -248,7 +248,7 @@ impl PyMultiHostUrl { #[pymethods] impl PyMultiHostUrl { #[new] - #[pyo3(signature = (url, *, extra_trailing_slash=false))] + #[pyo3(signature = (url, *, extra_trailing_slash=true))] pub fn py_new(py: Python, url: &Bound<'_, PyAny>, extra_trailing_slash: bool) -> PyResult { let schema_validator = get_schema_validator(py, true, extra_trailing_slash); let schema_obj = schema_validator.validate_python(py, url, None, None, None, None, false.into(), None, None)?; @@ -305,12 +305,12 @@ impl PyMultiHostUrl { // special urls will have had a trailing slash added, non-special urls will not // hence we need to remove the last char if the schema is special - let sub: usize = (!self.ref_url.remove_trailing_slash && schema_is_special(schema)).into(); + let sub: usize = schema_is_special(schema).into(); let hosts = extra_urls .iter() .map(|url| { - let str = unicode_url(url, self.ref_url.remove_trailing_slash); + let str = unicode_url(url, false); str[host_offset..str.len() - sub].to_string() }) .collect::>() @@ -327,13 +327,12 @@ impl PyMultiHostUrl { let schema = self.ref_url.lib_url.scheme(); let host_offset = schema.len() + 3; - let mut full_url = self.ref_url.lib_url.to_string(); + let mut full_url = self.ref_url.__str__().to_string(); full_url.insert(host_offset, ','); // special urls will have had a trailing slash added, non-special urls will not // hence we need to remove the last char if the schema is special - #[allow(clippy::bool_to_int_with_if)] - let sub = if schema_is_special(schema) { 1 } else { 0 }; + let sub: usize = schema_is_special(schema).into(); let hosts = extra_urls .iter() @@ -341,7 +340,7 @@ impl PyMultiHostUrl { let str = url.as_str(); &str[host_offset..str.len() - sub] }) - .collect::>() + .collect::>() .join(","); full_url.insert_str(host_offset, &hosts); full_url diff --git a/src/validators/url.rs b/src/validators/url.rs index 9fd0529b9..6d514a93c 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -238,7 +238,7 @@ impl BuildValidator for MultiHostUrlValidator { name, extra_trailing_slash: schema .get_as(intern!(schema.py(), "extra_trailing_slash"))? - .unwrap_or(false), + .unwrap_or(true), } .into()) } diff --git a/tests/validators/test_url.py b/tests/validators/test_url.py index ae0e473ea..3ca0c67a8 100644 --- a/tests/validators/test_url.py +++ b/tests/validators/test_url.py @@ -55,7 +55,7 @@ def url_validator_fixture(): SCHEMA_VALIDATOR_MODE = 'SCHEMA_VALIDATOR' -URL_CLASS_MODE = 'URI_CLASS' +URL_CLASS_MODE = 'URL_CLASS' MULTI_URL_CLASS_MODE = 'MULTI_URL_CLASS' @@ -86,6 +86,7 @@ def url_test_case_helper( output_url = MultiHostUrl(url) else: raise ValueError(f'Unknown validator mode: {validator_mode}') + assert isinstance(output_url, (Url, MultiHostUrl)) if isinstance(expected, str): assert str(output_url) == expected @@ -275,7 +276,7 @@ def test_url_cases(url_validator, url, expected, mode): @pytest.fixture(scope='module', name='url_validator_trailing_slash') -def url_url_validator_trailing_slash() -> SchemaValidator: +def url_validator_trailing_slash() -> SchemaValidator: return SchemaValidator(core_schema.url_schema(extra_trailing_slash=False)) @@ -301,6 +302,11 @@ def test_trailing_slash(url_validator_trailing_slash: SchemaValidator, url: str, assert url2.unicode_string() == expected +@pytest.fixture(scope='module', name='multi_url_validator_trailing_slash') +def multi_url_validator_trailing_slash() -> SchemaValidator: + return SchemaValidator(core_schema.multi_host_url_schema(extra_trailing_slash=False)) + + @pytest.mark.parametrize( 'url,expected', [ @@ -308,17 +314,24 @@ def test_trailing_slash(url_validator_trailing_slash: SchemaValidator, url: str, ('http://example.com/', 'http://example.com/'), ('http://example.com/path', 'http://example.com/path'), ('http://example.com/path/', 'http://example.com/path/'), - # ('http://localhost,127.0.0.1', 'http://localhost,127.0.0.1'), - # ('http://localhost,127.0.0.1/', 'http://localhost,127.0.0.1/'), + ('http://example.com,example.org', 'http://example.com,example.org'), + ('http://example.com,example.org/', 'http://example.com,example.org/'), + ('http://localhost,127.0.0.1', 'http://localhost,127.0.0.1'), + ('http://localhost,127.0.0.1/', 'http://localhost,127.0.0.1/'), + ('http:localhost,127.0.0.1', 'http://localhost,127.0.0.1'), ('http://localhost,127.0.0.1/path', 'http://localhost,127.0.0.1/path'), - # ('http://localhost,127.0.0.1/path/', 'http://localhost,127.0.0.1/path/'), + ('http://localhost,127.0.0.1/path/', 'http://localhost,127.0.0.1/path/'), ], ) -def test_multi_trailing_slash(url: str, expected: str): +def test_multi_trailing_slash(multi_url_validator_trailing_slash: SchemaValidator, url: str, expected: str): url1 = MultiHostUrl(url, extra_trailing_slash=False) assert str(url1) == expected assert url1.unicode_string() == expected + url2 = multi_url_validator_trailing_slash.validate_python(url) + assert str(url2) == expected + assert url2.unicode_string() == expected + @pytest.mark.parametrize( 'validator_kwargs,url,expected', From 4e19c63aaa7b9ca1744a7e810433b6ba9256bdf1 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 24 May 2025 18:24:54 +0100 Subject: [PATCH 3/4] add config --- pyproject.toml | 2 +- python/pydantic_core/_pydantic_core.pyi | 12 +++---- python/pydantic_core/core_schema.py | 19 +++++----- src/url.rs | 20 +++++------ src/validators/url.rs | 48 ++++++++++++++----------- tests/validators/test_url.py | 19 +++++++--- 6 files changed, 70 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eb0fd6284..06a8ec9fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,7 +150,7 @@ require_change_file = false include = ['python/pydantic_core', 'tests/test_typing.py'] reportUnnecessaryTypeIgnoreComment = true executionEnvironments = [ - { root = "tests", reportPrivateImportUsage = false, reportMissingParameterType = false }, + { root = "tests", reportPrivateImportUsage = false, reportMissingParameterType = false, reportAny = false }, ] [tool.inline-snapshot.shortcuts] diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 20144f49e..2cd297da9 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -522,15 +522,15 @@ class Url(SupportsAllComparisons): by Mozilla. """ - def __init__(self, url: str, *, extra_trailing_slash: bool = True) -> None: + def __init__(self, url: str, *, add_trailing_slash: bool = True) -> None: """Initialize a new URL object. Args: url: The URL string to parse. - extra_trailing_slash: Whether to add an extra trailing slash to the URL, defaults to `True` for + add_trailing_slash: Whether to add an extra trailing slash to some URLs, defaults to `True` for backward compatibility, default will change to `False` in v3 version. """ - def __new__(cls, url: str, *, extra_trailing_slash: bool = True) -> Self: ... + def __new__(cls, url: str, *, add_trailing_slash: bool = True) -> Self: ... @property def scheme(self) -> str: ... @property @@ -575,15 +575,15 @@ class MultiHostUrl(SupportsAllComparisons): by Mozilla. """ - def __init__(self, url: str, *, extra_trailing_slash: bool = True) -> None: + def __init__(self, url: str, *, add_trailing_slash: bool = True) -> None: """Initialize a new MultiHostUrl object. Args: url: The URL string to parse. - extra_trailing_slash: Whether to add an extra trailing slash to the URL, defaults to `True` for + add_trailing_slash: Whether to add an extra trailing slash to some URLs, defaults to `True` for backward compatibility, default will change to `False` in v3 version. """ - def __new__(cls, url: str, *, extra_trailing_slash: bool = True) -> Self: ... + def __new__(cls, url: str, *, add_trailing_slash: bool = True) -> Self: ... @property def scheme(self) -> str: ... @property diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 4fbe664a4..cc0a56246 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -75,6 +75,8 @@ class CoreConfig(TypedDict, total=False): validate_by_alias: Whether to use the field's alias when validating against the provided input data. Default is `True`. validate_by_name: Whether to use the field's name when validating against the provided input data. Default is `False`. Replacement for `populate_by_name`. serialize_by_alias: Whether to serialize by alias. Default is `False`, expected to change to `True` in V3. + url_add_trailing_slash: Whether to add an extra trailing slash to some URLs, defaults to `True` for + backward compatibility, default will change to `False` in v3 version. """ title: str @@ -114,6 +116,7 @@ class CoreConfig(TypedDict, total=False): validate_by_alias: bool # default: True validate_by_name: bool # default: False serialize_by_alias: bool # default: False + url_add_trailing_slash: bool # default: True IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None' @@ -3824,7 +3827,7 @@ class UrlSchema(TypedDict, total=False): default_host: str default_port: int default_path: str - extra_trailing_slash: bool + add_trailing_slash: bool strict: bool ref: str metadata: dict[str, Any] @@ -3839,7 +3842,7 @@ def url_schema( default_host: str | None = None, default_port: int | None = None, default_path: str | None = None, - extra_trailing_slash: bool | None = None, + add_trailing_slash: bool | None = None, strict: bool | None = None, ref: str | None = None, metadata: dict[str, Any] | None = None, @@ -3864,7 +3867,7 @@ def url_schema( default_host: The default host to use if the URL does not have a host default_port: The default port to use if the URL does not have a port default_path: The default path to use if the URL does not have a path - extra_trailing_slash: Whether to add an extra trailing slash to the URL, defaults to `True` for + add_trailing_slash: Whether to add an extra trailing slash to some URLs, defaults to `True` for backward compatibility, default will change to `False` in v3 version. strict: Whether to use strict URL parsing ref: optional unique identifier of the schema, used to reference the schema in other places @@ -3883,7 +3886,7 @@ def url_schema( ref=ref, metadata=metadata, serialization=serialization, - extra_trailing_slash=extra_trailing_slash, + add_trailing_slash=add_trailing_slash, ) @@ -3895,7 +3898,7 @@ class MultiHostUrlSchema(TypedDict, total=False): default_host: str default_port: int default_path: str - extra_trailing_slash: bool + add_trailing_slash: bool strict: bool ref: str metadata: dict[str, Any] @@ -3910,7 +3913,7 @@ def multi_host_url_schema( default_host: str | None = None, default_port: int | None = None, default_path: str | None = None, - extra_trailing_slash: bool | None = None, + add_trailing_slash: bool | None = None, strict: bool | None = None, ref: str | None = None, metadata: dict[str, Any] | None = None, @@ -3935,7 +3938,7 @@ def multi_host_url_schema( default_host: The default host to use if the URL does not have a host default_port: The default port to use if the URL does not have a port default_path: The default path to use if the URL does not have a path - extra_trailing_slash: Whether to add an extra trailing slash to the URL, defaults to `True` for + add_trailing_slash: Whether to add an extra trailing slash to some URLs, defaults to `True` for backward compatibility, default will change to `False` in v3 version. strict: Whether to use strict URL parsing ref: optional unique identifier of the schema, used to reference the schema in other places @@ -3950,7 +3953,7 @@ def multi_host_url_schema( default_host=default_host, default_port=default_port, default_path=default_path, - extra_trailing_slash=extra_trailing_slash, + add_trailing_slash=add_trailing_slash, strict=strict, ref=ref, metadata=metadata, diff --git a/src/url.rs b/src/url.rs index 505830f4c..f6bd7e4d7 100644 --- a/src/url.rs +++ b/src/url.rs @@ -50,8 +50,8 @@ static SCHEMA_URL_SINGLE_FALSE: GILOnceCell = GILOnceCell::new( static SCHEMA_URL_MULTI_TRUE: GILOnceCell = GILOnceCell::new(); static SCHEMA_URL_MULTI_FALSE: GILOnceCell = GILOnceCell::new(); -fn get_schema_validator(py: Python<'_>, multi_host: bool, extra_trailing_slash: bool) -> &SchemaValidator { - match (multi_host, extra_trailing_slash) { +fn get_schema_validator(py: Python<'_>, multi_host: bool, add_trailing_slash: bool) -> &SchemaValidator { + match (multi_host, add_trailing_slash) { (false, true) => SCHEMA_URL_SINGLE_TRUE.get_or_init(py, || build_schema_validator(py, "url", true)), (false, false) => SCHEMA_URL_SINGLE_FALSE.get_or_init(py, || build_schema_validator(py, "url", false)), (true, true) => SCHEMA_URL_MULTI_TRUE.get_or_init(py, || build_schema_validator(py, "multi-host-url", true)), @@ -59,19 +59,19 @@ fn get_schema_validator(py: Python<'_>, multi_host: bool, extra_trailing_slash: } } -fn build_schema_validator(py: Python, schema_type: &str, extra_trailing_slash: bool) -> SchemaValidator { +fn build_schema_validator(py: Python, schema_type: &str, add_trailing_slash: bool) -> SchemaValidator { let schema = PyDict::new(py); schema.set_item("type", schema_type).unwrap(); - schema.set_item("extra_trailing_slash", extra_trailing_slash).unwrap(); + schema.set_item("add_trailing_slash", add_trailing_slash).unwrap(); SchemaValidator::py_new(py, &schema, None).unwrap() } #[pymethods] impl PyUrl { #[new] - #[pyo3(signature = (url, *, extra_trailing_slash=true))] - pub fn py_new(py: Python, url: &Bound<'_, PyAny>, extra_trailing_slash: bool) -> PyResult { - let schema_validator = get_schema_validator(py, false, extra_trailing_slash); + #[pyo3(signature = (url, *, add_trailing_slash=true))] + pub fn py_new(py: Python, url: &Bound<'_, PyAny>, add_trailing_slash: bool) -> PyResult { + let schema_validator = get_schema_validator(py, false, add_trailing_slash); let schema_obj = schema_validator.validate_python(py, url, None, None, None, None, false.into(), None, None)?; schema_obj.extract(py) } @@ -248,9 +248,9 @@ impl PyMultiHostUrl { #[pymethods] impl PyMultiHostUrl { #[new] - #[pyo3(signature = (url, *, extra_trailing_slash=true))] - pub fn py_new(py: Python, url: &Bound<'_, PyAny>, extra_trailing_slash: bool) -> PyResult { - let schema_validator = get_schema_validator(py, true, extra_trailing_slash); + #[pyo3(signature = (url, *, add_trailing_slash=true))] + pub fn py_new(py: Python, url: &Bound<'_, PyAny>, add_trailing_slash: bool) -> PyResult { + let schema_validator = get_schema_validator(py, true, add_trailing_slash); let schema_obj = schema_validator.validate_python(py, url, None, None, None, None, false.into(), None, None)?; schema_obj.extract(py) } diff --git a/src/validators/url.rs b/src/validators/url.rs index 6d514a93c..a50f0d50f 100644 --- a/src/validators/url.rs +++ b/src/validators/url.rs @@ -10,7 +10,7 @@ use ahash::AHashSet; use pyo3::IntoPyObjectExt; use url::{ParseError, SyntaxViolation, Url}; -use crate::build_tools::{is_strict, py_schema_err}; +use crate::build_tools::{is_strict, py_schema_err, schema_or_config}; use crate::errors::ToErrorValue; use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult}; use crate::input::downcast_python_input; @@ -34,7 +34,17 @@ pub struct UrlValidator { default_port: Option, default_path: Option, name: String, - extra_trailing_slash: bool, + add_trailing_slash: bool, +} + +fn get_add_trailing_slash(schema: &Bound<'_, PyDict>, config: Option<&Bound<'_, PyDict>>) -> PyResult { + schema_or_config( + schema, + config, + intern!(schema.py(), "add_trailing_slash"), + intern!(schema.py(), "url_add_trailing_slash"), + ) + .map(|v| v.unwrap_or(true)) } impl BuildValidator for UrlValidator { @@ -56,9 +66,7 @@ impl BuildValidator for UrlValidator { default_path: schema.get_as(intern!(schema.py(), "default_path"))?, allowed_schemes, name, - extra_trailing_slash: schema - .get_as(intern!(schema.py(), "extra_trailing_slash"))? - .unwrap_or(true), + add_trailing_slash: get_add_trailing_slash(schema, config)?, } .into()) } @@ -73,7 +81,7 @@ impl Validator for UrlValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let mut either_url = self.get_url(input, state.strict_or(self.strict), self.extra_trailing_slash)?; + let mut either_url = self.get_url(input, state.strict_or(self.strict), self.add_trailing_slash)?; if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { if !allowed_schemes.contains(either_url.url().scheme()) { @@ -114,7 +122,7 @@ impl UrlValidator { &self, input: &(impl Input<'py> + ?Sized), strict: bool, - extra_trailing_slash: bool, + add_trailing_slash: bool, ) -> ValResult> { match input.validate_str(strict, false) { Ok(val_match) => { @@ -124,7 +132,7 @@ impl UrlValidator { self.check_length(input, url_str)?; - parse_url(url_str, input, strict, extra_trailing_slash).map(EitherUrl::Owned) + parse_url(url_str, input, strict, add_trailing_slash).map(EitherUrl::Owned) } Err(_) => { // we don't need to worry about whether the url was parsed in strict mode before, @@ -136,7 +144,7 @@ impl UrlValidator { let url_str = multi_host_url.get().__str__(); self.check_length(input, &url_str)?; - parse_url(&url_str, input, strict, extra_trailing_slash).map(EitherUrl::Owned) + parse_url(&url_str, input, strict, add_trailing_slash).map(EitherUrl::Owned) } else { Err(ValError::new(ErrorTypeDefaults::UrlType, input)) } @@ -208,7 +216,7 @@ pub struct MultiHostUrlValidator { default_port: Option, default_path: Option, name: String, - extra_trailing_slash: bool, + add_trailing_slash: bool, } impl BuildValidator for MultiHostUrlValidator { @@ -236,9 +244,7 @@ impl BuildValidator for MultiHostUrlValidator { default_port: schema.get_as(intern!(schema.py(), "default_port"))?, default_path: schema.get_as(intern!(schema.py(), "default_path"))?, name, - extra_trailing_slash: schema - .get_as(intern!(schema.py(), "extra_trailing_slash"))? - .unwrap_or(true), + add_trailing_slash: get_add_trailing_slash(schema, config)?, } .into()) } @@ -253,7 +259,7 @@ impl Validator for MultiHostUrlValidator { input: &(impl Input<'py> + ?Sized), state: &mut ValidationState<'_, 'py>, ) -> ValResult { - let mut multi_url = self.get_url(input, state.strict_or(self.strict), self.extra_trailing_slash)?; + let mut multi_url = self.get_url(input, state.strict_or(self.strict), self.add_trailing_slash)?; if let Some((ref allowed_schemes, ref expected_schemes_repr)) = self.allowed_schemes { if !allowed_schemes.contains(multi_url.url().scheme()) { @@ -293,7 +299,7 @@ impl MultiHostUrlValidator { &self, input: &(impl Input<'py> + ?Sized), strict: bool, - extra_trailing_slash: bool, + add_trailing_slash: bool, ) -> ValResult> { match input.validate_str(strict, false) { Ok(val_match) => { @@ -303,7 +309,7 @@ impl MultiHostUrlValidator { self.check_length(input, || url_str.len())?; - parse_multihost_url(url_str, input, strict, extra_trailing_slash).map(EitherMultiHostUrl::Rust) + parse_multihost_url(url_str, input, strict, add_trailing_slash).map(EitherMultiHostUrl::Rust) } Err(_) => { // we don't need to worry about whether the url was parsed in strict mode before, @@ -384,7 +390,7 @@ fn parse_multihost_url<'py>( url_str: &str, input: &(impl Input<'py> + ?Sized), strict: bool, - extra_trailing_slash: bool, + add_trailing_slash: bool, ) -> ValResult { macro_rules! parsing_err { ($parse_error:expr) => { @@ -474,7 +480,7 @@ fn parse_multihost_url<'py>( // with just one host, for consistent behaviour, we parse the URL the same as with multiple hosts let reconstructed_url = format!("{prefix}{}", &url_str[start..]); - let ref_url = parse_url(&reconstructed_url, input, strict, extra_trailing_slash)?; + let ref_url = parse_url(&reconstructed_url, input, strict, add_trailing_slash)?; if hosts.is_empty() { // if there's no one host (e.g. no `,`), we allow it to be empty to allow for default hosts @@ -488,7 +494,7 @@ fn parse_multihost_url<'py>( .iter() .map(|host| { let reconstructed_url = format!("{prefix}{host}"); - parse_url(&reconstructed_url, input, strict, extra_trailing_slash).map(Into::into) + parse_url(&reconstructed_url, input, strict, add_trailing_slash).map(Into::into) }) .collect::>()?; @@ -500,7 +506,7 @@ fn parse_multihost_url<'py>( } } -fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool, extra_trailing_slash: bool) -> ValResult { +fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool, add_trailing_slash: bool) -> ValResult { if url_str.is_empty() { return Err(ValError::new( ErrorType::UrlParsing { @@ -510,7 +516,7 @@ fn parse_url(url_str: &str, input: impl ToErrorValue, strict: bool, extra_traili input, )); } - let remove_trailing_slash = !extra_trailing_slash && !url_str.ends_with('/'); + let remove_trailing_slash = !add_trailing_slash && !url_str.ends_with('/'); // if we're in strict mode, we consider a syntax violation as an error if strict { diff --git a/tests/validators/test_url.py b/tests/validators/test_url.py index 3ca0c67a8..ab5ff46cf 100644 --- a/tests/validators/test_url.py +++ b/tests/validators/test_url.py @@ -277,7 +277,7 @@ def test_url_cases(url_validator, url, expected, mode): @pytest.fixture(scope='module', name='url_validator_trailing_slash') def url_validator_trailing_slash() -> SchemaValidator: - return SchemaValidator(core_schema.url_schema(extra_trailing_slash=False)) + return SchemaValidator(core_schema.url_schema(add_trailing_slash=False)) @pytest.mark.parametrize( @@ -293,7 +293,7 @@ def url_validator_trailing_slash() -> SchemaValidator: ], ) def test_trailing_slash(url_validator_trailing_slash: SchemaValidator, url: str, expected: str): - url1 = Url(url, extra_trailing_slash=False) + url1 = Url(url, add_trailing_slash=False) assert str(url1) == expected assert url1.unicode_string() == expected @@ -304,7 +304,7 @@ def test_trailing_slash(url_validator_trailing_slash: SchemaValidator, url: str, @pytest.fixture(scope='module', name='multi_url_validator_trailing_slash') def multi_url_validator_trailing_slash() -> SchemaValidator: - return SchemaValidator(core_schema.multi_host_url_schema(extra_trailing_slash=False)) + return SchemaValidator(core_schema.multi_host_url_schema(add_trailing_slash=False)) @pytest.mark.parametrize( @@ -324,7 +324,7 @@ def multi_url_validator_trailing_slash() -> SchemaValidator: ], ) def test_multi_trailing_slash(multi_url_validator_trailing_slash: SchemaValidator, url: str, expected: str): - url1 = MultiHostUrl(url, extra_trailing_slash=False) + url1 = MultiHostUrl(url, add_trailing_slash=False) assert str(url1) == expected assert url1.unicode_string() == expected @@ -333,6 +333,17 @@ def test_multi_trailing_slash(multi_url_validator_trailing_slash: SchemaValidato assert url2.unicode_string() == expected +def test_multi_trailing_slash_config(): + s = SchemaValidator(core_schema.url_schema(), CoreConfig(url_add_trailing_slash=False)) + url1 = s.validate_python('http://example.com') + assert str(url1) == 'http://example.com' + assert url1.unicode_string() == 'http://example.com' + + url2 = s.validate_python('http://example.com/') + assert str(url2) == 'http://example.com/' + assert url2.unicode_string() == 'http://example.com/' + + @pytest.mark.parametrize( 'validator_kwargs,url,expected', [ From 660b112962d4d256d9fbecbde2f9ed12aa987d1e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 24 May 2025 18:43:58 +0100 Subject: [PATCH 4/4] improve building global sceham validators --- src/url.rs | 51 ++++++++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/src/url.rs b/src/url.rs index f6bd7e4d7..22bb8c443 100644 --- a/src/url.rs +++ b/src/url.rs @@ -45,33 +45,12 @@ impl From for Url { } } -static SCHEMA_URL_SINGLE_TRUE: GILOnceCell = GILOnceCell::new(); -static SCHEMA_URL_SINGLE_FALSE: GILOnceCell = GILOnceCell::new(); -static SCHEMA_URL_MULTI_TRUE: GILOnceCell = GILOnceCell::new(); -static SCHEMA_URL_MULTI_FALSE: GILOnceCell = GILOnceCell::new(); - -fn get_schema_validator(py: Python<'_>, multi_host: bool, add_trailing_slash: bool) -> &SchemaValidator { - match (multi_host, add_trailing_slash) { - (false, true) => SCHEMA_URL_SINGLE_TRUE.get_or_init(py, || build_schema_validator(py, "url", true)), - (false, false) => SCHEMA_URL_SINGLE_FALSE.get_or_init(py, || build_schema_validator(py, "url", false)), - (true, true) => SCHEMA_URL_MULTI_TRUE.get_or_init(py, || build_schema_validator(py, "multi-host-url", true)), - (true, false) => SCHEMA_URL_MULTI_FALSE.get_or_init(py, || build_schema_validator(py, "multi-host-url", false)), - } -} - -fn build_schema_validator(py: Python, schema_type: &str, add_trailing_slash: bool) -> SchemaValidator { - let schema = PyDict::new(py); - schema.set_item("type", schema_type).unwrap(); - schema.set_item("add_trailing_slash", add_trailing_slash).unwrap(); - SchemaValidator::py_new(py, &schema, None).unwrap() -} - #[pymethods] impl PyUrl { #[new] #[pyo3(signature = (url, *, add_trailing_slash=true))] pub fn py_new(py: Python, url: &Bound<'_, PyAny>, add_trailing_slash: bool) -> PyResult { - let schema_validator = get_schema_validator(py, false, add_trailing_slash); + let schema_validator = get_schema_validator(py, false, add_trailing_slash)?; let schema_obj = schema_validator.validate_python(py, url, None, None, None, None, false.into(), None, None)?; schema_obj.extract(py) } @@ -250,7 +229,7 @@ impl PyMultiHostUrl { #[new] #[pyo3(signature = (url, *, add_trailing_slash=true))] pub fn py_new(py: Python, url: &Bound<'_, PyAny>, add_trailing_slash: bool) -> PyResult { - let schema_validator = get_schema_validator(py, true, add_trailing_slash); + let schema_validator = get_schema_validator(py, true, add_trailing_slash)?; let schema_obj = schema_validator.validate_python(py, url, None, None, None, None, false.into(), None, None)?; schema_obj.extract(py) } @@ -545,3 +524,29 @@ fn is_punnycode_domain(lib_url: &Url, domain: &str) -> bool { pub fn schema_is_special(schema: &str) -> bool { matches!(schema, "http" | "https" | "ws" | "wss" | "ftp" | "file") } + +static SCHEMA_URL_SINGLE_TRUE: GILOnceCell = GILOnceCell::new(); +static SCHEMA_URL_SINGLE_FALSE: GILOnceCell = GILOnceCell::new(); +static SCHEMA_URL_MULTI_TRUE: GILOnceCell = GILOnceCell::new(); +static SCHEMA_URL_MULTI_FALSE: GILOnceCell = GILOnceCell::new(); + +macro_rules! make_schema_val { + ($py:ident, $schema_type:literal, $add_trailing_slash:literal) => {{ + let schema = PyDict::new($py); + schema.set_item(intern!($py, "type"), intern!($py, $schema_type))?; + // add_trailing_slash defaults to true, so only set it if false + if !$add_trailing_slash { + schema.set_item(intern!($py, "add_trailing_slash"), false)?; + } + SchemaValidator::py_new($py, &schema, None) + }}; +} + +fn get_schema_validator(py: Python<'_>, multi_host: bool, add_trailing_slash: bool) -> PyResult<&SchemaValidator> { + match (multi_host, add_trailing_slash) { + (false, true) => SCHEMA_URL_SINGLE_TRUE.get_or_try_init(py, || make_schema_val!(py, "url", true)), + (false, false) => SCHEMA_URL_SINGLE_FALSE.get_or_try_init(py, || make_schema_val!(py, "url", false)), + (true, true) => SCHEMA_URL_MULTI_TRUE.get_or_try_init(py, || make_schema_val!(py, "multi-host-url", true)), + (true, false) => SCHEMA_URL_MULTI_FALSE.get_or_try_init(py, || make_schema_val!(py, "multi-host-url", false)), + } +}