diff --git a/temporalio/bridge/runtime.py b/temporalio/bridge/runtime.py index 6235d8468..afc79f0f5 100644 --- a/temporalio/bridge/runtime.py +++ b/temporalio/bridge/runtime.py @@ -80,6 +80,7 @@ class PrometheusConfig: counters_total_suffix: bool unit_suffix: bool durations_as_seconds: bool + histogram_bucket_overrides: Optional[Mapping[str, Sequence[float]]] = None @dataclass(frozen=True) diff --git a/temporalio/bridge/src/runtime.rs b/temporalio/bridge/src/runtime.rs index 195f67864..9e1a83532 100644 --- a/temporalio/bridge/src/runtime.rs +++ b/temporalio/bridge/src/runtime.rs @@ -83,6 +83,7 @@ pub struct PrometheusConfig { counters_total_suffix: bool, unit_suffix: bool, durations_as_seconds: bool, + histogram_bucket_overrides: Option>>, } const FORWARD_LOG_BUFFER_SIZE: usize = 2048; @@ -347,6 +348,11 @@ impl TryFrom for Arc { if let Some(global_tags) = conf.global_tags { build.global_tags(global_tags); } + if let Some(overrides) = prom_conf.histogram_bucket_overrides { + build.histogram_bucket_overrides(temporal_sdk_core_api::telemetry::HistogramBucketOverrides { + overrides, + }); + } let prom_options = build.build().map_err(|err| { PyValueError::new_err(format!("Invalid Prometheus config: {}", err)) })?; diff --git a/temporalio/runtime.py b/temporalio/runtime.py index fe6a26ca9..809c06346 100644 --- a/temporalio/runtime.py +++ b/temporalio/runtime.py @@ -277,6 +277,7 @@ class PrometheusConfig: counters_total_suffix: bool = False unit_suffix: bool = False durations_as_seconds: bool = False + histogram_bucket_overrides: Optional[Mapping[str, Sequence[float]]] = None def _to_bridge_config(self) -> temporalio.bridge.runtime.PrometheusConfig: return temporalio.bridge.runtime.PrometheusConfig( @@ -284,6 +285,7 @@ def _to_bridge_config(self) -> temporalio.bridge.runtime.PrometheusConfig: counters_total_suffix=self.counters_total_suffix, unit_suffix=self.unit_suffix, durations_as_seconds=self.durations_as_seconds, + histogram_bucket_overrides=self.histogram_bucket_overrides, ) diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 31b713f83..4505ebfcf 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -1,7 +1,9 @@ import logging import logging.handlers import queue +import re import uuid +from datetime import timedelta from typing import List, cast from urllib.request import urlopen @@ -16,7 +18,7 @@ TelemetryFilter, ) from temporalio.worker import Worker -from tests.helpers import assert_eq_eventually, find_free_port +from tests.helpers import assert_eq_eventually, assert_eventually, find_free_port @workflow.defn @@ -181,3 +183,74 @@ async def has_log() -> bool: assert record.levelno == logging.WARNING assert record.name == f"{logger.name}-sdk_core::temporal_sdk_core::worker::workflow" assert record.temporal_log.fields["run_id"] == handle.result_run_id # type: ignore + + +async def test_prometheus_histogram_bucket_overrides(client: Client): + # Set up a Prometheus configuration with custom histogram bucket overrides + prom_addr = f"127.0.0.1:{find_free_port()}" + special_value = float(1234.5678) + histogram_overrides = { + "temporal_long_request_latency": [special_value / 2, special_value], + "custom_histogram": [special_value / 2, special_value], + } + + runtime = Runtime( + telemetry=TelemetryConfig( + metrics=PrometheusConfig( + bind_address=prom_addr, + counters_total_suffix=False, + unit_suffix=False, + durations_as_seconds=False, + histogram_bucket_overrides=histogram_overrides, + ), + ), + ) + + # Create a custom histogram metric + custom_histogram = runtime.metric_meter.create_histogram( + "custom_histogram", "Custom histogram", "ms" + ) + + # Record a value to the custom histogram + custom_histogram.record(600) + + # Create client with overrides + client_with_overrides = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=runtime, + ) + + async def run_workflow(client: Client): + task_queue = f"task-queue-{uuid.uuid4()}" + async with Worker( + client, + task_queue=task_queue, + workflows=[HelloWorkflow], + ): + assert "Hello, World!" == await client.execute_workflow( + HelloWorkflow.run, + "World", + id=f"workflow-{uuid.uuid4()}", + task_queue=task_queue, + ) + + await run_workflow(client_with_overrides) + + async def check_metrics() -> None: + with urlopen(url=f"http://{prom_addr}/metrics") as f: + metrics_output = f.read().decode("utf-8") + + for key, buckets in histogram_overrides.items(): + assert ( + key in metrics_output + ), f"Missing {key} in full output: {metrics_output}" + for bucket in buckets: + # expect to have {key}_bucket and le={bucket} in the same line with arbitrary strings between them + regex = re.compile(f'{key}_bucket.*le="{bucket}"') + assert regex.search( + metrics_output + ), f"Missing bucket for {key} in full output: {metrics_output}" + + # Wait for metrics to appear and match the expected buckets + await assert_eventually(check_metrics)