|
16 | 16 |
|
17 | 17 | # type: ignore
|
18 | 18 |
|
19 |
| - |
20 | 19 | from abc import ABC, abstractmethod
|
| 20 | +from collections import abc as collections_abc |
21 | 21 | from logging import getLogger
|
22 | 22 | from re import compile as compile_
|
23 |
| -from types import GeneratorType |
| 23 | +from typing import Callable, Generator, Iterable, Union |
24 | 24 |
|
25 | 25 | from opentelemetry.metrics.measurement import Measurement
|
26 | 26 |
|
| 27 | +_TInstrumentCallback = Callable[[], Iterable[Measurement]] |
| 28 | +_TInstrumentCallbackGenerator = Generator[Iterable[Measurement], None, None] |
| 29 | +TCallback = Union[_TInstrumentCallback, _TInstrumentCallbackGenerator] |
| 30 | + |
| 31 | + |
27 | 32 | _logger = getLogger(__name__)
|
28 | 33 |
|
29 | 34 |
|
@@ -75,30 +80,65 @@ class Synchronous(Instrument):
|
75 | 80 | class Asynchronous(Instrument):
|
76 | 81 | @abstractmethod
|
77 | 82 | def __init__(
|
78 |
| - self, name, callback, *args, unit="", description="", **kwargs |
| 83 | + self, |
| 84 | + name, |
| 85 | + callback: TCallback, |
| 86 | + *args, |
| 87 | + unit="", |
| 88 | + description="", |
| 89 | + **kwargs |
79 | 90 | ):
|
80 |
| - super().__init__(name, *args, unit=unit, description="", **kwargs) |
81 |
| - |
82 |
| - if not isinstance(callback, GeneratorType): |
83 |
| - _logger.error("callback must be a generator") |
| 91 | + super().__init__( |
| 92 | + name, *args, unit=unit, description=description, **kwargs |
| 93 | + ) |
84 | 94 |
|
85 |
| - else: |
86 |
| - super().__init__( |
87 |
| - name, unit=unit, description=description, *args, **kwargs |
88 |
| - ) |
| 95 | + if isinstance(callback, collections_abc.Callable): |
89 | 96 | self._callback = callback
|
| 97 | + elif isinstance(callback, collections_abc.Generator): |
| 98 | + self._callback = self._wrap_generator_callback(callback) |
| 99 | + else: |
| 100 | + _logger.error("callback must be a callable or generator") |
| 101 | + |
| 102 | + def _wrap_generator_callback( |
| 103 | + self, |
| 104 | + generator_callback: _TInstrumentCallbackGenerator, |
| 105 | + ) -> _TInstrumentCallback: |
| 106 | + """Wraps a generator style callback into a callable one""" |
| 107 | + has_items = True |
| 108 | + |
| 109 | + def inner() -> Iterable[Measurement]: |
| 110 | + nonlocal has_items |
| 111 | + if not has_items: |
| 112 | + return [] |
| 113 | + |
| 114 | + try: |
| 115 | + return next(generator_callback) |
| 116 | + except StopIteration: |
| 117 | + has_items = False |
| 118 | + _logger.error( |
| 119 | + "callback generator for instrument %s ran out of measurements", |
| 120 | + self._name, |
| 121 | + ) |
| 122 | + return [] |
| 123 | + |
| 124 | + return inner |
90 | 125 |
|
91 |
| - @property |
92 | 126 | def callback(self):
|
93 |
| - def function(): |
94 |
| - measurement = next(self._callback) |
| 127 | + measurements = self._callback() |
| 128 | + if not isinstance(measurements, collections_abc.Iterable): |
| 129 | + _logger.error( |
| 130 | + "Callback must return an iterable of Measurement, got %s", |
| 131 | + type(measurements), |
| 132 | + ) |
| 133 | + return |
| 134 | + for measurement in measurements: |
95 | 135 | if not isinstance(measurement, Measurement):
|
96 |
| - _logger.error("Callback must return a Measurement") |
97 |
| - return None |
98 |
| - |
99 |
| - return measurement |
100 |
| - |
101 |
| - return function |
| 136 | + _logger.error( |
| 137 | + "Callback must return an iterable of Measurement, " |
| 138 | + "iterable contained type %s", |
| 139 | + type(measurement), |
| 140 | + ) |
| 141 | + yield measurement |
102 | 142 |
|
103 | 143 |
|
104 | 144 | class _Adding(Instrument):
|
@@ -147,18 +187,13 @@ def add(self, amount, attributes=None):
|
147 | 187 |
|
148 | 188 |
|
149 | 189 | class ObservableCounter(_Monotonic, Asynchronous):
|
150 |
| - @property |
151 | 190 | def callback(self):
|
152 |
| - def function(): |
153 |
| - measurement = super(ObservableCounter, self).callback() |
| 191 | + measurements = super().callback() |
154 | 192 |
|
155 |
| - if measurement is not None and measurement.value < 0: |
| 193 | + for measurement in measurements: |
| 194 | + if measurement.value < 0: |
156 | 195 | _logger.error("Amount must be non-negative")
|
157 |
| - return None |
158 |
| - |
159 |
| - return measurement |
160 |
| - |
161 |
| - return function |
| 196 | + yield measurement |
162 | 197 |
|
163 | 198 |
|
164 | 199 | class DefaultObservableCounter(ObservableCounter):
|
|
0 commit comments