Skip to content

Commit b8dff93

Browse files
authored
chore: Fix Streaming Callback types (#9441)
* Fix types * Add select_streaming_callback
1 parent c82a337 commit b8dff93

File tree

6 files changed

+34
-28
lines changed

6 files changed

+34
-28
lines changed

haystack/components/generators/azure.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import os
6-
from typing import Any, Callable, Dict, Optional
6+
from typing import Any, Dict, Optional
77

88
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
99

1010
from haystack import component, default_from_dict, default_to_dict
1111
from haystack.components.generators import OpenAIGenerator
12-
from haystack.dataclasses import StreamingChunk
12+
from haystack.dataclasses import StreamingCallbackT
1313
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1414
from haystack.utils.http_client import init_http_client
1515

@@ -63,7 +63,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
6363
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
6464
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
6565
organization: Optional[str] = None,
66-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
66+
streaming_callback: Optional[StreamingCallbackT] = None,
6767
system_prompt: Optional[str] = None,
6868
timeout: Optional[float] = None,
6969
max_retries: Optional[int] = None,

haystack/components/generators/chat/hugging_face_local.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
1111

1212
from haystack import component, default_from_dict, default_to_dict, logging
13-
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
13+
from haystack.dataclasses import ChatMessage, StreamingCallbackT, ToolCall, select_streaming_callback
1414
from haystack.lazy_imports import LazyImport
1515
from haystack.tools import (
1616
Tool,
@@ -130,7 +130,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
130130
generation_kwargs: Optional[Dict[str, Any]] = None,
131131
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
132132
stop_words: Optional[List[str]] = None,
133-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
133+
streaming_callback: Optional[StreamingCallbackT] = None,
134134
tools: Optional[Union[List[Tool], Toolset]] = None,
135135
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
136136
async_executor: Optional[ThreadPoolExecutor] = None,
@@ -330,7 +330,7 @@ def run(
330330
self,
331331
messages: List[ChatMessage],
332332
generation_kwargs: Optional[Dict[str, Any]] = None,
333-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
333+
streaming_callback: Optional[StreamingCallbackT] = None,
334334
tools: Optional[Union[List[Tool], Toolset]] = None,
335335
):
336336
"""
@@ -492,7 +492,7 @@ async def run_async(
492492
self,
493493
messages: List[ChatMessage],
494494
generation_kwargs: Optional[Dict[str, Any]] = None,
495-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
495+
streaming_callback: Optional[StreamingCallbackT] = None,
496496
tools: Optional[Union[List[Tool], Toolset]] = None,
497497
):
498498
"""
@@ -546,7 +546,7 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
546546
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
547547
generation_kwargs: Dict[str, Any],
548548
stop_words: Optional[List[str]],
549-
streaming_callback: Callable[[StreamingChunk], None],
549+
streaming_callback: StreamingCallbackT,
550550
):
551551
"""
552552
Handles async streaming generation of responses.

haystack/components/generators/hugging_face_api.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
from dataclasses import asdict
66
from datetime import datetime
7-
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, cast
7+
from typing import Any, Dict, Iterable, List, Optional, Union, cast
88

99
from haystack import component, default_from_dict, default_to_dict
10-
from haystack.dataclasses import StreamingChunk
10+
from haystack.dataclasses import StreamingCallbackT, StreamingChunk, select_streaming_callback
1111
from haystack.lazy_imports import LazyImport
1212
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1313
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
@@ -80,7 +80,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
8080
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
8181
generation_kwargs: Optional[Dict[str, Any]] = None,
8282
stop_words: Optional[List[str]] = None,
83-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
83+
streaming_callback: Optional[StreamingCallbackT] = None,
8484
):
8585
"""
8686
Initialize the HuggingFaceAPIGenerator instance.
@@ -180,7 +180,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIGenerator":
180180
def run(
181181
self,
182182
prompt: str,
183-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
183+
streaming_callback: Optional[StreamingCallbackT] = None,
184184
generation_kwargs: Optional[Dict[str, Any]] = None,
185185
):
186186
"""
@@ -200,7 +200,9 @@ def run(
200200
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
201201

202202
# check if streaming_callback is passed
203-
streaming_callback = streaming_callback or self.streaming_callback
203+
streaming_callback = select_streaming_callback(
204+
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
205+
)
204206

205207
hf_output = self._client.text_generation(
206208
prompt, details=True, stream=streaming_callback is not None, **generation_kwargs
@@ -213,7 +215,7 @@ def run(
213215
return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))
214216

215217
def _stream_and_build_response(
216-
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None]
218+
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: StreamingCallbackT
217219
):
218220
chunks: List[StreamingChunk] = []
219221
first_chunk_time = None

haystack/components/generators/hugging_face_local.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Callable, Dict, List, Literal, Optional, cast
5+
from typing import Any, Dict, List, Literal, Optional, cast
66

77
from haystack import component, default_from_dict, default_to_dict, logging
8-
from haystack.dataclasses import StreamingChunk
8+
from haystack.dataclasses import StreamingCallbackT, select_streaming_callback
99
from haystack.lazy_imports import LazyImport
1010
from haystack.utils import (
1111
ComponentDevice,
@@ -63,7 +63,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
6363
generation_kwargs: Optional[Dict[str, Any]] = None,
6464
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
6565
stop_words: Optional[List[str]] = None,
66-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
66+
streaming_callback: Optional[StreamingCallbackT] = None,
6767
):
6868
"""
6969
Creates an instance of a HuggingFaceLocalGenerator.
@@ -210,7 +210,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalGenerator":
210210
def run(
211211
self,
212212
prompt: str,
213-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
213+
streaming_callback: Optional[StreamingCallbackT] = None,
214214
generation_kwargs: Optional[Dict[str, Any]] = None,
215215
):
216216
"""
@@ -239,7 +239,9 @@ def run(
239239
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
240240

241241
# check if streaming_callback is passed
242-
streaming_callback = streaming_callback or self.streaming_callback
242+
streaming_callback = select_streaming_callback(
243+
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
244+
)
243245

244246
if streaming_callback:
245247
num_responses = updated_generation_kwargs.get("num_return_sequences", 1)

haystack/components/generators/openai.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
import os
66
from datetime import datetime
7-
from typing import Any, Callable, Dict, List, Optional, Union
7+
from typing import Any, Dict, List, Optional, Union
88

99
from openai import OpenAI, Stream
1010
from openai.types.chat import ChatCompletion, ChatCompletionChunk
1111

1212
from haystack import component, default_from_dict, default_to_dict, logging
13-
from haystack.dataclasses import ChatMessage, StreamingChunk
13+
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk, select_streaming_callback
1414
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1515
from haystack.utils.http_client import init_http_client
1616

@@ -54,7 +54,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
5454
self,
5555
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
5656
model: str = "gpt-4o-mini",
57-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
57+
streaming_callback: Optional[StreamingCallbackT] = None,
5858
api_base_url: Optional[str] = None,
5959
organization: Optional[str] = None,
6060
system_prompt: Optional[str] = None,
@@ -178,7 +178,7 @@ def run(
178178
self,
179179
prompt: str,
180180
system_prompt: Optional[str] = None,
181-
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
181+
streaming_callback: Optional[StreamingCallbackT] = None,
182182
generation_kwargs: Optional[Dict[str, Any]] = None,
183183
):
184184
"""
@@ -211,7 +211,9 @@ def run(
211211
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
212212

213213
# check if streaming_callback is passed
214-
streaming_callback = streaming_callback or self.streaming_callback
214+
streaming_callback = select_streaming_callback(
215+
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
216+
)
215217

216218
# adapt ChatMessage(s) to the format expected by the OpenAI API
217219
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]

haystack/utils/hf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import copy
66
from enum import Enum
7-
from typing import Any, Callable, Dict, List, Optional, Union
7+
from typing import Any, Dict, List, Optional, Union
88

99
from haystack import logging
10-
from haystack.dataclasses import ChatMessage, StreamingChunk
10+
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk
1111
from haystack.lazy_imports import LazyImport
1212
from haystack.utils.auth import Secret
1313
from haystack.utils.device import ComponentDevice
@@ -349,15 +349,15 @@ class HFTokenStreamingHandler(TextStreamer):
349349
Streaming handler for HuggingFaceLocalGenerator and HuggingFaceLocalChatGenerator.
350350
351351
Note: This is a helper class for HuggingFaceLocalGenerator & HuggingFaceLocalChatGenerator enabling streaming
352-
of generated text via Haystack Callable[StreamingChunk, None] callbacks.
352+
of generated text via Haystack StreamingCallbackT callbacks.
353353
354354
Do not use this class directly.
355355
"""
356356

357357
def __init__(
358358
self,
359359
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
360-
stream_handler: Callable[[StreamingChunk], None],
360+
stream_handler: StreamingCallbackT,
361361
stop_words: Optional[List[str]] = None,
362362
):
363363
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore

0 commit comments

Comments
 (0)