Skip to content

Add type annotations for all returns and enable ruff annotation checks #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ extend-select = [
"UP035",
# Missing function argument type-annotation
"ANN001",
"ANN002",
"ANN003",
"ANN201",
"ANN202",
"ANN204",
"ANN205",
"ANN206",
# Using except without specifying an exception type to catch
"BLE001",
]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
realisation_to_srf,
],
)
def test_invocation_of_script(script: Callable):
def test_invocation_of_script(script: Callable) -> None:
"""Basic check that the scripts can be invoked."""
runner = CliRunner()
result = runner.invoke(script.app, ["--help"])
Expand Down
30 changes: 15 additions & 15 deletions tests/test_log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@


@log_utils.log_call()
def foo(a: int, b: int):
def foo(a: int, b: int) -> int:
return a + b


def test_basic_log():
def test_basic_log() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -35,11 +35,11 @@ def test_basic_log():


@log_utils.log_call(exclude_args={"b"})
def foo_less_b(a: int, b: int):
def foo_less_b(a: int, b: int) -> int:
return a + b


def test_excluded_log():
def test_excluded_log() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -59,11 +59,11 @@ def test_excluded_log():


@log_utils.log_call(action_name="FOOBAR")
def bar(a: Any):
def bar(a: Any) -> None:
pass


def test_renamed_bar():
def test_renamed_bar() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -81,11 +81,11 @@ def test_renamed_bar():


@log_utils.log_call(include_result=False)
def baz(a: Any):
def baz(a: Any) -> int:
return 1


def test_no_result():
def test_no_result() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -104,11 +104,11 @@ def test_no_result():


@log_utils.log_call()
def failing_function():
def failing_function() -> None:
raise ValueError("This function should fail!")


def test_failing_function():
def test_failing_function() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -122,7 +122,7 @@ def test_failing_function():
assert "error" in return_log


def test_successful_check_call_log(tmp_path: Path):
def test_successful_check_call_log(tmp_path: Path) -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -140,7 +140,7 @@ def test_successful_check_call_log(tmp_path: Path):
assert "stdout" in completion_message and "test.txt" in completion_message["stdout"]


def test_failing_check_call_log():
def test_failing_check_call_log() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -156,7 +156,7 @@ def test_failing_check_call_log():
)


def test_repeated_logs():
def test_repeated_logs() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand All @@ -177,12 +177,12 @@ def test_repeated_logs():
)


def _thread_worker(logger_name: str):
def _thread_worker(logger_name: str) -> None:
logger = log_utils.get_logger(logger_name)
logger.info("Threaded log message")


def test_thread_safety():
def test_thread_safety() -> None:
log_capture = structlog.testing.LogCapture()
structlog.configure(processors=[log_capture])

Expand Down
44 changes: 22 additions & 22 deletions tests/test_realisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from workflow import defaults, realisations


def test_bounding_box_example(tmp_path: Path):
def test_bounding_box_example(tmp_path: Path) -> None:
domain_parameters = realisations.DomainParameters(
resolution=0.1, # a 0.1km resolution
domain=bounding_box.BoundingBox.from_centroid_bearing_extents(
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_bounding_box_example(tmp_path: Path):
).all()


def test_domain_parameters_properties():
def test_domain_parameters_properties() -> None:
domain_parameters = realisations.DomainParameters(
resolution=0.1, # a 0.1km resolution
domain=bounding_box.BoundingBox.from_centroid_bearing_extents(
Expand All @@ -78,7 +78,7 @@ def test_domain_parameters_properties():
assert domain_parameters.nz == 400


def test_srf_config_example(tmp_path: Path):
def test_srf_config_example(tmp_path: Path) -> None:
domain_parameters = realisations.DomainParameters(
resolution=0.1, # a 0.1km resolution
domain=bounding_box.BoundingBox.from_centroid_bearing_extents(
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_srf_config_example(tmp_path: Path):
assert realisations.SRFConfig.read_from_realisation(realisation_ffp) == srf_config


def test_bad_domain_parameters(tmp_path: Path):
def test_bad_domain_parameters(tmp_path: Path) -> None:
bad_json = tmp_path / "bad_domain_parameters.json"
bad_json.write_text(
json.dumps(
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_bad_domain_parameters(tmp_path: Path):
realisations.DomainParameters.read_from_realisation(bad_json)


def test_bad_config_key(tmp_path: Path):
def test_bad_config_key(tmp_path: Path) -> None:
bad_json = tmp_path / "bad_domain_parameters.json"
bad_json.write_text(
json.dumps(
Expand Down Expand Up @@ -196,7 +196,7 @@ def test_bad_config_key(tmp_path: Path):
realisations.DomainParameters.read_from_realisation(bad_json)


def test_metadata(tmp_path: Path):
def test_metadata(tmp_path: Path) -> None:
metadata = realisations.RealisationMetadata(
name="consecutive write test",
version="1",
Expand All @@ -220,7 +220,7 @@ def test_metadata(tmp_path: Path):
)


def test_velocity_model(tmp_path: Path):
def test_velocity_model(tmp_path: Path) -> None:
velocity_model = realisations.VelocityModelParameters(
min_vs=1.0,
version="2.06",
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_velocity_model(tmp_path: Path):
)


def test_rupture_prop_config(tmp_path: Path):
def test_rupture_prop_config(tmp_path: Path) -> None:
rup_prop = realisations.RupturePropagationConfig(
rupture_causality_tree={"A": None, "B": "A", "C": "B"},
jump_points={
Expand Down Expand Up @@ -307,7 +307,7 @@ def test_rupture_prop_config(tmp_path: Path):
assert rupture_prop_config.hypocentre.tolist() == [0.0, 0.6]


def test_rupture_prop_properties():
def test_rupture_prop_properties() -> None:
rup_prop = realisations.RupturePropagationConfig(
rupture_causality_tree={"A": None, "B": "A", "C": "B"},
jump_points={
Expand All @@ -325,7 +325,7 @@ def test_rupture_prop_properties():
assert rup_prop.initial_fault == "A"


def test_hf_config(tmp_path: Path):
def test_hf_config(tmp_path: Path) -> None:
test_realisation = tmp_path / "realisation.json"
test_realisation.write_text("{}")
hf_config = realisations.HFConfig.read_from_realisation_or_defaults(
Expand All @@ -344,7 +344,7 @@ def test_hf_config(tmp_path: Path):
)


def test_emod3d(tmp_path: Path):
def test_emod3d(tmp_path: Path) -> None:
test_realisation = tmp_path / "realisation.json"
test_realisation.write_text("{}")
emod3d = realisations.EMOD3DParameters.read_from_realisation_or_defaults(
Expand All @@ -363,7 +363,7 @@ def test_emod3d(tmp_path: Path):
)


def test_broadband_parameters(tmp_path: Path):
def test_broadband_parameters(tmp_path: Path) -> None:
test_realisation = tmp_path / "realisation.json"
broadband_parameters = realisations.BroadbandParameters(
flo=0.5, dt=0.005, fmidbot=0.5, fmin=0.25, site_amp_version="2014"
Expand All @@ -385,14 +385,14 @@ def test_broadband_parameters(tmp_path: Path):
)


def test_logtrail_init_empty():
def test_logtrail_init_empty() -> None:
"""Test LogTrail initialization with no log provided."""
trail = realisations.LogTrail([])
assert trail.log == []
assert trail._config_key == "log_trail"


def test_logtrail_init_with_log_entries():
def test_logtrail_init_with_log_entries() -> None:
"""Test LogTrail initialization with a list of LogEntry objects."""
entry1 = realisations.LogEntry(
utility="util1", args=["a"], version="1", timestamp=datetime.now()
Expand All @@ -404,7 +404,7 @@ def test_logtrail_init_with_log_entries():
assert trail.log == [entry1, entry2]


def test_logtrail_init_with_dicts_post_init():
def test_logtrail_init_with_dicts_post_init() -> None:
"""Test LogTrail post_init conversion of dicts to LogEntry objects."""
log_data = [
{
Expand All @@ -431,7 +431,7 @@ def test_logtrail_init_with_dicts_post_init():
assert trail.log[1].args == ["b"]


def test_logtrail_log_entry_method():
def test_logtrail_log_entry_method() -> None:
"""Test adding an entry using the log_entry method."""
trail = realisations.LogTrail([])
trail.log_entry("my_util", ["--flag", "value"])
Expand All @@ -442,7 +442,7 @@ def test_logtrail_log_entry_method():
assert isinstance(trail.log[0].timestamp, datetime)


def test_logtrail_to_dict():
def test_logtrail_to_dict() -> None:
"""Test converting LogTrail to a dictionary."""
ts = datetime.now()
entry1 = realisations.LogEntry(
Expand Down Expand Up @@ -480,7 +480,7 @@ def test_logtrail_to_dict():

def test_append_log_entry_file_exists_no_key(
tmp_path: Path,
):
) -> None:
"""Test append_log_entry when file exists but lacks the 'log_trail' key."""
realisation_file = tmp_path / "test_realisation.json"
# Create a file with unrelated content
Expand All @@ -504,15 +504,15 @@ def test_append_log_entry_file_exists_no_key(
assert data["log_trail"]["log"][0]["utility"] == "script_name.py"


def test_seeds():
def test_seeds() -> None:
seeds = realisations.Seeds.random_seeds()
assert all(
0 <= seed <= 2 ** (struct.Struct("i").size * 8 - 1) - 1
for seed in seeds.to_dict().values()
)


def test_velocity_model_1d(tmp_path: Path):
def test_velocity_model_1d(tmp_path: Path) -> None:
velocity_model_1d = realisations.VelocityModel1D(
model=pd.DataFrame(
{
Expand Down Expand Up @@ -564,7 +564,7 @@ def test_velocity_model_1d(tmp_path: Path):
)


def test_intensity_measure_calculation_parameters(tmp_path: Path):
def test_intensity_measure_calculation_parameters(tmp_path: Path) -> None:
im_calc_params = realisations.IntensityMeasureCalculationParameters(
ims=[im_calculation.IM("PGA"), im_calculation.IM("PGV")],
valid_periods=np.array([0.1, 0.2, 0.3]),
Expand Down Expand Up @@ -605,5 +605,5 @@ def test_defaults_are_loadable(
tmp_path: Path,
realisation_config: realisations.RealisationConfiguration,
defaults_version: defaults.DefaultsVersion,
):
) -> None:
realisation_config.read_from_defaults(defaults_version)
8 changes: 5 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from workflow import utils


def test_get_available_cores_slurm_cpus_on_node():
def test_get_available_cores_slurm_cpus_on_node() -> None:
with patch.dict(os.environ, {"SLURM_CPUS_ON_NODE": "4"}):
assert utils.get_available_cores() == 4

def get_available_cores_slurm_nprocs():

def get_available_cores_slurm_nprocs() -> None:
with patch.dict(os.environ, {"SLURM_NPROCS": "8"}):
assert utils.get_available_cores() == 8

def get_available_cores_no_slurm():

def get_available_cores_no_slurm() -> None:
with patch("multiprocessing.cpu_count", return_value=16):
assert utils.get_available_cores() == 16
4 changes: 3 additions & 1 deletion workflow/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def log_call(

def decorator(f: Callable) -> Callable: # numpydoc ignore=GL08
@functools.wraps(f)
def wrapper(*args, **kwargs): # numpydoc ignore=GL08
def wrapper(
*args: list[Any], **kwargs: dict[str, Any]
) -> Any: # numpydoc ignore=GL08
nonlocal exclude_args
signature = inspect.signature(f)
function_id = str(uuid.uuid4())
Expand Down
8 changes: 4 additions & 4 deletions workflow/realisations.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class Seeds(RealisationConfiguration):

@classmethod
def read_from_realisation_or_defaults(
cls, realisation_ffp: Path, *args
cls, realisation_ffp: Path, *args: list[Any]
) -> Self: # *args is to maintain compat with superclass (remove this and see the error in mypy).
"""Read seeds configuration from a realisation file or generate random seeds if not present.

Expand All @@ -256,7 +256,7 @@ def read_from_realisation_or_defaults(
----------
realisation_ffp : Path
The realisation filepath to read from.
*args : Any
*args : list
Ignored arguments.

Returns
Expand Down Expand Up @@ -315,7 +315,7 @@ class SourceConfig(RealisationConfiguration):
source_geometries: dict[str, IsSource]
"""Dictionary mapping source names to their definitions."""

def to_dict(self):
def to_dict(self) -> dict[str, Any]:
"""
Convert the object to a dictionary representation.

Expand Down Expand Up @@ -522,7 +522,7 @@ class VelocityModel1D(RealisationConfiguration):

model: pd.DataFrame

def write_velocity_model(self, velocity_model_path: Path):
def write_velocity_model(self, velocity_model_path: Path) -> None:
"""Write a 1D velocity model to the specified path.

Parameters
Expand Down
Loading