Skip to content

Implement Stop button while streaming #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 1, 2025
1 change: 1 addition & 0 deletions crates/code_assistant/assets/icons/circle_stop.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions crates/code_assistant/assets/icons/file_icons/file_types.json
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@
},
"send": {
"icon": "icons/send.svg"
},
"stop": {
"icon": "icons/stop.svg"
}
}
}
1 change: 1 addition & 0 deletions crates/code_assistant/assets/icons/stop.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
63 changes: 45 additions & 18 deletions crates/code_assistant/src/agent/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,18 @@ impl Agent {
}

/// Handles the interaction with the LLM to get the next assistant message.
/// Appends the assistant's message to the history.
/// Appends the assistant's message to the history only if it has content.
async fn obtain_llm_response(&mut self, messages: Vec<Message>) -> Result<llm::LLMResponse> {
let llm_response = self.get_next_assistant_message(messages).await?;
self.append_message(Message {
role: MessageRole::Assistant,
content: MessageContent::Structured(llm_response.content.clone()),
})?;

// Only add to message history if there's actual content
if !llm_response.content.is_empty() {
self.append_message(Message {
role: MessageRole::Assistant,
content: MessageContent::Structured(llm_response.content.clone()),
})?;
}

Ok(llm_response)
}

Expand Down Expand Up @@ -227,14 +232,7 @@ impl Agent {
request_counter += 1;

// 1. Obtain LLM response (includes adding assistant message to history)
let llm_response = match self.obtain_llm_response(messages).await {
Ok(response) => response,
Err(e) => {
// Log critical error and break loop
tracing::error!("Critical error obtaining LLM response: {}", e);
return Err(e);
}
};
let llm_response = self.obtain_llm_response(messages).await?;

// 2. Extract tool requests from LLM response and determine the next flow action
let (tool_requests, flow) = self
Expand Down Expand Up @@ -594,20 +592,49 @@ impl Agent {

// Create a StreamProcessor and use it to process streaming chunks
let ui = Arc::clone(&self.ui);
let processor = Arc::new(Mutex::new(create_stream_processor(self.tool_mode, ui)));
let processor = Arc::new(Mutex::new(create_stream_processor(
self.tool_mode,
ui.clone(),
)));

let streaming_callback: StreamingCallback = Box::new(move |chunk: &StreamingChunk| {
// Check if streaming should continue
if !ui.should_streaming_continue() {
debug!("Streaming should stop - user requested cancellation");
return Err(anyhow::anyhow!("Streaming cancelled by user"));
}

let mut processor_guard = processor.lock().unwrap();
processor_guard
.process(chunk)
.map_err(|e| anyhow::anyhow!("Failed to process streaming chunk: {}", e))
});

// Send message to LLM provider
let response = self
let response = match self
.llm_provider
.send_message(request, Some(&streaming_callback))
.await?;
.await
{
Ok(response) => response,
Err(e) => {
// Check for streaming cancelled error
if e.to_string().contains("Streaming cancelled by user") {
debug!("Streaming cancelled by user in LLM request {}", request_id);
// End LLM request with cancelled=true
let _ = self.ui.end_llm_request(request_id, true).await;
// Return empty response
return Ok(llm::LLMResponse {
content: Vec::new(),
usage: llm::Usage::zero(),
});
}

// For other errors, still end the request but not cancelled
let _ = self.ui.end_llm_request(request_id, false).await;
return Err(e);
}
};

// Print response for debugging
debug!("Raw LLM response:");
Expand All @@ -631,8 +658,8 @@ impl Agent {
response.usage.cache_read_input_tokens
);

// Inform UI that the LLM request has completed
let _ = self.ui.end_llm_request(request_id).await;
// Inform UI that the LLM request has completed (normal completion)
let _ = self.ui.end_llm_request(request_id, false).await;
debug!("Completed LLM request with ID: {}", request_id);

Ok(response)
Expand Down
7 changes: 6 additions & 1 deletion crates/code_assistant/src/tests/mocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,15 @@ impl UserInterface for MockUI {
Ok(42)
}

async fn end_llm_request(&self, _request_id: u64) -> Result<(), UIError> {
async fn end_llm_request(&self, _request_id: u64, _cancelled: bool) -> Result<(), UIError> {
// Mock implementation does nothing with request completion
Ok(())
}

fn should_streaming_continue(&self) -> bool {
// Mock implementation always continues streaming
true
}
}

// Mock Explorer
Expand Down
96 changes: 86 additions & 10 deletions crates/code_assistant/src/ui/gpui/elements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,55 @@ pub enum MessageRole {
pub struct MessageContainer {
elements: Arc<Mutex<Vec<Entity<BlockView>>>>,
role: MessageRole,
current_request_id: Arc<Mutex<u64>>,
waiting_for_content: Arc<Mutex<bool>>,
}

impl MessageContainer {
pub fn with_role(role: MessageRole, _cx: &mut Context<Self>) -> Self {
Self {
elements: Arc::new(Mutex::new(Vec::new())),
role,
current_request_id: Arc::new(Mutex::new(0)),
waiting_for_content: Arc::new(Mutex::new(false)),
}
}

// Set the current request ID for this message container
pub fn set_current_request_id(&self, request_id: u64) {
*self.current_request_id.lock().unwrap() = request_id;
}

// Set waiting for content flag
pub fn set_waiting_for_content(&self, waiting: bool) {
*self.waiting_for_content.lock().unwrap() = waiting;
}

// Check if waiting for content
pub fn is_waiting_for_content(&self) -> bool {
*self.waiting_for_content.lock().unwrap()
}

// Remove all blocks with the given request ID
pub fn remove_blocks_with_request_id(&self, request_id: u64, cx: &mut Context<Self>) {
let mut elements = self.elements.lock().unwrap();
let mut blocks_to_remove = Vec::new();

// Find indices of blocks to remove
for (index, element) in elements.iter().enumerate() {
let should_remove = element.read(cx).request_id == request_id;
if should_remove {
blocks_to_remove.push(index);
}
}

// Remove blocks in reverse order to maintain indices
for &index in blocks_to_remove.iter().rev() {
elements.remove(index);
}

if !blocks_to_remove.is_empty() {
cx.notify();
}
}

Expand All @@ -46,11 +88,16 @@ impl MessageContainer {
// Add a new text block
pub fn add_text_block(&self, content: impl Into<String>, cx: &mut Context<Self>) {
self.finish_any_thinking_blocks(cx);

// Clear waiting_for_content flag on first content
self.set_waiting_for_content(false);

let request_id = *self.current_request_id.lock().unwrap();
let mut elements = self.elements.lock().unwrap();
let block = BlockData::TextBlock(TextBlock {
content: content.into(),
});
let view = cx.new(|cx| BlockView::new(block, cx));
let view = cx.new(|cx| BlockView::new(block, request_id, cx));
elements.push(view);
cx.notify();
}
Expand All @@ -59,9 +106,14 @@ impl MessageContainer {
#[allow(dead_code)]
pub fn add_thinking_block(&self, content: impl Into<String>, cx: &mut Context<Self>) {
self.finish_any_thinking_blocks(cx);

// Clear waiting_for_content flag on first content
self.set_waiting_for_content(false);

let request_id = *self.current_request_id.lock().unwrap();
let mut elements = self.elements.lock().unwrap();
let block = BlockData::ThinkingBlock(ThinkingBlock::new(content.into()));
let view = cx.new(|cx| BlockView::new(block, cx));
let view = cx.new(|cx| BlockView::new(block, request_id, cx));
elements.push(view);
cx.notify();
}
Expand All @@ -74,6 +126,11 @@ impl MessageContainer {
cx: &mut Context<Self>,
) {
self.finish_any_thinking_blocks(cx);

// Clear waiting_for_content flag on first content
self.set_waiting_for_content(false);

let request_id = *self.current_request_id.lock().unwrap();
let mut elements = self.elements.lock().unwrap();
let block = BlockData::ToolUse(ToolUseBlock {
name: name.into(),
Expand All @@ -84,7 +141,7 @@ impl MessageContainer {
output: None,
is_collapsed: false, // Default to expanded
});
let view = cx.new(|cx| BlockView::new(block, cx));
let view = cx.new(|cx| BlockView::new(block, request_id, cx));
elements.push(view);
cx.notify();
}
Expand Down Expand Up @@ -131,6 +188,9 @@ impl MessageContainer {
pub fn add_or_append_to_text_block(&self, content: impl Into<String>, cx: &mut Context<Self>) {
self.finish_any_thinking_blocks(cx);

// Clear waiting_for_content flag on first content
self.set_waiting_for_content(false);

let content = content.into();
let mut elements = self.elements.lock().unwrap();

Expand All @@ -151,10 +211,11 @@ impl MessageContainer {
}

// If we reach here, we need to add a new text block
let request_id = *self.current_request_id.lock().unwrap();
let block = BlockData::TextBlock(TextBlock {
content: content.to_string(),
});
let view = cx.new(|cx| BlockView::new(block, cx));
let view = cx.new(|cx| BlockView::new(block, request_id, cx));
elements.push(view);
cx.notify();
}
Expand All @@ -165,6 +226,9 @@ impl MessageContainer {
content: impl Into<String>,
cx: &mut Context<Self>,
) {
// Clear waiting_for_content flag on first content
self.set_waiting_for_content(false);

let content = content.into();
let mut elements = self.elements.lock().unwrap();

Expand All @@ -185,8 +249,9 @@ impl MessageContainer {
}

// If we reach here, we need to add a new thinking block
let request_id = *self.current_request_id.lock().unwrap();
let block = BlockData::ThinkingBlock(ThinkingBlock::new(content.to_string()));
let view = cx.new(|cx| BlockView::new(block, cx));
let view = cx.new(|cx| BlockView::new(block, request_id, cx));
elements.push(view);
cx.notify();
}
Expand Down Expand Up @@ -260,6 +325,7 @@ impl MessageContainer {

// If we didn't find a matching tool, create a new one with this parameter
if !tool_found {
let request_id = *self.current_request_id.lock().unwrap();
let mut tool = ToolUseBlock {
name: "unknown".to_string(), // Default name since we only have ID
id: tool_id.clone(),
Expand All @@ -276,7 +342,7 @@ impl MessageContainer {
});

let block = BlockData::ToolUse(tool);
let view = cx.new(|cx| BlockView::new(block, cx));
let view = cx.new(|cx| BlockView::new(block, request_id, cx));
elements.push(view);
cx.notify();
}
Expand Down Expand Up @@ -368,11 +434,12 @@ impl BlockData {
/// Entity view for a block
pub struct BlockView {
block: BlockData,
request_id: u64,
}

impl BlockView {
pub fn new(block: BlockData, _cx: &mut Context<Self>) -> Self {
Self { block }
pub fn new(block: BlockData, request_id: u64, _cx: &mut Context<Self>) -> Self {
Self { block, request_id }
}

fn toggle_thinking_collapsed(&mut self, cx: &mut Context<Self>) {
Expand Down Expand Up @@ -721,13 +788,20 @@ impl Render for BlockView {
if let Some(output_content) = &block.output {
if !output_content.is_empty() {
// Also check if output is not empty
let output_color =
if block.status == crate::ui::ToolStatus::Error {
cx.theme().danger
} else {
cx.theme().foreground
};

elements.push(
div()
.id(SharedString::from(block.id.clone()))
.p_2()
.mt_1()
.w_full()
.text_color(cx.theme().foreground)
.text_color(output_color)
.text_size(px(13.))
.child(output_content.clone())
.into_any(),
Expand All @@ -736,9 +810,11 @@ impl Render for BlockView {
}
}

// Error message (always shown for error status, regardless of collapsed state)
// Error message (only shown for error status when collapsed, or when there's no output)
if block.status == crate::ui::ToolStatus::Error
&& block.status_message.is_some()
&& (block.is_collapsed
|| block.output.as_ref().map_or(true, |o| o.is_empty()))
{
elements.push(
div()
Expand Down
1 change: 1 addition & 0 deletions crates/code_assistant/src/ui/gpui/file_icons.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub const THEME_DARK: &str = "theme_dark"; // theme_dark.svg
pub const THEME_LIGHT: &str = "theme_light"; // theme_light.svg

pub const SEND: &str = "send"; // send.svg
pub const STOP: &str = "stop"; // circle_stop.svg

// Tool-specific icon mappings to actual SVG files
// These are direct constants defining the paths to SVG icons or existing types
Expand Down
Loading