Skip to content

Commit f3d1b85

Browse files
authored
API key client option (#486)
Fixes #482
1 parent 477aa31 commit f3d1b85

File tree

8 files changed

+414
-217
lines changed

8 files changed

+414
-217
lines changed

temporalio/bridge/Cargo.lock

Lines changed: 282 additions & 171 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

temporalio/bridge/Cargo.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ crate-type = ["cdylib"]
1010
[dependencies]
1111
futures = "0.3"
1212
log = "0.4"
13-
once_cell = "1.16.0"
14-
parking_lot = "0.12"
15-
prost = "0.11"
16-
prost-types = "0.11"
13+
once_cell = "1.16"
14+
prost = "0.12"
15+
prost-types = "0.12"
1716
pyo3 = { version = "0.19", features = ["extension-module", "abi3-py38"] }
1817
pyo3-asyncio = { version = "0.19", features = ["tokio-runtime"] }
1918
pythonize = "0.19"
@@ -23,7 +22,7 @@ temporal-sdk-core-api = { version = "0.1.0", path = "./sdk-core/core-api" }
2322
temporal-sdk-core-protos = { version = "0.1.0", path = "./sdk-core/sdk-core-protos" }
2423
tokio = "1.26"
2524
tokio-stream = "0.1"
26-
tonic = "0.9"
25+
tonic = "0.11"
2726
tracing = "0.1"
2827
url = "2.2"
2928

temporalio/bridge/client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class ClientConfig:
5252

5353
target_url: str
5454
metadata: Mapping[str, str]
55+
api_key: Optional[str]
5556
identity: str
5657
tls_config: Optional[ClientTlsConfig]
5758
retry_config: Optional[ClientRetryConfig]
@@ -102,6 +103,10 @@ def update_metadata(self, metadata: Mapping[str, str]) -> None:
102103
"""Update underlying metadata on Core client."""
103104
self._ref.update_metadata(metadata)
104105

106+
def update_api_key(self, api_key: Optional[str]) -> None:
107+
"""Update underlying API key on Core client."""
108+
self._ref.update_api_key(api_key)
109+
105110
async def call(
106111
self,
107112
*,

temporalio/bridge/src/client.rs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
use parking_lot::RwLock;
21
use pyo3::exceptions::{PyException, PyRuntimeError, PyValueError};
32
use pyo3::prelude::*;
43
use std::collections::HashMap;
54
use std::str::FromStr;
6-
use std::sync::Arc;
75
use std::time::Duration;
86
use temporal_client::{
97
ClientKeepAliveConfig as CoreClientKeepAliveConfig, ClientOptions, ClientOptionsBuilder,
@@ -31,6 +29,7 @@ pub struct ClientConfig {
3129
client_name: String,
3230
client_version: String,
3331
metadata: HashMap<String, String>,
32+
api_key: Option<String>,
3433
identity: String,
3534
tls_config: Option<ClientTlsConfig>,
3635
retry_config: Option<ClientRetryConfig>,
@@ -75,20 +74,12 @@ pub fn connect_client<'a>(
7574
runtime_ref: &runtime::RuntimeRef,
7675
config: ClientConfig,
7776
) -> PyResult<&'a PyAny> {
78-
let headers = if config.metadata.is_empty() {
79-
None
80-
} else {
81-
Some(Arc::new(RwLock::new(config.metadata.clone())))
82-
};
8377
let opts: ClientOptions = config.try_into()?;
8478
let runtime = runtime_ref.runtime.clone();
8579
runtime_ref.runtime.future_into_py(py, async move {
8680
Ok(ClientRef {
8781
retry_client: opts
88-
.connect_no_namespace(
89-
runtime.core.telemetry().get_temporal_metric_meter(),
90-
headers,
91-
)
82+
.connect_no_namespace(runtime.core.telemetry().get_temporal_metric_meter())
9283
.await
9384
.map_err(|err| {
9485
PyRuntimeError::new_err(format!("Failed client connect: {}", err))
@@ -114,6 +105,10 @@ impl ClientRef {
114105
self.retry_client.get_client().set_headers(headers);
115106
}
116107

108+
fn update_api_key(&self, api_key: Option<String>) {
109+
self.retry_client.get_client().set_api_key(api_key);
110+
}
111+
117112
fn call_workflow_service<'p>(&self, py: Python<'p>, call: RpcCall) -> PyResult<&'p PyAny> {
118113
let mut retry_client = self.retry_client.clone();
119114
self.runtime.future_into_py(py, async move {
@@ -396,7 +391,9 @@ impl TryFrom<ClientConfig> for ClientOptions {
396391
opts.retry_config
397392
.map_or(RetryConfig::default(), |c| c.into()),
398393
)
399-
.keep_alive(opts.keep_alive_config.map(Into::into));
394+
.keep_alive(opts.keep_alive_config.map(Into::into))
395+
.headers(Some(opts.metadata))
396+
.api_key(opts.api_key);
400397
// Builder does not allow us to set option here, so we have to make
401398
// a conditional to even call it
402399
if let Some(tls_config) = opts.tls_config {

temporalio/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ async def connect(
9797
target_host: str,
9898
*,
9999
namespace: str = "default",
100+
api_key: Optional[str] = None,
100101
data_converter: temporalio.converter.DataConverter = temporalio.converter.DataConverter.default,
101102
interceptors: Sequence[Interceptor] = [],
102103
default_workflow_query_reject_condition: Optional[
@@ -116,6 +117,9 @@ async def connect(
116117
target_host: ``host:port`` for the Temporal server. For local
117118
development, this is often "localhost:7233".
118119
namespace: Namespace to use for client calls.
120+
api_key: API key for Temporal. This becomes the "Authorization"
121+
HTTP header with "Bearer " prepended. This is only set if RPC
122+
metadata doesn't already have an "authorization" key.
119123
data_converter: Data converter to use for all data conversions
120124
to/from payloads.
121125
interceptors: Set of interceptors that are chained together to allow
@@ -152,6 +156,7 @@ async def connect(
152156
"""
153157
connect_config = temporalio.service.ConnectConfig(
154158
target_host=target_host,
159+
api_key=api_key,
155160
tls=tls,
156161
retry_config=retry_config,
157162
keep_alive_config=keep_alive_config,
@@ -261,6 +266,22 @@ def rpc_metadata(self, value: Mapping[str, str]) -> None:
261266
self.service_client.config.rpc_metadata = value
262267
self.service_client.update_rpc_metadata(value)
263268

269+
@property
270+
def api_key(self) -> Optional[str]:
271+
"""API key for every call made by this client."""
272+
return self.service_client.config.api_key
273+
274+
@api_key.setter
275+
def api_key(self, value: Optional[str]) -> None:
276+
"""Update the API key for this client.
277+
278+
This is only set if RPCmetadata doesn't already have an "authorization"
279+
key.
280+
"""
281+
# Update config and perform update
282+
self.service_client.config.api_key = value
283+
self.service_client.update_api_key(value)
284+
264285
# Overload for no-param workflow
265286
@overload
266287
async def start_workflow(

temporalio/service.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class ConnectConfig:
120120
"""Config for connecting to the server."""
121121

122122
target_host: str
123+
api_key: Optional[str] = None
123124
tls: Union[bool, TLSConfig] = False
124125
retry_config: Optional[RetryConfig] = None
125126
keep_alive_config: Optional[KeepAliveConfig] = KeepAliveConfig.default
@@ -161,6 +162,7 @@ def _to_bridge_config(self) -> temporalio.bridge.client.ClientConfig:
161162

162163
return temporalio.bridge.client.ClientConfig(
163164
target_url=target_url,
165+
api_key=self.api_key,
164166
tls_config=tls_config,
165167
retry_config=self.retry_config._to_bridge_config()
166168
if self.retry_config
@@ -238,6 +240,11 @@ def update_rpc_metadata(self, metadata: Mapping[str, str]) -> None:
238240
"""Update service client's RPC metadata."""
239241
raise NotImplementedError
240242

243+
@abstractmethod
244+
def update_api_key(self, api_key: Optional[str]) -> None:
245+
"""Update service client's API key."""
246+
raise NotImplementedError
247+
241248
@abstractmethod
242249
async def _rpc_call(
243250
self,
@@ -740,6 +747,14 @@ def update_rpc_metadata(self, metadata: Mapping[str, str]) -> None:
740747
if self._bridge_client:
741748
self._bridge_client.update_metadata(metadata)
742749

750+
def update_api_key(self, api_key: Optional[str]) -> None:
751+
"""Update Core client API key."""
752+
# Mutate the bridge config and then only mutate the running client
753+
# metadata if already connected
754+
self._bridge_config.api_key = api_key
755+
if self._bridge_client:
756+
self._bridge_client.update_api_key(api_key)
757+
743758
async def _rpc_call(
744759
self,
745760
rpc: str,

tests/api/test_grpc_stub.py

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import timedelta
2+
from typing import Mapping
23

34
from google.protobuf.empty_pb2 import Empty
45
from google.protobuf.timestamp_pb2 import Timestamp
@@ -27,12 +28,6 @@
2728
from temporalio.client import Client
2829

2930

30-
def assert_metadata(context: ServicerContext, **kwargs) -> None:
31-
metadata = dict(context.invocation_metadata())
32-
for k, v in kwargs.items():
33-
assert metadata.get(k) == v
34-
35-
3631
def assert_time_remaining(context: ServicerContext, expected: int) -> None:
3732
# Give or take 5 seconds
3833
assert expected - 5 <= context.time_remaining() <= expected + 5
@@ -41,24 +36,26 @@ def assert_time_remaining(context: ServicerContext, expected: int) -> None:
4136
class SimpleWorkflowServer(WorkflowServiceServicer):
4237
def __init__(self) -> None:
4338
super().__init__()
44-
self.expected_client_key_value = "client_value"
39+
self.last_metadata: Mapping[str, str] = {}
40+
41+
def assert_last_metadata(self, expected: Mapping[str, str]) -> None:
42+
for k, v in expected.items():
43+
assert self.last_metadata.get(k) == v
4544

4645
async def GetSystemInfo( # type: ignore # https://github.com/nipunn1313/mypy-protobuf/issues/216
4746
self,
4847
request: GetSystemInfoRequest,
4948
context: ServicerContext,
5049
) -> GetSystemInfoResponse:
51-
assert_metadata(context, client_key=self.expected_client_key_value)
50+
self.last_metadata = dict(context.invocation_metadata())
5251
return GetSystemInfoResponse()
5352

5453
async def CountWorkflowExecutions( # type: ignore # https://github.com/nipunn1313/mypy-protobuf/issues/216
5554
self,
5655
request: CountWorkflowExecutionsRequest,
5756
context: ServicerContext,
5857
) -> CountWorkflowExecutionsResponse:
59-
assert_metadata(
60-
context, client_key=self.expected_client_key_value, rpc_key="rpc_value"
61-
)
58+
self.last_metadata = dict(context.invocation_metadata())
6259
assert_time_remaining(context, 123)
6360
assert request.namespace == "my namespace"
6461
assert request.query == "my query"
@@ -71,7 +68,6 @@ async def DeleteNamespace( # type: ignore # https://github.com/nipunn1313/mypy-
7168
request: DeleteNamespaceRequest,
7269
context: ServicerContext,
7370
) -> DeleteNamespaceResponse:
74-
assert_metadata(context, client_key="client_value", rpc_key="rpc_value")
7571
assert_time_remaining(context, 123)
7672
assert request.namespace == "my namespace"
7773
return DeleteNamespaceResponse(deleted_namespace="my namespace response")
@@ -83,7 +79,6 @@ async def GetCurrentTime( # type: ignore # https://github.com/nipunn1313/mypy-p
8379
request: Empty,
8480
context: ServicerContext,
8581
) -> GetCurrentTimeResponse:
86-
assert_metadata(context, client_key="client_value", rpc_key="rpc_value")
8782
assert_time_remaining(context, 123)
8883
return GetCurrentTimeResponse(time=Timestamp(seconds=123))
8984

@@ -101,34 +96,88 @@ async def test_python_grpc_stub():
10196
await server.start()
10297

10398
# Use our client to make a call to each service
104-
client = await Client.connect(
105-
f"localhost:{port}", rpc_metadata={"client_key": "client_value"}
106-
)
107-
metadata = {"rpc_key": "rpc_value"}
99+
client = await Client.connect(f"localhost:{port}")
108100
timeout = timedelta(seconds=123)
109101
count_resp = await client.workflow_service.count_workflow_executions(
110102
CountWorkflowExecutionsRequest(namespace="my namespace", query="my query"),
111-
metadata=metadata,
112103
timeout=timeout,
113104
)
114105
assert count_resp.count == 123
115106
del_resp = await client.operator_service.delete_namespace(
116107
DeleteNamespaceRequest(namespace="my namespace"),
117-
metadata=metadata,
118108
timeout=timeout,
119109
)
120110
assert del_resp.deleted_namespace == "my namespace response"
121-
time_resp = await client.test_service.get_current_time(
122-
Empty(), metadata=metadata, timeout=timeout
123-
)
111+
time_resp = await client.test_service.get_current_time(Empty(), timeout=timeout)
124112
assert time_resp.time.seconds == 123
125113

126-
# Make another call to get system info after changing the client-level
127-
# header
128-
new_metadata = dict(client.rpc_metadata)
129-
new_metadata["client_key"] = "changed_value"
130-
client.rpc_metadata = new_metadata
131-
workflow_server.expected_client_key_value = "changed_value"
114+
await server.stop(grace=None)
115+
116+
117+
async def test_grpc_metadata():
118+
# Start server
119+
server = grpc_server()
120+
workflow_server = SimpleWorkflowServer() # type: ignore[abstract]
121+
add_WorkflowServiceServicer_to_server(workflow_server, server)
122+
port = server.add_insecure_port("[::]:0")
123+
await server.start()
124+
125+
# Connect and confirm metadata of get system info call
126+
client = await Client.connect(
127+
f"localhost:{port}",
128+
api_key="my-api-key",
129+
rpc_metadata={"my-meta-key": "my-meta-val"},
130+
)
131+
workflow_server.assert_last_metadata(
132+
{
133+
"authorization": "Bearer my-api-key",
134+
"my-meta-key": "my-meta-val",
135+
}
136+
)
137+
138+
# Overwrite API key via client RPC metadata, confirm there
139+
client.rpc_metadata = {
140+
"authorization": "my-auth-val1",
141+
"my-meta-key": "my-meta-val",
142+
}
132143
await client.workflow_service.get_system_info(GetSystemInfoRequest())
144+
workflow_server.assert_last_metadata(
145+
{
146+
"authorization": "my-auth-val1",
147+
"my-meta-key": "my-meta-val",
148+
}
149+
)
150+
client.rpc_metadata = {"my-meta-key": "my-meta-val"}
151+
152+
# Overwrite API key via call RPC metadata, confirm there
153+
await client.workflow_service.get_system_info(
154+
GetSystemInfoRequest(), metadata={"authorization": "my-auth-val2"}
155+
)
156+
workflow_server.assert_last_metadata(
157+
{
158+
"authorization": "my-auth-val2",
159+
"my-meta-key": "my-meta-val",
160+
}
161+
)
162+
163+
# Update API key, confirm updated
164+
client.api_key = "my-new-api-key"
165+
await client.workflow_service.get_system_info(GetSystemInfoRequest())
166+
workflow_server.assert_last_metadata(
167+
{
168+
"authorization": "Bearer my-new-api-key",
169+
"my-meta-key": "my-meta-val",
170+
}
171+
)
172+
173+
# Remove API key, confirm removed
174+
client.api_key = None
175+
await client.workflow_service.get_system_info(GetSystemInfoRequest())
176+
workflow_server.assert_last_metadata(
177+
{
178+
"my-meta-key": "my-meta-val",
179+
}
180+
)
181+
assert "authorization" not in workflow_server.last_metadata
133182

134183
await server.stop(grace=None)

0 commit comments

Comments
 (0)