Skip to content

Commit 7867202

Browse files
aabmassocelotllzchenowais
authored
Fix race in set_tracer_provider() (open-telemetry#2182)
* Fix race in set_tracer_provider * refactor _reset_globals to a test util * get rid of "Mixin" name and simplify code a bit * add some comments to concurrency_test.py * actually respect log option Co-authored-by: Diego Hurtado <[email protected]> Co-authored-by: Leighton Chen <[email protected]> Co-authored-by: Owais Lone <[email protected]>
1 parent 0770fcd commit 7867202

File tree

12 files changed

+313
-61
lines changed

12 files changed

+313
-61
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

77
## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.5.0-0.24b0...HEAD)
8+
- Fix race in `set_tracer_provider()`
9+
([#2182](https://github.com/open-telemetry/opentelemetry-python/pull/2182))
810
- Automatically load OTEL environment variables as options for `opentelemetry-instrument`
911
([#1969](https://github.com/open-telemetry/opentelemetry-python/pull/1969))
1012
- `opentelemetry-semantic-conventions` Update to semantic conventions v1.6.1

exporter/opentelemetry-exporter-jaeger-thrift/tests/test_jaeger_exporter_thrift.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import unittest
1717
from unittest import mock
18-
from unittest.mock import patch
1918

2019
# pylint:disable=no-name-in-module
2120
# pylint:disable=import-error
@@ -38,6 +37,7 @@
3837
from opentelemetry.sdk.resources import SERVICE_NAME
3938
from opentelemetry.sdk.trace import Resource, TracerProvider
4039
from opentelemetry.sdk.util.instrumentation import InstrumentationInfo
40+
from opentelemetry.test.globals_test import TraceGlobalsTest
4141
from opentelemetry.test.spantestutil import (
4242
get_span_with_dropped_attributes_events_links,
4343
)
@@ -53,7 +53,7 @@ def _translate_spans_with_dropped_attributes():
5353
return translate._translate(ThriftTranslator(max_tag_value_length=5))
5454

5555

56-
class TestJaegerExporter(unittest.TestCase):
56+
class TestJaegerExporter(TraceGlobalsTest, unittest.TestCase):
5757
def setUp(self):
5858
# create and save span to be used in tests
5959
self.context = trace_api.SpanContext(
@@ -73,7 +73,6 @@ def setUp(self):
7373
self._test_span.end(end_time=3)
7474
# pylint: disable=protected-access
7575

76-
@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
7776
def test_constructor_default(self):
7877
# pylint: disable=protected-access
7978
"""Test the default values assigned by constructor."""
@@ -98,7 +97,6 @@ def test_constructor_default(self):
9897
self.assertTrue(exporter._agent_client is not None)
9998
self.assertIsNone(exporter._max_tag_value_length)
10099

101-
@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
102100
def test_constructor_explicit(self):
103101
# pylint: disable=protected-access
104102
"""Test the constructor passing all the options."""
@@ -143,7 +141,6 @@ def test_constructor_explicit(self):
143141
self.assertTrue(exporter._collector_http_client.auth is None)
144142
self.assertEqual(exporter._max_tag_value_length, 42)
145143

146-
@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
147144
def test_constructor_by_environment_variables(self):
148145
# pylint: disable=protected-access
149146
"""Test the constructor using Environment Variables."""
@@ -198,7 +195,6 @@ def test_constructor_by_environment_variables(self):
198195
self.assertTrue(exporter._collector_http_client.auth is None)
199196
environ_patcher.stop()
200197

201-
@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
202198
def test_constructor_with_no_traceprovider_resource(self):
203199

204200
"""Test the constructor when there is no resource attached to trace_provider"""
@@ -480,7 +476,6 @@ def test_translate_to_jaeger(self):
480476

481477
self.assertEqual(spans, expected_spans)
482478

483-
@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
484479
def test_export(self):
485480

486481
"""Test that agent and/or collector are invoked"""
@@ -511,9 +506,7 @@ def test_export(self):
511506
exporter.export((self._test_span,))
512507
self.assertEqual(agent_client_mock.emit.call_count, 1)
513508
self.assertEqual(collector_mock.submit.call_count, 1)
514-
# trace_api._TRACER_PROVIDER = None
515509

516-
@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
517510
def test_export_span_service_name(self):
518511
trace_api.set_tracer_provider(
519512
TracerProvider(

exporter/opentelemetry-exporter-opencensus/tests/test_otcollector_trace_exporter.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,12 @@
2929
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
3030
from opentelemetry.sdk.trace import TracerProvider
3131
from opentelemetry.sdk.trace.export import SpanExportResult
32+
from opentelemetry.test.globals_test import TraceGlobalsTest
3233
from opentelemetry.trace import TraceFlags
3334

3435

3536
# pylint: disable=no-member
36-
class TestCollectorSpanExporter(unittest.TestCase):
37-
@mock.patch(
38-
"opentelemetry.exporter.opencensus.trace_exporter.trace._TRACER_PROVIDER",
39-
None,
40-
)
37+
class TestCollectorSpanExporter(TraceGlobalsTest, unittest.TestCase):
4138
def test_constructor(self):
4239
mock_get_node = mock.Mock()
4340
patch = mock.patch(
@@ -329,10 +326,6 @@ def test_export(self):
329326
getattr(output_identifier, "host_name"), "testHostName"
330327
)
331328

332-
@mock.patch(
333-
"opentelemetry.exporter.opencensus.trace_exporter.trace._TRACER_PROVIDER",
334-
None,
335-
)
336329
def test_export_service_name(self):
337330
trace_api.set_tracer_provider(
338331
TracerProvider(

opentelemetry-api/src/opentelemetry/trace/__init__.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
)
109109
from opentelemetry.trace.status import Status, StatusCode
110110
from opentelemetry.util import types
111+
from opentelemetry.util._once import Once
111112
from opentelemetry.util._providers import _load_provider
112113

113114
logger = getLogger(__name__)
@@ -452,8 +453,9 @@ def start_as_current_span(
452453
yield INVALID_SPAN
453454

454455

455-
_TRACER_PROVIDER = None
456-
_PROXY_TRACER_PROVIDER = None
456+
_TRACER_PROVIDER_SET_ONCE = Once()
457+
_TRACER_PROVIDER: Optional[TracerProvider] = None
458+
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
457459

458460

459461
def get_tracer(
@@ -476,40 +478,40 @@ def get_tracer(
476478
)
477479

478480

481+
def _set_tracer_provider(tracer_provider: TracerProvider, log: bool) -> None:
482+
def set_tp() -> None:
483+
global _TRACER_PROVIDER # pylint: disable=global-statement
484+
_TRACER_PROVIDER = tracer_provider
485+
486+
did_set = _TRACER_PROVIDER_SET_ONCE.do_once(set_tp)
487+
488+
if log and not did_set:
489+
logger.warning("Overriding of current TracerProvider is not allowed")
490+
491+
479492
def set_tracer_provider(tracer_provider: TracerProvider) -> None:
480493
"""Sets the current global :class:`~.TracerProvider` object.
481494
482495
This can only be done once, a warning will be logged if any furter attempt
483496
is made.
484497
"""
485-
global _TRACER_PROVIDER # pylint: disable=global-statement
486-
487-
if _TRACER_PROVIDER is not None:
488-
logger.warning("Overriding of current TracerProvider is not allowed")
489-
return
490-
491-
_TRACER_PROVIDER = tracer_provider
498+
_set_tracer_provider(tracer_provider, log=True)
492499

493500

494501
def get_tracer_provider() -> TracerProvider:
495502
"""Gets the current global :class:`~.TracerProvider` object."""
496-
# pylint: disable=global-statement
497-
global _TRACER_PROVIDER
498-
global _PROXY_TRACER_PROVIDER
499-
500503
if _TRACER_PROVIDER is None:
501504
# if a global tracer provider has not been set either via code or env
502505
# vars, return a proxy tracer provider
503506
if OTEL_PYTHON_TRACER_PROVIDER not in os.environ:
504-
if not _PROXY_TRACER_PROVIDER:
505-
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
506507
return _PROXY_TRACER_PROVIDER
507508

508-
_TRACER_PROVIDER = cast( # type: ignore
509-
"TracerProvider",
510-
_load_provider(OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"),
509+
tracer_provider: TracerProvider = _load_provider(
510+
OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"
511511
)
512-
return _TRACER_PROVIDER
512+
_set_tracer_provider(tracer_provider, log=False)
513+
# _TRACER_PROVIDER will have been set by one thread
514+
return cast("TracerProvider", _TRACER_PROVIDER)
513515

514516

515517
@contextmanager # type: ignore
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from threading import Lock
16+
from typing import Callable
17+
18+
19+
class Once:
20+
"""Execute a function exactly once and block all callers until the function returns
21+
22+
Same as golang's `sync.Once <https://pkg.go.dev/sync#Once>`_
23+
"""
24+
25+
def __init__(self) -> None:
26+
self._lock = Lock()
27+
self._done = False
28+
29+
def do_once(self, func: Callable[[], None]) -> bool:
30+
"""Execute ``func`` if it hasn't been executed or return.
31+
32+
Will block until ``func`` has been called by one thread.
33+
34+
Returns:
35+
Whether or not ``func`` was executed in this call
36+
"""
37+
38+
# fast path, try to avoid locking
39+
if self._done:
40+
return False
41+
42+
with self._lock:
43+
if not self._done:
44+
func()
45+
self._done = True
46+
return True
47+
return False

opentelemetry-api/tests/trace/test_globals.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import unittest
2-
from unittest.mock import patch
2+
from unittest.mock import Mock, patch
33

44
from opentelemetry import context, trace
5+
from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
6+
from opentelemetry.test.globals_test import TraceGlobalsTest
57
from opentelemetry.trace.status import Status, StatusCode
68

79

@@ -25,25 +27,60 @@ def record_exception(
2527
self.recorded_exception = exception
2628

2729

28-
class TestGlobals(unittest.TestCase):
29-
def setUp(self):
30-
self._patcher = patch("opentelemetry.trace._TRACER_PROVIDER")
31-
self._mock_tracer_provider = self._patcher.start()
32-
33-
def tearDown(self) -> None:
34-
self._patcher.stop()
35-
36-
def test_get_tracer(self):
30+
class TestGlobals(TraceGlobalsTest, unittest.TestCase):
31+
@staticmethod
32+
@patch("opentelemetry.trace._TRACER_PROVIDER")
33+
def test_get_tracer(mock_tracer_provider): # type: ignore
3734
"""trace.get_tracer should proxy to the global tracer provider."""
3835
trace.get_tracer("foo", "var")
39-
self._mock_tracer_provider.get_tracer.assert_called_with(
40-
"foo", "var", None
41-
)
42-
mock_provider = unittest.mock.Mock()
36+
mock_tracer_provider.get_tracer.assert_called_with("foo", "var", None)
37+
mock_provider = Mock()
4338
trace.get_tracer("foo", "var", mock_provider)
4439
mock_provider.get_tracer.assert_called_with("foo", "var", None)
4540

4641

42+
class TestGlobalsConcurrency(TraceGlobalsTest, ConcurrencyTestBase):
43+
@patch("opentelemetry.trace.logger")
44+
def test_set_tracer_provider_many_threads(self, mock_logger) -> None: # type: ignore
45+
mock_logger.warning = MockFunc()
46+
47+
def do_concurrently() -> Mock:
48+
# first get a proxy tracer
49+
proxy_tracer = trace.ProxyTracerProvider().get_tracer("foo")
50+
51+
# try to set the global tracer provider
52+
mock_tracer_provider = Mock(get_tracer=MockFunc())
53+
trace.set_tracer_provider(mock_tracer_provider)
54+
55+
# start a span through the proxy which will call through to the mock provider
56+
proxy_tracer.start_span("foo")
57+
58+
return mock_tracer_provider
59+
60+
num_threads = 100
61+
mock_tracer_providers = self.run_with_many_threads(
62+
do_concurrently,
63+
num_threads=num_threads,
64+
)
65+
66+
# despite trying to set tracer provider many times, only one of the
67+
# mock_tracer_providers should have stuck and been called from
68+
# proxy_tracer.start_span()
69+
mock_tps_with_any_call = [
70+
mock
71+
for mock in mock_tracer_providers
72+
if mock.get_tracer.call_count > 0
73+
]
74+
75+
self.assertEqual(len(mock_tps_with_any_call), 1)
76+
self.assertEqual(
77+
mock_tps_with_any_call[0].get_tracer.call_count, num_threads
78+
)
79+
80+
# should have warned everytime except for the successful set
81+
self.assertEqual(mock_logger.warning.call_count, num_threads - 1)
82+
83+
4784
class TestTracer(unittest.TestCase):
4885
def setUp(self):
4986
# pylint: disable=protected-access

opentelemetry-api/tests/trace/test_proxy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import unittest
1818

1919
from opentelemetry import trace
20+
from opentelemetry.test.globals_test import TraceGlobalsTest
2021
from opentelemetry.trace.span import INVALID_SPAN_CONTEXT, NonRecordingSpan
2122

2223

@@ -39,10 +40,8 @@ class TestSpan(NonRecordingSpan):
3940
pass
4041

4142

42-
class TestProxy(unittest.TestCase):
43+
class TestProxy(TraceGlobalsTest, unittest.TestCase):
4344
def test_proxy_tracer(self):
44-
original_provider = trace._TRACER_PROVIDER
45-
4645
provider = trace.get_tracer_provider()
4746
# proxy provider
4847
self.assertIsInstance(provider, trace.ProxyTracerProvider)
@@ -60,6 +59,9 @@ def test_proxy_tracer(self):
6059
# set a real provider
6160
trace.set_tracer_provider(TestProvider())
6261

62+
# get_tracer_provider() now returns the real provider
63+
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)
64+
6365
# tracer provider now returns real instance
6466
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)
6567

@@ -71,5 +73,3 @@ def test_proxy_tracer(self):
7173
# creates real spans
7274
with tracer.start_span("") as span:
7375
self.assertIsInstance(span, TestSpan)
74-
75-
trace._TRACER_PROVIDER = original_provider

0 commit comments

Comments
 (0)