Skip to content

Commit 1fa18fc

Browse files
committed
0.23 upstream merge fix part 4:
* Fix stream object * Add assistant streaming + func call example * Fix old OpenAI chat example
2 parents 72ea9c9 + c64d80b commit 1fa18fc

File tree

15 files changed

+533
-126
lines changed

15 files changed

+533
-126
lines changed

async-openai-wasm/src/client.rs

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,10 @@ impl<C: Config> Client<C> {
446446
path: &str,
447447
request: I,
448448
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
449-
) -> Pin<Box<dyn Stream<Item=Result<O, OpenAIError>> + Send>>
449+
) -> OpenAIEventMappedStream<O>
450450
where
451451
I: Serialize,
452-
O: DeserializeOwned + Send + 'static,
452+
O: DeserializeOwned + Send + 'static
453453
{
454454
let event_source = self
455455
.http_client
@@ -460,8 +460,7 @@ impl<C: Config> Client<C> {
460460
.eventsource()
461461
.unwrap();
462462

463-
// stream_mapped_raw_events(event_source, event_mapper).await
464-
todo!()
463+
OpenAIEventMappedStream::new(event_source, event_mapper)
465464
}
466465

467466
/// Make HTTP GET request to receive SSE
@@ -491,19 +490,21 @@ impl<C: Config> Client<C> {
491490
/// Request which responds with SSE.
492491
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
493492
#[pin_project]
494-
pub struct OpenAIEventStream<O> {
493+
pub struct OpenAIEventStream<O: DeserializeOwned + Send + 'static> {
495494
#[pin]
496495
stream: Filter<EventSource, future::Ready<bool>, fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>>,
496+
done: bool,
497497
_phantom_data: PhantomData<O>,
498498
}
499499

500-
impl<O> OpenAIEventStream<O> {
500+
impl<O: DeserializeOwned + Send + 'static> OpenAIEventStream<O> {
501501
pub(crate) fn new(event_source: EventSource) -> Self {
502502
Self {
503503
stream: event_source.filter(|result|
504504
// filter out the first event which is always Event::Open
505505
future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))
506506
),
507+
done: false,
507508
_phantom_data: PhantomData,
508509
}
509510
}
@@ -514,6 +515,9 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
514515

515516
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
516517
let this = self.project();
518+
if *this.done {
519+
return Poll::Ready(None);
520+
}
517521
let stream: Pin<&mut _> = this.stream;
518522
match stream.poll_next(cx) {
519523
Poll::Ready(response) => {
@@ -524,17 +528,24 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
524528
Event::Open => unreachable!(), // it has been filtered out
525529
Event::Message(message) => {
526530
if message.data == "[DONE]" {
531+
*this.done = true;
527532
Poll::Ready(None) // end of the stream, defined by OpenAI
528533
} else {
529534
// deserialize the data
530535
match serde_json::from_str::<O>(&message.data) {
531-
Err(e) => Poll::Ready(Some(Err(map_deserialization_error(e, &message.data.as_bytes())))),
536+
Err(e) => {
537+
*this.done = true;
538+
Poll::Ready(Some(Err(map_deserialization_error(e, &message.data.as_bytes()))))
539+
}
532540
Ok(output) => Poll::Ready(Some(Ok(output))),
533541
}
534542
}
535543
}
536544
}
537-
Err(e) => Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
545+
Err(e) => {
546+
*this.done = true;
547+
Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
548+
}
538549
}
539550
}
540551
}
@@ -543,6 +554,77 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
543554
}
544555
}
545556

557+
#[pin_project]
558+
pub struct OpenAIEventMappedStream<O>
559+
where O: Send + 'static
560+
{
561+
#[pin]
562+
stream: Filter<EventSource, future::Ready<bool>, fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>>,
563+
event_mapper: Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>,
564+
done: bool,
565+
_phantom_data: PhantomData<O>,
566+
}
567+
568+
impl<O> OpenAIEventMappedStream<O>
569+
where O: Send + 'static
570+
{
571+
pub(crate) fn new<M>(event_source: EventSource, event_mapper: M) -> Self
572+
where M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static {
573+
Self {
574+
stream: event_source.filter(|result|
575+
// filter out the first event which is always Event::Open
576+
future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))
577+
),
578+
done: false,
579+
event_mapper: Box::new(event_mapper),
580+
_phantom_data: PhantomData,
581+
}
582+
}
583+
}
584+
585+
586+
impl<O> Stream for OpenAIEventMappedStream<O>
587+
where O: Send + 'static
588+
{
589+
type Item = Result<O, OpenAIError>;
590+
591+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
592+
let this = self.project();
593+
if *this.done {
594+
return Poll::Ready(None);
595+
}
596+
let stream: Pin<&mut _> = this.stream;
597+
match stream.poll_next(cx) {
598+
Poll::Ready(response) => {
599+
match response {
600+
None => Poll::Ready(None), // end of the stream
601+
Some(result) => match result {
602+
Ok(event) => match event {
603+
Event::Open => unreachable!(), // it has been filtered out
604+
Event::Message(message) => {
605+
if message.data == "[DONE]" {
606+
*this.done = true;
607+
}
608+
let response = (this.event_mapper)(message);
609+
match response {
610+
Ok(output) => Poll::Ready(Some(Ok(output))),
611+
Err(_) => Poll::Ready(None)
612+
}
613+
}
614+
}
615+
Err(e) => {
616+
*this.done = true;
617+
Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
618+
}
619+
}
620+
}
621+
}
622+
Poll::Pending => Poll::Pending
623+
}
624+
}
625+
}
626+
627+
546628
// pub(crate) async fn stream_mapped_raw_events<O>(
547629
// mut event_source: EventSource,
548630
// event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,

async-openai-wasm/src/types/assistant_stream.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
use std::pin::Pin;
2-
3-
use futures::Stream;
41
use serde::Deserialize;
52

3+
use crate::client::OpenAIEventMappedStream;
64
use crate::error::{ApiError, map_deserialization_error, OpenAIError};
75

86
use super::{
@@ -28,7 +26,6 @@ use super::{
2826
/// We may add additional events over time, so we recommend handling unknown events gracefully
2927
/// in your code. See the [Assistants API quickstart](https://platform.openai.com/docs/assistants/overview) to learn how to
3028
/// integrate the Assistants API with streaming.
31-
3229
#[derive(Debug, Deserialize, Clone)]
3330
#[serde(tag = "event", content = "data")]
3431
#[non_exhaustive]
@@ -110,8 +107,7 @@ pub enum AssistantStreamEvent {
110107
Done(String),
111108
}
112109

113-
pub type AssistantEventStream =
114-
Pin<Box<dyn Stream<Item = Result<AssistantStreamEvent, OpenAIError>> + Send>>;
110+
pub type AssistantEventStream = OpenAIEventMappedStream<AssistantStreamEvent>;
115111

116112
impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
117113
type Error = OpenAIError;
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[package]
2+
name = "openai-web-assistant-chat"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8+
9+
[dependencies]
10+
dioxus = {version = "~0.5", features = ["web"]}
11+
futures = "0.3.30"
12+
async-openai-wasm = { path = "../../async-openai-wasm" }
13+
# Debug
14+
tracing = "0.1.40"
15+
dioxus-logger = "~0.5"
16+
serde_json = "1.0.117"
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
[application]
2+
3+
# App (Project) Name
4+
name = "openai-web-app-assistant-dioxus"
5+
6+
# Dioxus App Default Platform
7+
# desktop, web
8+
default_platform = "web"
9+
10+
# `build` & `serve` dist path
11+
out_dir = "dist"
12+
13+
[web.app]
14+
15+
# HTML title tag content
16+
title = "openai-web-app-assistant-dioxus"
17+
18+
[web.watcher]
19+
20+
# when watcher trigger, regenerate the `index.html`
21+
reload_html = true
22+
23+
# which files or dirs will be watcher monitoring
24+
watch_path = ["src"]
25+
26+
# include `assets` in web platform
27+
[web.resource]
28+
29+
# CSS style file
30+
31+
style = []
32+
33+
# Javascript code file
34+
script = []
35+
36+
[web.resource.dev]
37+
38+
# Javascript code file
39+
# serve: [dev-server] only
40+
script = []
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# OpenAI Web App - Assistant
2+
3+
This builds a `dioxus` web App that uses OpenAI Assistant APIs to generate text.
4+
5+
To run it, you need:
6+
1. Set OpenAI secrets in `./src/main.rs`. Please do NOT take this demo into production without using a secure secret store
7+
2. Install `dioxus-cli` by `cargo install dioxus-cli`.
8+
3. Run `dx serve`
9+
10+
Note: Safari may not work due to CORS issues. Please use Chrome or Edge.
11+
12+
## Reference
13+
14+
The code is adapted from [assistant-func-call-stream example in async-openai](https://github.com/64bit/async-openai/tree/main/examples/assistants-func-call-stream).
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#![allow(non_snake_case)]
2+
3+
use dioxus::prelude::*;
4+
use dioxus_logger::tracing::{error, info, Level};
5+
use futures::stream::StreamExt;
6+
7+
use async_openai_wasm::types::{AssistantStreamEvent, CreateMessageRequest, CreateRunRequest, CreateThreadRequest, MessageRole};
8+
9+
use crate::utils::*;
10+
11+
mod utils;
12+
13+
pub const API_BASE: &str = "...";
14+
pub const API_KEY: &str = "...";
15+
16+
17+
pub fn App() -> Element {
18+
const QUERY: &str = "What's the weather in San Francisco today and the likelihood it'll rain?";
19+
let reply = use_signal(String::new);
20+
let _run_assistant: Coroutine<()> = use_coroutine(|_rx| {
21+
let client = get_client();
22+
async move {
23+
//
24+
// Step 1: Define functions
25+
//
26+
let assistant = client
27+
.assistants()
28+
.create(create_assistant_request())
29+
.await
30+
.expect("failed to create assistant");
31+
//
32+
// Step 2: Create a Thread and add Messages
33+
//
34+
let thread = client
35+
.threads()
36+
.create(CreateThreadRequest::default())
37+
.await
38+
.expect("failed to create thread");
39+
let _message = client
40+
.threads()
41+
.messages(&thread.id)
42+
.create(CreateMessageRequest {
43+
role: MessageRole::User,
44+
content: QUERY.into(),
45+
..Default::default()
46+
})
47+
.await
48+
.expect("failed to create message");
49+
//
50+
// Step 3: Initiate a Run
51+
//
52+
let mut event_stream = client
53+
.threads()
54+
.runs(&thread.id)
55+
.create_stream(CreateRunRequest {
56+
assistant_id: assistant.id.clone(),
57+
stream: Some(true),
58+
..Default::default()
59+
})
60+
.await
61+
.expect("failed to create run");
62+
63+
64+
while let Some(event) = event_stream.next().await {
65+
match event {
66+
Ok(event) => match event {
67+
AssistantStreamEvent::ThreadRunRequiresAction(run_object) => {
68+
info!("thread.run.requires_action: run_id:{}", run_object.id);
69+
handle_requires_action(&client, run_object, reply.to_owned()).await
70+
}
71+
_ => info!("\nEvent: {event:?}\n"),
72+
},
73+
Err(e) => {
74+
error!("Error: {e}");
75+
}
76+
}
77+
}
78+
79+
client.threads().delete(&thread.id).await.expect("failed to delete thread");
80+
client.assistants().delete(&assistant.id).await.expect("failed to delete assistant");
81+
info!("Done!");
82+
}
83+
});
84+
85+
rsx! {
86+
div {
87+
p { "Using OpenAI" }
88+
p { "User: {QUERY}" }
89+
p { "Expected Stats (Debug): temperature = {TEMPERATURE}, rain_probability = {RAIN_PROBABILITY}" }
90+
p { "Assistant: {reply}" }
91+
}
92+
}
93+
}
94+
95+
fn main() {
96+
dioxus_logger::init(Level::INFO).expect("failed to init logger");
97+
launch(App);
98+
}

0 commit comments

Comments
 (0)