Skip to content

Commit d9da7de

Browse files
committed
refactor(types): TestParameterGroup into a pydantic model
1 parent e0857e5 commit d9da7de

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/ethereum_test_types/helpers.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Helper functions/classes used to generate Ethereum tests."""
22

3-
from dataclasses import MISSING, dataclass, fields
43
from typing import List, SupportsBytes
54

65
import ethereum_rlp as eth_rlp
6+
from pydantic import BaseModel, ConfigDict
77

88
from ethereum_test_base_types.base_types import Address, Bytes, Hash
99
from ethereum_test_base_types.conversions import BytesConvertible, FixedSizeBytesConvertible
@@ -92,8 +92,7 @@ def add_kzg_version(
9292
return kzg_versioned_hashes
9393

9494

95-
@dataclass(kw_only=True, frozen=True, repr=False)
96-
class TestParameterGroup:
95+
class TestParameterGroup(BaseModel):
9796
"""
9897
Base class for grouping test parameters in a dataclass. Provides a generic
9998
__repr__ method to generate clean test ids, including only non-default
@@ -102,18 +101,18 @@ class TestParameterGroup:
102101

103102
__test__ = False # explicitly prevent pytest collecting this class
104103

104+
model_config = ConfigDict(frozen=True, repr=False, validate_default=True)
105+
105106
def __repr__(self):
106107
"""
107108
Generate repr string, intended to be used as a test id, based on the class
108109
name and the values of the non-default optional fields.
109110
"""
110111
class_name = self.__class__.__name__
111-
field_strings = []
112-
113-
for field in fields(self):
114-
value = getattr(self, field.name)
112+
field_strings = [
113+
f"{field}_{value}"
115114
# Include the field only if it is not optional or not set to its default value
116-
if field.default is MISSING or field.default != value:
117-
field_strings.append(f"{field.name}_{value}")
115+
for field, value in self.model_dump(exclude_defaults=True, exclude_unset=True).items()
116+
]
118117

119118
return f"{class_name}_{'-'.join(field_strings)}"

0 commit comments

Comments
 (0)