diff --git a/packages/components/nodes/agents/AirtableAgent/AirtableAgent.ts b/packages/components/nodes/agents/AirtableAgent/AirtableAgent.ts index 7b8349e80c1..e6428f14b7d 100644 --- a/packages/components/nodes/agents/AirtableAgent/AirtableAgent.ts +++ b/packages/components/nodes/agents/AirtableAgent/AirtableAgent.ts @@ -2,7 +2,7 @@ import axios from 'axios' import { BaseLanguageModel } from '@langchain/core/language_models/base' import { AgentExecutor } from 'langchain/agents' import { LLMChain } from 'langchain/chains' -import { ICommonObject, INode, INodeData, INodeParams, PromptTemplate } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer, PromptTemplate } from '../../../src/Interface' import { getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { LoadPyodide, finalSystemPrompt, systemPrompt } from './core' @@ -104,11 +104,17 @@ class Airtable_Agents implements INode { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - //streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + // if (options.shouldStreamResponse) { + // streamResponse(options.sseStreamer, options.chatId, e.message) + // } return formatResponse(e.message) } } + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + const credentialData = await getCredentialData(nodeData.credential ?? '', options) const accessToken = getCredentialParam('accessToken', credentialData, nodeData) @@ -123,7 +129,6 @@ class Airtable_Agents implements INode { let base64String = Buffer.from(JSON.stringify(airtableData)).toString('base64') const loggerHandler = new ConsoleCallbackHandler(options.logger) - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) const callbacks = await additionalCallbacks(nodeData, options) const pyodide = await LoadPyodide() @@ -194,7 +199,8 @@ json.dumps(my_dict)` answer: finalResult } - if (options.socketIO && options.socketIOClientId) { + if (options.shouldStreamResponse) { + const handler = new CustomChainHandler(shouldStreamResponse ? sseStreamer : undefined, chatId) const result = await chain.call(inputs, [loggerHandler, handler, ...callbacks]) return result?.text } else { diff --git a/packages/components/nodes/agents/AutoGPT/AutoGPT.ts b/packages/components/nodes/agents/AutoGPT/AutoGPT.ts index 4c1d962c3e4..c41a52965fd 100644 --- a/packages/components/nodes/agents/AutoGPT/AutoGPT.ts +++ b/packages/components/nodes/agents/AutoGPT/AutoGPT.ts @@ -113,7 +113,9 @@ class AutoGPT_Agents implements INode { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - //streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + // if (options.shouldStreamResponse) { + // streamResponse(options.sseStreamer, options.chatId, e.message) + // } return formatResponse(e.message) } } diff --git a/packages/components/nodes/agents/BabyAGI/BabyAGI.ts b/packages/components/nodes/agents/BabyAGI/BabyAGI.ts index bfc910b7952..87d5cd28923 100644 --- a/packages/components/nodes/agents/BabyAGI/BabyAGI.ts +++ b/packages/components/nodes/agents/BabyAGI/BabyAGI.ts @@ -73,7 +73,9 @@ class BabyAGI_Agents implements INode { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - //streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + // if (options.shouldStreamResponse) { + // streamResponse(options.sseStreamer, options.chatId, e.message) + // } return formatResponse(e.message) } } diff --git a/packages/components/nodes/agents/CSVAgent/CSVAgent.ts b/packages/components/nodes/agents/CSVAgent/CSVAgent.ts index 70c35cd967c..b0bc0eabe81 100644 --- a/packages/components/nodes/agents/CSVAgent/CSVAgent.ts +++ b/packages/components/nodes/agents/CSVAgent/CSVAgent.ts @@ -2,7 +2,7 @@ import { BaseLanguageModel } from '@langchain/core/language_models/base' import { AgentExecutor } from 'langchain/agents' import { LLMChain } from 'langchain/chains' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' -import { ICommonObject, INode, INodeData, INodeParams, PromptTemplate } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer, PromptTemplate } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' import { LoadPyodide, finalSystemPrompt, systemPrompt } from './core' import { checkInputs, Moderation } from '../../moderation/Moderation' @@ -90,13 +90,18 @@ class CSV_Agents implements INode { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - //streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + // if (options.shouldStreamResponse) { + // streamResponse(options.sseStreamer, options.chatId, e.message) + // } return formatResponse(e.message) } } const loggerHandler = new ConsoleCallbackHandler(options.logger) - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + const callbacks = await additionalCallbacks(nodeData, options) let files: string[] = [] @@ -203,7 +208,8 @@ json.dumps(my_dict)` answer: finalResult } - if (options.socketIO && options.socketIOClientId) { + if (options.shouldStreamResponse) { + const handler = new CustomChainHandler(shouldStreamResponse ? sseStreamer : undefined, chatId) const result = await chain.call(inputs, [loggerHandler, handler, ...callbacks]) return result?.text } else { diff --git a/packages/components/nodes/agents/ConversationalAgent/ConversationalAgent.ts b/packages/components/nodes/agents/ConversationalAgent/ConversationalAgent.ts index 2d8a2d73ed2..f9541b369e0 100644 --- a/packages/components/nodes/agents/ConversationalAgent/ConversationalAgent.ts +++ b/packages/components/nodes/agents/ConversationalAgent/ConversationalAgent.ts @@ -9,7 +9,16 @@ import { RunnableSequence } from '@langchain/core/runnables' import { ChatConversationalAgent } from 'langchain/agents' import { getBaseClasses } from '../../../src/utils' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' -import { IVisionChatModal, FlowiseMemory, ICommonObject, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface' +import { + IVisionChatModal, + FlowiseMemory, + ICommonObject, + INode, + INodeData, + INodeParams, + IUsedTool, + IServerSideEventStreamer +} from '../../../src/Interface' import { AgentExecutor } from '../../../src/agents' import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils' import { checkInputs, Moderation } from '../../moderation/Moderation' @@ -106,12 +115,18 @@ class ConversationalAgent_Agents implements INode { const memory = nodeData.inputs?.memory as FlowiseMemory const moderations = nodeData.inputs?.inputModeration as Moderation[] + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the BabyAGI agent input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) + // if (options.shouldStreamResponse) { + // streamResponse(options.sseStreamer, options.chatId, e.message) + // } //streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) return formatResponse(e.message) } @@ -125,15 +140,17 @@ class ConversationalAgent_Agents implements INode { let sourceDocuments: ICommonObject[] = [] let usedTools: IUsedTool[] = [] - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + if (options.shouldStreamResponse) { + const handler = new CustomChainHandler(shouldStreamResponse ? sseStreamer : undefined, chatId) res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) if (res.sourceDocuments) { - options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments)) + if (options.sseStreamer) { + sseStreamer.streamSourceDocumentsEvent(options.chatId, flatten(res.sourceDocuments)) + } sourceDocuments = res.sourceDocuments } if (res.usedTools) { - options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools) + sseStreamer.streamUsedToolsEvent(options.chatId, res.usedTools) usedTools = res.usedTools } // If the tool is set to returnDirect, stream the output to the client @@ -142,11 +159,14 @@ class ConversationalAgent_Agents implements INode { inputTools = flatten(inputTools) for (const tool of res.usedTools) { const inputTool = inputTools.find((inputTool: Tool) => inputTool.name === tool.tool) - if (inputTool && inputTool.returnDirect) { - options.socketIO.to(options.socketIOClientId).emit('token', tool.toolOutput) + if (inputTool && inputTool.returnDirect && options.sseStreamer) { + sseStreamer.streamTokenEvent(options.chatId, tool.toolOutput) } } } + if (sseStreamer) { + sseStreamer.streamEndEvent(options.chatId) + } } else { res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) if (res.sourceDocuments) { diff --git a/packages/components/nodes/agents/ConversationalRetrievalToolAgent/ConversationalRetrievalToolAgent.ts b/packages/components/nodes/agents/ConversationalRetrievalToolAgent/ConversationalRetrievalToolAgent.ts index 54013ac551b..894c7b13dbe 100644 --- a/packages/components/nodes/agents/ConversationalRetrievalToolAgent/ConversationalRetrievalToolAgent.ts +++ b/packages/components/nodes/agents/ConversationalRetrievalToolAgent/ConversationalRetrievalToolAgent.ts @@ -7,7 +7,16 @@ import { ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate, Pr import { formatToOpenAIToolMessages } from 'langchain/agents/format_scratchpad/openai_tools' import { getBaseClasses } from '../../../src/utils' import { type ToolsAgentStep } from 'langchain/agents/openai/output_parser' -import { FlowiseMemory, ICommonObject, INode, INodeData, INodeParams, IUsedTool, IVisionChatModal } from '../../../src/Interface' +import { + FlowiseMemory, + ICommonObject, + INode, + INodeData, + INodeParams, + IServerSideEventStreamer, + IUsedTool, + IVisionChatModal +} from '../../../src/Interface' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { AgentExecutor, ToolCallingAgentOutputParser } from '../../../src/agents' import { Moderation, checkInputs, streamResponse } from '../../moderation/Moderation' @@ -104,7 +113,9 @@ class ConversationalRetrievalToolAgent_Agents implements INode { const memory = nodeData.inputs?.memory as FlowiseMemory const moderations = nodeData.inputs?.inputModeration as Moderation[] - const isStreamable = options.socketIO && options.socketIOClientId + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId if (moderations && moderations.length > 0) { try { @@ -112,8 +123,9 @@ class ConversationalRetrievalToolAgent_Agents implements INode { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - if (isStreamable) - streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + streamResponse(sseStreamer, chatId, e.message) + } return formatResponse(e.message) } } @@ -127,15 +139,15 @@ class ConversationalRetrievalToolAgent_Agents implements INode { let sourceDocuments: ICommonObject[] = [] let usedTools: IUsedTool[] = [] - if (isStreamable) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) if (res.sourceDocuments) { - options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments)) + sseStreamer.streamSourceDocumentsEvent(chatId, flatten(res.sourceDocuments)) sourceDocuments = res.sourceDocuments } if (res.usedTools) { - options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools) + sseStreamer.streamUsedToolsEvent(chatId, res.usedTools) usedTools = res.usedTools } } else { diff --git a/packages/components/nodes/agents/LlamaIndexAgents/OpenAIToolAgent/OpenAIToolAgent_LlamaIndex.ts b/packages/components/nodes/agents/LlamaIndexAgents/OpenAIToolAgent/OpenAIToolAgent_LlamaIndex.ts index ed9895de0ed..f25b9677bfa 100644 --- a/packages/components/nodes/agents/LlamaIndexAgents/OpenAIToolAgent/OpenAIToolAgent_LlamaIndex.ts +++ b/packages/components/nodes/agents/LlamaIndexAgents/OpenAIToolAgent/OpenAIToolAgent_LlamaIndex.ts @@ -1,7 +1,16 @@ import { flatten } from 'lodash' import { ChatMessage, OpenAI, OpenAIAgent } from 'llamaindex' import { getBaseClasses } from '../../../../src/utils' -import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../../src/Interface' +import { + FlowiseMemory, + ICommonObject, + IMessage, + INode, + INodeData, + INodeParams, + IServerSideEventStreamer, + IUsedTool +} from '../../../../src/Interface' class OpenAIFunctionAgent_LlamaIndex_Agents implements INode { label: string @@ -67,7 +76,9 @@ class OpenAIFunctionAgent_LlamaIndex_Agents implements INode { let tools = nodeData.inputs?.tools tools = flatten(tools) - const isStreamingEnabled = options.socketIO && options.socketIOClientId + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId const chatHistory = [] as ChatMessage[] @@ -104,7 +115,7 @@ class OpenAIFunctionAgent_LlamaIndex_Agents implements INode { let isStreamingStarted = false const usedTools: IUsedTool[] = [] - if (isStreamingEnabled) { + if (shouldStreamResponse) { const stream = await agent.chat({ message: input, chatHistory, @@ -116,7 +127,9 @@ class OpenAIFunctionAgent_LlamaIndex_Agents implements INode { text += chunk.response.delta if (!isStreamingStarted) { isStreamingStarted = true - options.socketIO.to(options.socketIOClientId).emit('start', chunk.response.delta) + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, chunk.response.delta) + } if (chunk.sources.length) { for (const sourceTool of chunk.sources) { usedTools.push({ @@ -125,11 +138,14 @@ class OpenAIFunctionAgent_LlamaIndex_Agents implements INode { toolOutput: sourceTool.output as any }) } - options.socketIO.to(options.socketIOClientId).emit('usedTools', usedTools) + if (sseStreamer) { + sseStreamer.streamUsedToolsEvent(chatId, usedTools) + } } } - - options.socketIO.to(options.socketIOClientId).emit('token', chunk.response.delta) + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, chunk.response.delta) + } } } else { const response = await agent.chat({ message: input, chatHistory, verbose: process.env.DEBUG === 'true' ? true : false }) diff --git a/packages/components/nodes/agents/OpenAIAssistant/OpenAIAssistant.ts b/packages/components/nodes/agents/OpenAIAssistant/OpenAIAssistant.ts index 98e76fedc86..4778b1b744f 100644 --- a/packages/components/nodes/agents/OpenAIAssistant/OpenAIAssistant.ts +++ b/packages/components/nodes/agents/OpenAIAssistant/OpenAIAssistant.ts @@ -1,4 +1,13 @@ -import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeOptionsValue, INodeParams, IUsedTool } from '../../../src/Interface' +import { + ICommonObject, + IDatabaseEntity, + INode, + INodeData, + INodeOptionsValue, + INodeParams, + IServerSideEventStreamer, + IUsedTool +} from '../../../src/Interface' import OpenAI from 'openai' import { DataSource } from 'typeorm' import { getCredentialData, getCredentialParam } from '../../../src/utils' @@ -176,16 +185,19 @@ class OpenAIAssistant_Agents implements INode { const moderations = nodeData.inputs?.inputModeration as Moderation[] const _toolChoice = nodeData.inputs?.toolChoice as string const parallelToolCalls = nodeData.inputs?.parallelToolCalls as boolean - const isStreaming = options.socketIO && options.socketIOClientId - const socketIO = isStreaming ? options.socketIO : undefined - const socketIOClientId = isStreaming ? options.socketIOClientId : '' + + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId if (moderations && moderations.length > 0) { try { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - streamResponse(isStreaming, e.message, socketIO, socketIOClientId) + if (shouldStreamResponse) { + streamResponse(sseStreamer, chatId, e.message) + } return formatResponse(e.message) } } @@ -307,7 +319,7 @@ class OpenAIAssistant_Agents implements INode { } } - if (isStreaming) { + if (shouldStreamResponse) { const streamThread = await openai.beta.threads.runs.create(threadId, { assistant_id: retrievedAssistant.id, stream: true, @@ -389,26 +401,37 @@ class OpenAIAssistant_Agents implements INode { if (message_content.value) { if (!isStreamingStarted) { isStreamingStarted = true - socketIO.to(socketIOClientId).emit('start', message_content.value) + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, message_content.value) + } + } + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, message_content.value) } - socketIO.to(socketIOClientId).emit('token', message_content.value) } if (fileAnnotations.length) { if (!isStreamingStarted) { isStreamingStarted = true - socketIO.to(socketIOClientId).emit('start', '') + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, ' ') + } + } + if (sseStreamer) { + sseStreamer.streamFileAnnotationsEvent(chatId, fileAnnotations) } - socketIO.to(socketIOClientId).emit('fileAnnotations', fileAnnotations) } } else { text += chunk.text?.value if (!isStreamingStarted) { isStreamingStarted = true - socketIO.to(socketIOClientId).emit('start', chunk.text?.value) + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, chunk.text?.value || '') + } + } + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, chunk.text?.value || '') } - - socketIO.to(socketIOClientId).emit('token', chunk.text?.value) } } @@ -425,10 +448,13 @@ class OpenAIAssistant_Agents implements INode { if (!isStreamingStarted) { isStreamingStarted = true - socketIO.to(socketIOClientId).emit('start', imgHTML) + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, imgHTML) + } + } + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, imgHTML) } - - socketIO.to(socketIOClientId).emit('token', imgHTML) } } @@ -495,15 +521,19 @@ class OpenAIAssistant_Agents implements INode { text += chunk.text.value if (!isStreamingStarted) { isStreamingStarted = true - socketIO.to(socketIOClientId).emit('start', chunk.text.value) + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, chunk.text.value) + } + } + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, chunk.text.value) } - - socketIO.to(socketIOClientId).emit('token', chunk.text.value) } } } - - socketIO.to(socketIOClientId).emit('usedTools', usedTools) + if (sseStreamer) { + sseStreamer.streamUsedToolsEvent(chatId, usedTools) + } } catch (error) { console.error('Error submitting tool outputs:', error) await openai.beta.threads.runs.cancel(threadId, runThreadId) @@ -574,7 +604,9 @@ class OpenAIAssistant_Agents implements INode { // Start tool analytics const toolIds = await analyticHandlers.onToolStart(tool.name, actions[i].toolInput, parentIds) - if (socketIO && socketIOClientId) socketIO.to(socketIOClientId).emit('tool', tool.name) + if (shouldStreamResponse && sseStreamer) { + sseStreamer.streamToolEvent(chatId, tool.name) + } try { const toolOutput = await tool.call(actions[i].toolInput, undefined, undefined, { diff --git a/packages/components/nodes/agents/ReActAgentChat/ReActAgentChat.ts b/packages/components/nodes/agents/ReActAgentChat/ReActAgentChat.ts index 227f1070675..c0732d3f713 100644 --- a/packages/components/nodes/agents/ReActAgentChat/ReActAgentChat.ts +++ b/packages/components/nodes/agents/ReActAgentChat/ReActAgentChat.ts @@ -88,7 +88,9 @@ class ReActAgentChat_Agents implements INode { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - //streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + // if (options.shouldStreamResponse) { + // streamResponse(options.sseStreamer, options.chatId, e.message) + // } return formatResponse(e.message) } } diff --git a/packages/components/nodes/agents/ReActAgentLLM/ReActAgentLLM.ts b/packages/components/nodes/agents/ReActAgentLLM/ReActAgentLLM.ts index 7547c807d91..bc7b0e94580 100644 --- a/packages/components/nodes/agents/ReActAgentLLM/ReActAgentLLM.ts +++ b/packages/components/nodes/agents/ReActAgentLLM/ReActAgentLLM.ts @@ -77,7 +77,9 @@ class ReActAgentLLM_Agents implements INode { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - //streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + // if (options.shouldStreamResponse) { + // streamResponse(options.sseStreamer, options.chatId, e.message) + // } return formatResponse(e.message) } } diff --git a/packages/components/nodes/agents/ToolAgent/ToolAgent.ts b/packages/components/nodes/agents/ToolAgent/ToolAgent.ts index c56138da75a..012e18d6eb5 100644 --- a/packages/components/nodes/agents/ToolAgent/ToolAgent.ts +++ b/packages/components/nodes/agents/ToolAgent/ToolAgent.ts @@ -8,7 +8,16 @@ import { ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate, Pr import { formatToOpenAIToolMessages } from 'langchain/agents/format_scratchpad/openai_tools' import { type ToolsAgentStep } from 'langchain/agents/openai/output_parser' import { getBaseClasses, handleEscapeCharacters } from '../../../src/utils' -import { FlowiseMemory, ICommonObject, INode, INodeData, INodeParams, IUsedTool, IVisionChatModal } from '../../../src/Interface' +import { + FlowiseMemory, + ICommonObject, + INode, + INodeData, + INodeParams, + IServerSideEventStreamer, + IUsedTool, + IVisionChatModal +} from '../../../src/Interface' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { AgentExecutor, ToolCallingAgentOutputParser } from '../../../src/agents' import { Moderation, checkInputs, streamResponse } from '../../moderation/Moderation' @@ -100,7 +109,9 @@ class ToolAgent_Agents implements INode { const memory = nodeData.inputs?.memory as FlowiseMemory const moderations = nodeData.inputs?.inputModeration as Moderation[] - const isStreamable = options.socketIO && options.socketIOClientId + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId if (moderations && moderations.length > 0) { try { @@ -108,8 +119,9 @@ class ToolAgent_Agents implements INode { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - if (isStreamable) - streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + streamResponse(sseStreamer, chatId, e.message) + } return formatResponse(e.message) } } @@ -123,15 +135,19 @@ class ToolAgent_Agents implements INode { let sourceDocuments: ICommonObject[] = [] let usedTools: IUsedTool[] = [] - if (isStreamable) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) if (res.sourceDocuments) { - options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments)) + if (sseStreamer) { + sseStreamer.streamSourceDocumentsEvent(chatId, flatten(res.sourceDocuments)) + } sourceDocuments = res.sourceDocuments } if (res.usedTools) { - options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools) + if (sseStreamer) { + sseStreamer.streamUsedToolsEvent(chatId, flatten(res.usedTools)) + } usedTools = res.usedTools } // If the tool is set to returnDirect, stream the output to the client @@ -140,8 +156,8 @@ class ToolAgent_Agents implements INode { inputTools = flatten(inputTools) for (const tool of res.usedTools) { const inputTool = inputTools.find((inputTool: Tool) => inputTool.name === tool.tool) - if (inputTool && inputTool.returnDirect) { - options.socketIO.to(options.socketIOClientId).emit('token', tool.toolOutput) + if (inputTool && inputTool.returnDirect && shouldStreamResponse) { + sseStreamer.streamTokenEvent(chatId, tool.toolOutput) } } } diff --git a/packages/components/nodes/agents/XMLAgent/XMLAgent.ts b/packages/components/nodes/agents/XMLAgent/XMLAgent.ts index 886f2bcbb91..d28e3439ebf 100644 --- a/packages/components/nodes/agents/XMLAgent/XMLAgent.ts +++ b/packages/components/nodes/agents/XMLAgent/XMLAgent.ts @@ -7,7 +7,16 @@ import { Tool } from '@langchain/core/tools' import { ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts' import { formatLogToMessage } from 'langchain/agents/format_scratchpad/log_to_message' import { getBaseClasses } from '../../../src/utils' -import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface' +import { + FlowiseMemory, + ICommonObject, + IMessage, + INode, + INodeData, + INodeParams, + IServerSideEventStreamer, + IUsedTool +} from '../../../src/Interface' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { AgentExecutor, XMLAgentOutputParser } from '../../../src/agents' import { Moderation, checkInputs } from '../../moderation/Moderation' @@ -112,13 +121,19 @@ class XMLAgent_Agents implements INode { const memory = nodeData.inputs?.memory as FlowiseMemory const moderations = nodeData.inputs?.inputModeration as Moderation[] + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the OpenAI Function Agent input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - //streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + // if (options.shouldStreamResponse) { + // streamResponse(options.sseStreamer, options.chatId, e.message) + // } return formatResponse(e.message) } } @@ -131,15 +146,19 @@ class XMLAgent_Agents implements INode { let sourceDocuments: ICommonObject[] = [] let usedTools: IUsedTool[] = [] - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) if (res.sourceDocuments) { - options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments)) + if (sseStreamer) { + sseStreamer.streamSourceDocumentsEvent(chatId, flatten(res.sourceDocuments)) + } sourceDocuments = res.sourceDocuments } if (res.usedTools) { - options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools) + if (sseStreamer) { + sseStreamer.streamUsedToolsEvent(chatId, flatten(res.usedTools)) + } usedTools = res.usedTools } // If the tool is set to returnDirect, stream the output to the client @@ -149,7 +168,9 @@ class XMLAgent_Agents implements INode { for (const tool of res.usedTools) { const inputTool = inputTools.find((inputTool: Tool) => inputTool.name === tool.tool) if (inputTool && inputTool.returnDirect) { - options.socketIO.to(options.socketIOClientId).emit('token', tool.toolOutput) + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, tool.toolOutput) + } } } } diff --git a/packages/components/nodes/chains/ApiChain/GETApiChain.ts b/packages/components/nodes/chains/ApiChain/GETApiChain.ts index cdf78ffc837..ac3ef7f4fdc 100644 --- a/packages/components/nodes/chains/ApiChain/GETApiChain.ts +++ b/packages/components/nodes/chains/ApiChain/GETApiChain.ts @@ -2,7 +2,7 @@ import { BaseLanguageModel } from '@langchain/core/language_models/base' import { PromptTemplate } from '@langchain/core/prompts' import { APIChain } from 'langchain/chains' import { getBaseClasses } from '../../../src/utils' -import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer } from '../../../src/Interface' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' export const API_URL_RAW_PROMPT_TEMPLATE = `You are given the below API Documentation: @@ -100,9 +100,12 @@ class GETApiChain_Chains implements INode { const chain = await getAPIChain(apiDocs, model, headers, urlPrompt, ansPrompt) const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) return res } else { diff --git a/packages/components/nodes/chains/ApiChain/OpenAPIChain.ts b/packages/components/nodes/chains/ApiChain/OpenAPIChain.ts index 27c1fe007e3..a04d1961d23 100644 --- a/packages/components/nodes/chains/ApiChain/OpenAPIChain.ts +++ b/packages/components/nodes/chains/ApiChain/OpenAPIChain.ts @@ -1,5 +1,5 @@ import { APIChain, createOpenAPIChain } from 'langchain/chains' -import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation' @@ -74,18 +74,24 @@ class OpenApiChain_Chains implements INode { const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) const moderations = nodeData.inputs?.inputModeration as Moderation[] + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the OpenAPI chain input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + if (options.shouldStreamResponse) { + streamResponse(sseStreamer, chatId, e.message) + } return formatResponse(e.message) } } - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) return res } else { diff --git a/packages/components/nodes/chains/ApiChain/POSTApiChain.ts b/packages/components/nodes/chains/ApiChain/POSTApiChain.ts index e6f0bd34994..b93c0bacdfd 100644 --- a/packages/components/nodes/chains/ApiChain/POSTApiChain.ts +++ b/packages/components/nodes/chains/ApiChain/POSTApiChain.ts @@ -2,7 +2,7 @@ import { BaseLanguageModel } from '@langchain/core/language_models/base' import { PromptTemplate } from '@langchain/core/prompts' import { API_RESPONSE_RAW_PROMPT_TEMPLATE, API_URL_RAW_PROMPT_TEMPLATE, APIChain } from './postCore' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' -import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' class POSTApiChain_Chains implements INode { @@ -90,8 +90,12 @@ class POSTApiChain_Chains implements INode { const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2) + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) return res } else { diff --git a/packages/components/nodes/chains/ConversationChain/ConversationChain.ts b/packages/components/nodes/chains/ConversationChain/ConversationChain.ts index 73dc9c68c71..07da2ee449b 100644 --- a/packages/components/nodes/chains/ConversationChain/ConversationChain.ts +++ b/packages/components/nodes/chains/ConversationChain/ConversationChain.ts @@ -23,7 +23,8 @@ import { INode, INodeData, INodeParams, - MessageContentImageUrl + MessageContentImageUrl, + IServerSideEventStreamer } from '../../../src/Interface' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { getBaseClasses, handleEscapeCharacters } from '../../../src/utils' @@ -114,13 +115,19 @@ class ConversationChain_Chains implements INode { const chain = await prepareChain(nodeData, options, this.sessionId) const moderations = nodeData.inputs?.inputModeration as Moderation[] + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the LLM chain input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + if (options.shouldStreamResponse) { + streamResponse(options.sseStreamer, options.chatId, e.message) + } return formatResponse(e.message) } } @@ -135,8 +142,8 @@ class ConversationChain_Chains implements INode { callbacks.push(new LCConsoleCallbackHandler()) } - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) callbacks.push(handler) res = await chain.invoke({ input }, { callbacks }) } else { diff --git a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts index a65775e9bdb..29528ae5c69 100644 --- a/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts +++ b/packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts @@ -22,7 +22,8 @@ import { INodeData, INodeParams, IDatabaseEntity, - MemoryMethods + MemoryMethods, + IServerSideEventStreamer } from '../../../src/Interface' import { QA_TEMPLATE, REPHRASE_TEMPLATE, RESPONSE_TEMPLATE } from './prompts' @@ -181,6 +182,10 @@ class ConversationalRetrievalQAChain_Chains implements INode { const databaseEntities = options.databaseEntities as IDatabaseEntity const chatflowid = options.chatflowid as string + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + let customResponsePrompt = responsePrompt // If the deprecated systemMessagePrompt is still exists if (systemMessagePrompt) { @@ -205,7 +210,9 @@ class ConversationalRetrievalQAChain_Chains implements INode { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + if (options.shouldStreamResponse) { + streamResponse(options.sseStreamer, options.chatId, e.message) + } return formatResponse(e.message) } } @@ -234,18 +241,22 @@ class ConversationalRetrievalQAChain_Chains implements INode { let sourceDocuments: ICommonObject[] = [] let text = '' let isStreamingStarted = false - const isStreamingEnabled = options.socketIO && options.socketIOClientId for await (const chunk of stream) { streamedResponse = applyPatch(streamedResponse, chunk.ops).newDocument if (streamedResponse.final_output) { text = streamedResponse.final_output?.output - if (isStreamingEnabled) options.socketIO.to(options.socketIOClientId).emit('end') if (Array.isArray(streamedResponse?.logs?.[sourceRunnableName]?.final_output?.output)) { sourceDocuments = streamedResponse?.logs?.[sourceRunnableName]?.final_output?.output - if (isStreamingEnabled && returnSourceDocuments) - options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', sourceDocuments) + if (shouldStreamResponse && returnSourceDocuments) { + if (sseStreamer) { + sseStreamer.streamSourceDocumentsEvent(chatId, sourceDocuments) + } + } + } + if (shouldStreamResponse && sseStreamer) { + sseStreamer.streamEndEvent(chatId) } } @@ -258,9 +269,17 @@ class ConversationalRetrievalQAChain_Chains implements INode { if (!isStreamingStarted) { isStreamingStarted = true - if (isStreamingEnabled) options.socketIO.to(options.socketIOClientId).emit('start', token) + if (shouldStreamResponse) { + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, token) + } + } + } + if (shouldStreamResponse) { + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, token) + } } - if (isStreamingEnabled) options.socketIO.to(options.socketIOClientId).emit('token', token) } } diff --git a/packages/components/nodes/chains/LLMChain/LLMChain.ts b/packages/components/nodes/chains/LLMChain/LLMChain.ts index 5a7d04569a6..57c0a907c25 100644 --- a/packages/components/nodes/chains/LLMChain/LLMChain.ts +++ b/packages/components/nodes/chains/LLMChain/LLMChain.ts @@ -4,7 +4,15 @@ import { HumanMessage } from '@langchain/core/messages' import { ChatPromptTemplate, FewShotPromptTemplate, HumanMessagePromptTemplate, PromptTemplate } from '@langchain/core/prompts' import { OutputFixingParser } from 'langchain/output_parsers' import { LLMChain } from 'langchain/chains' -import { IVisionChatModal, ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface' +import { + IVisionChatModal, + ICommonObject, + INode, + INodeData, + INodeOutputsValue, + INodeParams, + IServerSideEventStreamer +} from '../../../src/Interface' import { additionalCallbacks, ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler' import { getBaseClasses, handleEscapeCharacters } from '../../../src/utils' import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation' @@ -162,18 +170,22 @@ const runPrediction = async ( const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) - const isStreaming = !disableStreaming && options.socketIO && options.socketIOClientId - const socketIO = isStreaming ? options.socketIO : undefined - const socketIOClientId = isStreaming ? options.socketIOClientId : '' const moderations = nodeData.inputs?.inputModeration as Moderation[] + // this is true if the prediction is external and the client has requested streaming='true' + const shouldStreamResponse = !disableStreaming && options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the LLM chain input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - streamResponse(isStreaming, e.message, socketIO, socketIOClientId) + if (shouldStreamResponse) { + streamResponse(sseStreamer, chatId, e.message) + } return formatResponse(e.message) } } @@ -245,8 +257,8 @@ const runPrediction = async ( if (seen.length === 0) { // All inputVariables have fixed values specified const options = { ...promptValues } - if (isStreaming) { - const handler = new CustomChainHandler(socketIO, socketIOClientId) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) return formatResponse(res?.text) } else { @@ -261,8 +273,8 @@ const runPrediction = async ( ...promptValues, [lastValue]: input } - if (isStreaming) { - const handler = new CustomChainHandler(socketIO, socketIOClientId) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) const res = await chain.call(options, [loggerHandler, handler, ...callbacks]) return formatResponse(res?.text) } else { @@ -273,8 +285,9 @@ const runPrediction = async ( throw new Error(`Please provide Prompt Values for: ${seen.join(', ')}`) } } else { - if (isStreaming) { - const handler = new CustomChainHandler(socketIO, socketIOClientId) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) + const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) return formatResponse(res) } else { diff --git a/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts b/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts index fa91bb205e4..d3d15e4a08a 100644 --- a/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts +++ b/packages/components/nodes/chains/MultiPromptChain/MultiPromptChain.ts @@ -1,6 +1,6 @@ import { BaseLanguageModel } from '@langchain/core/language_models/base' import { MultiPromptChain } from 'langchain/chains' -import { ICommonObject, INode, INodeData, INodeParams, PromptRetriever } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer, PromptRetriever } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation' @@ -75,13 +75,21 @@ class MultiPromptChain_Chains implements INode { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as MultiPromptChain const moderations = nodeData.inputs?.inputModeration as Moderation[] + + // this is true if the prediction is external and the client has requested streaming='true' + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the Multi Prompt Chain input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + if (options.shouldStreamResponse) { + streamResponse(options.sseStreamer, options.chatId, e.message) + } return formatResponse(e.message) } } @@ -90,8 +98,8 @@ class MultiPromptChain_Chains implements INode { const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId, 2) const res = await chain.call(obj, [loggerHandler, handler, ...callbacks]) return res?.text } else { diff --git a/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts b/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts index 71302d635af..8b7889e2450 100644 --- a/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts +++ b/packages/components/nodes/chains/MultiRetrievalQAChain/MultiRetrievalQAChain.ts @@ -1,6 +1,6 @@ import { BaseLanguageModel } from '@langchain/core/language_models/base' import { MultiRetrievalQAChain } from 'langchain/chains' -import { ICommonObject, INode, INodeData, INodeParams, VectorStoreRetriever } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer, VectorStoreRetriever } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation' @@ -83,13 +83,20 @@ class MultiRetrievalQAChain_Chains implements INode { const chain = nodeData.instance as MultiRetrievalQAChain const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean const moderations = nodeData.inputs?.inputModeration as Moderation[] + + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the Multi Retrieval QA Chain input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + if (options.shouldStreamResponse) { + streamResponse(options.sseStreamer, options.chatId, e.message) + } return formatResponse(e.message) } } @@ -97,8 +104,8 @@ class MultiRetrievalQAChain_Chains implements INode { const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2, returnSourceDocuments) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId, 2, returnSourceDocuments) const res = await chain.call(obj, [loggerHandler, handler, ...callbacks]) if (res.text && res.sourceDocuments) return res return res?.text diff --git a/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts b/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts index 9125f38fcae..5b43ffc1369 100644 --- a/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts +++ b/packages/components/nodes/chains/RetrievalQAChain/RetrievalQAChain.ts @@ -2,7 +2,7 @@ import { BaseRetriever } from '@langchain/core/retrievers' import { BaseLanguageModel } from '@langchain/core/language_models/base' import { RetrievalQAChain } from 'langchain/chains' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' -import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation' import { formatResponse } from '../../outputparsers/OutputParserHelpers' @@ -60,13 +60,20 @@ class RetrievalQAChain_Chains implements INode { async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const chain = nodeData.instance as RetrievalQAChain const moderations = nodeData.inputs?.inputModeration as Moderation[] + + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the Retrieval QA Chain input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + streamResponse(sseStreamer, chatId, e.message) + } return formatResponse(e.message) } } @@ -76,8 +83,8 @@ class RetrievalQAChain_Chains implements INode { const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) const res = await chain.call(obj, [loggerHandler, handler, ...callbacks]) return res?.text } else { diff --git a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts index f04102fd4a9..dbdd4698053 100644 --- a/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts +++ b/packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts @@ -4,7 +4,7 @@ import { BaseLanguageModel } from '@langchain/core/language_models/base' import { PromptTemplate, PromptTemplateInput } from '@langchain/core/prompts' import { SqlDatabaseChain, SqlDatabaseChainInput, DEFAULT_SQL_DATABASE_PROMPT } from 'langchain/chains/sql_db' import { SqlDatabase } from 'langchain/sql_db' -import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer } from '../../../src/Interface' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' import { getBaseClasses, getInputVariables } from '../../../src/utils' import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation' @@ -166,13 +166,21 @@ class SqlDatabaseChain_Chains implements INode { const topK = nodeData.inputs?.topK as number const customPrompt = nodeData.inputs?.customPrompt as string const moderations = nodeData.inputs?.inputModeration as Moderation[] + + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the Sql Database Chain input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + streamResponse(sseStreamer, chatId, e.message) + } + // streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) return formatResponse(e.message) } } @@ -190,8 +198,9 @@ class SqlDatabaseChain_Chains implements INode { const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId, 2) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId, 2) + const res = await chain.run(input, [loggerHandler, handler, ...callbacks]) return res } else { diff --git a/packages/components/nodes/chains/VectaraChain/VectaraChain.ts b/packages/components/nodes/chains/VectaraChain/VectaraChain.ts index e5427ca0f7d..dc34a166fb4 100644 --- a/packages/components/nodes/chains/VectaraChain/VectaraChain.ts +++ b/packages/components/nodes/chains/VectaraChain/VectaraChain.ts @@ -269,7 +269,9 @@ class VectaraChain_Chains implements INode { input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - //streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + // if (options.shouldStreamResponse) { + // streamResponse(options.sseStreamer, options.chatId, e.message) + // } return formatResponse(e.message) } } diff --git a/packages/components/nodes/chains/VectorDBQAChain/VectorDBQAChain.ts b/packages/components/nodes/chains/VectorDBQAChain/VectorDBQAChain.ts index 129eb46acdd..aab8d824e0c 100644 --- a/packages/components/nodes/chains/VectorDBQAChain/VectorDBQAChain.ts +++ b/packages/components/nodes/chains/VectorDBQAChain/VectorDBQAChain.ts @@ -2,7 +2,7 @@ import { BaseLanguageModel } from '@langchain/core/language_models/base' import { VectorStore } from '@langchain/core/vectorstores' import { VectorDBQAChain } from 'langchain/chains' import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' -import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeParams, IServerSideEventStreamer } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' import { checkInputs, Moderation } from '../../moderation/Moderation' import { formatResponse } from '../../outputparsers/OutputParserHelpers' @@ -64,13 +64,19 @@ class VectorDBQAChain_Chains implements INode { const chain = nodeData.instance as VectorDBQAChain const moderations = nodeData.inputs?.inputModeration as Moderation[] + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + if (moderations && moderations.length > 0) { try { // Use the output of the moderation chain as input for the VectorDB QA Chain input = await checkInputs(moderations, input) } catch (e) { await new Promise((resolve) => setTimeout(resolve, 500)) - //streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) + // if (options.shouldStreamResponse) { + // streamResponse(options.sseStreamer, options.chatId, e.message) + // } return formatResponse(e.message) } } @@ -81,8 +87,8 @@ class VectorDBQAChain_Chains implements INode { const loggerHandler = new ConsoleCallbackHandler(options.logger) const callbacks = await additionalCallbacks(nodeData, options) - if (options.socketIO && options.socketIOClientId) { - const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) + if (shouldStreamResponse) { + const handler = new CustomChainHandler(sseStreamer, chatId) const res = await chain.call(obj, [loggerHandler, handler, ...callbacks]) return res?.text } else { diff --git a/packages/components/nodes/engine/ChatEngine/ContextChatEngine.ts b/packages/components/nodes/engine/ChatEngine/ContextChatEngine.ts index a5bacaad070..35b6ae069b4 100644 --- a/packages/components/nodes/engine/ChatEngine/ContextChatEngine.ts +++ b/packages/components/nodes/engine/ChatEngine/ContextChatEngine.ts @@ -1,4 +1,13 @@ -import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface' +import { + FlowiseMemory, + ICommonObject, + IMessage, + INode, + INodeData, + INodeOutputsValue, + INodeParams, + IServerSideEventStreamer +} from '../../../src/Interface' import { Metadata, BaseRetriever, LLM, ContextChatEngine, ChatMessage, NodeWithScore } from 'llamaindex' import { reformatSourceDocuments } from '../EngineUtils' @@ -103,24 +112,33 @@ class ContextChatEngine_LlamaIndex implements INode { let isStreamingStarted = false let sourceDocuments: ICommonObject[] = [] let sourceNodes: NodeWithScore[] = [] - const isStreamingEnabled = options.socketIO && options.socketIOClientId - if (isStreamingEnabled) { + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + + if (shouldStreamResponse) { const stream = await chatEngine.chat({ message: input, chatHistory, stream: true }) for await (const chunk of stream) { text += chunk.response if (chunk.sourceNodes) sourceNodes = chunk.sourceNodes if (!isStreamingStarted) { isStreamingStarted = true - options.socketIO.to(options.socketIOClientId).emit('start', chunk.response) + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, chunk.response) + } } - options.socketIO.to(options.socketIOClientId).emit('token', chunk.response) + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, chunk.response) + } } if (returnSourceDocuments) { sourceDocuments = reformatSourceDocuments(sourceNodes) - options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', sourceDocuments) + if (sseStreamer) { + sseStreamer.streamSourceDocumentsEvent(chatId, sourceDocuments) + } } } else { const response = await chatEngine.chat({ message: input, chatHistory }) diff --git a/packages/components/nodes/engine/ChatEngine/SimpleChatEngine.ts b/packages/components/nodes/engine/ChatEngine/SimpleChatEngine.ts index 5734288d1da..e6045fda6c6 100644 --- a/packages/components/nodes/engine/ChatEngine/SimpleChatEngine.ts +++ b/packages/components/nodes/engine/ChatEngine/SimpleChatEngine.ts @@ -1,4 +1,13 @@ -import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface' +import { + FlowiseMemory, + ICommonObject, + IMessage, + INode, + INodeData, + INodeOutputsValue, + INodeParams, + IServerSideEventStreamer +} from '../../../src/Interface' import { LLM, ChatMessage, SimpleChatEngine } from 'llamaindex' class SimpleChatEngine_LlamaIndex implements INode { @@ -86,18 +95,24 @@ class SimpleChatEngine_LlamaIndex implements INode { let text = '' let isStreamingStarted = false - const isStreamingEnabled = options.socketIO && options.socketIOClientId - if (isStreamingEnabled) { + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + + if (shouldStreamResponse) { const stream = await chatEngine.chat({ message: input, chatHistory, stream: true }) for await (const chunk of stream) { text += chunk.response if (!isStreamingStarted) { isStreamingStarted = true - options.socketIO.to(options.socketIOClientId).emit('start', chunk.response) + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, chunk.response) + } + } + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, chunk.response) } - - options.socketIO.to(options.socketIOClientId).emit('token', chunk.response) } } else { const response = await chatEngine.chat({ message: input, chatHistory }) diff --git a/packages/components/nodes/engine/QueryEngine/QueryEngine.ts b/packages/components/nodes/engine/QueryEngine/QueryEngine.ts index 7d8d4fe4cad..14eb3c5de1d 100644 --- a/packages/components/nodes/engine/QueryEngine/QueryEngine.ts +++ b/packages/components/nodes/engine/QueryEngine/QueryEngine.ts @@ -1,4 +1,4 @@ -import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams, IServerSideEventStreamer } from '../../../src/Interface' import { RetrieverQueryEngine, ResponseSynthesizer, @@ -71,24 +71,32 @@ class QueryEngine_LlamaIndex implements INode { let sourceDocuments: ICommonObject[] = [] let sourceNodes: NodeWithScore[] = [] let isStreamingStarted = false - const isStreamingEnabled = options.socketIO && options.socketIOClientId - if (isStreamingEnabled) { + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + + if (shouldStreamResponse) { const stream = await queryEngine.query({ query: input, stream: true }) for await (const chunk of stream) { text += chunk.response if (chunk.sourceNodes) sourceNodes = chunk.sourceNodes if (!isStreamingStarted) { isStreamingStarted = true - options.socketIO.to(options.socketIOClientId).emit('start', chunk.response) + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, chunk.response) + } + } + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, chunk.response) } - - options.socketIO.to(options.socketIOClientId).emit('token', chunk.response) } if (returnSourceDocuments) { sourceDocuments = reformatSourceDocuments(sourceNodes) - options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', sourceDocuments) + if (sseStreamer) { + sseStreamer.streamSourceDocumentsEvent(chatId, sourceDocuments) + } } } else { const response = await queryEngine.query({ query: input }) diff --git a/packages/components/nodes/engine/SubQuestionQueryEngine/SubQuestionQueryEngine.ts b/packages/components/nodes/engine/SubQuestionQueryEngine/SubQuestionQueryEngine.ts index eb6f85dccb4..6d8ceead9c4 100644 --- a/packages/components/nodes/engine/SubQuestionQueryEngine/SubQuestionQueryEngine.ts +++ b/packages/components/nodes/engine/SubQuestionQueryEngine/SubQuestionQueryEngine.ts @@ -1,5 +1,5 @@ import { flatten } from 'lodash' -import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams } from '../../../src/Interface' +import { ICommonObject, INode, INodeData, INodeOutputsValue, INodeParams, IServerSideEventStreamer } from '../../../src/Interface' import { TreeSummarize, SimpleResponseBuilder, @@ -88,24 +88,32 @@ class SubQuestionQueryEngine_LlamaIndex implements INode { let sourceDocuments: ICommonObject[] = [] let sourceNodes: NodeWithScore[] = [] let isStreamingStarted = false - const isStreamingEnabled = options.socketIO && options.socketIOClientId - if (isStreamingEnabled) { + const shouldStreamResponse = options.shouldStreamResponse + const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer + const chatId = options.chatId + + if (shouldStreamResponse) { const stream = await queryEngine.query({ query: input, stream: true }) for await (const chunk of stream) { text += chunk.response if (chunk.sourceNodes) sourceNodes = chunk.sourceNodes if (!isStreamingStarted) { isStreamingStarted = true - options.socketIO.to(options.socketIOClientId).emit('start', chunk.response) + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, chunk.response) + } + } + if (sseStreamer) { + sseStreamer.streamTokenEvent(chatId, chunk.response) } - - options.socketIO.to(options.socketIOClientId).emit('token', chunk.response) } if (returnSourceDocuments) { sourceDocuments = reformatSourceDocuments(sourceNodes) - options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', sourceDocuments) + if (sseStreamer) { + sseStreamer.streamSourceDocumentsEvent(chatId, sourceDocuments) + } } } else { const response = await queryEngine.query({ query: input }) diff --git a/packages/components/nodes/moderation/Moderation.ts b/packages/components/nodes/moderation/Moderation.ts index 9fd2bfde340..7e37ee5be17 100644 --- a/packages/components/nodes/moderation/Moderation.ts +++ b/packages/components/nodes/moderation/Moderation.ts @@ -1,4 +1,4 @@ -import { Server } from 'socket.io' +import { IServerSideEventStreamer } from '../../src' export abstract class Moderation { abstract checkForViolations(input: string): Promise @@ -13,15 +13,13 @@ export const checkInputs = async (inputModerations: Moderation[], input: string) // is this the correct location for this function? // should we have a utils files that all node components can use? -export const streamResponse = (isStreaming: any, response: string, socketIO: Server, socketIOClientId: string) => { - if (isStreaming) { - const result = response.split(/(\s+)/) - result.forEach((token: string, index: number) => { - if (index === 0) { - socketIO.to(socketIOClientId).emit('start', token) - } - socketIO.to(socketIOClientId).emit('token', token) - }) - socketIO.to(socketIOClientId).emit('end') - } +export const streamResponse = (sseStreamer: IServerSideEventStreamer, chatId: string, response: string) => { + const result = response.split(/(\s+)/) + result.forEach((token: string, index: number) => { + if (index === 0) { + sseStreamer.streamStartEvent(chatId, token) + } + sseStreamer.streamTokenEvent(chatId, token) + }) + sseStreamer.streamEndEvent(chatId) } diff --git a/packages/components/nodes/tools/ChainTool/core.ts b/packages/components/nodes/tools/ChainTool/core.ts index 60ba5977637..e43c126f8d5 100644 --- a/packages/components/nodes/tools/ChainTool/core.ts +++ b/packages/components/nodes/tools/ChainTool/core.ts @@ -1,6 +1,7 @@ import { DynamicTool, DynamicToolInput } from '@langchain/core/tools' import { BaseChain } from 'langchain/chains' import { handleEscapeCharacters } from '../../../src/utils' +import { CustomChainHandler } from '../../../src' export interface ChainToolInput extends Omit { chain: BaseChain @@ -13,13 +14,42 @@ export class ChainTool extends DynamicTool { super({ ...rest, func: async (input, runManager) => { - // To enable LLM Chain which has promptValues + // prevent sending SSE events of the sub-chain + const sseStreamer = runManager?.handlers.find((handler) => handler instanceof CustomChainHandler)?.sseStreamer + if (runManager) { + const callbacks = runManager.handlers + for (let i = 0; i < callbacks.length; i += 1) { + if (callbacks[i] instanceof CustomChainHandler) { + ;(callbacks[i] as any).sseStreamer = undefined + } + } + } + if ((chain as any).prompt && (chain as any).prompt.promptValues) { const promptValues = handleEscapeCharacters((chain as any).prompt.promptValues, true) + const values = await chain.call(promptValues, runManager?.getChild()) + if (runManager && sseStreamer) { + const callbacks = runManager.handlers + for (let i = 0; i < callbacks.length; i += 1) { + if (callbacks[i] instanceof CustomChainHandler) { + ;(callbacks[i] as any).sseStreamer = sseStreamer + } + } + } return values?.text } - return chain.run(input, runManager?.getChild()) + + const values = chain.run(input, runManager?.getChild()) + if (runManager && sseStreamer) { + const callbacks = runManager.handlers + for (let i = 0; i < callbacks.length; i += 1) { + if (callbacks[i] instanceof CustomChainHandler) { + ;(callbacks[i] as any).sseStreamer = sseStreamer + } + } + } + return values } }) this.chain = chain diff --git a/packages/components/nodes/tools/ChatflowTool/ChatflowTool.ts b/packages/components/nodes/tools/ChatflowTool/ChatflowTool.ts index 0a8d5516804..3d476cbe18a 100644 --- a/packages/components/nodes/tools/ChatflowTool/ChatflowTool.ts +++ b/packages/components/nodes/tools/ChatflowTool/ChatflowTool.ts @@ -7,6 +7,7 @@ import { StructuredTool } from '@langchain/core/tools' import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeOptionsValue, INodeParams } from '../../../src/Interface' import { availableDependencies, defaultAllowBuiltInDep, getCredentialData, getCredentialParam } from '../../../src/utils' import { v4 as uuidv4 } from 'uuid' +import { CustomChainHandler } from '../../../src' class ChatflowTool_Tools implements INode { label: string @@ -219,6 +220,15 @@ class ChatflowTool extends StructuredTool { } catch (e) { throw new Error(`Received tool input did not match expected schema: ${JSON.stringify(arg)}`) } + // iterate over the callbacks and the sse streamer + if (config.callbacks instanceof CallbackManager) { + const callbacks = config.callbacks.handlers + for (let i = 0; i < callbacks.length; i += 1) { + if (callbacks[i] instanceof CustomChainHandler) { + ;(callbacks[i] as any).sseStreamer = undefined + } + } + } const callbackManager_ = await CallbackManager.configure( config.callbacks, this.callbacks, diff --git a/packages/components/src/Interface.ts b/packages/components/src/Interface.ts index c3d2a72d409..ecbcb250e6d 100644 --- a/packages/components/src/Interface.ts +++ b/packages/components/src/Interface.ts @@ -400,3 +400,22 @@ export interface IStateWithMessages extends ICommonObject { messages: BaseMessage[] [key: string]: any } + +export interface IServerSideEventStreamer { + streamEvent(chatId: string, data: string): void + streamStartEvent(chatId: string, data: any): void + + streamTokenEvent(chatId: string, data: string): void + streamCustomEvent(chatId: string, eventType: string, data: any): void + + streamSourceDocumentsEvent(chatId: string, data: any): void + streamUsedToolsEvent(chatId: string, data: any): void + streamFileAnnotationsEvent(chatId: string, data: any): void + streamToolEvent(chatId: string, data: any): void + streamAgentReasoningEvent(chatId: string, data: any): void + streamNextAgentEvent(chatId: string, data: any): void + streamActionEvent(chatId: string, data: any): void + + streamAbortEvent(chatId: string): void + streamEndEvent(chatId: string): void +} diff --git a/packages/components/src/handler.ts b/packages/components/src/handler.ts index faf0b260ba2..2c0679c7a88 100644 --- a/packages/components/src/handler.ts +++ b/packages/components/src/handler.ts @@ -1,6 +1,5 @@ import { Logger } from 'winston' import { v4 as uuidv4 } from 'uuid' -import { Server } from 'socket.io' import { Client } from 'langsmith' import CallbackHandler from 'langfuse-langchain' import lunary from 'lunary' @@ -15,7 +14,7 @@ import { AgentAction } from '@langchain/core/agents' import { LunaryHandler } from '@langchain/community/callbacks/handlers/lunary' import { getCredentialData, getCredentialParam, getEnvironmentVariable } from './utils' -import { ICommonObject, INodeData } from './Interface' +import { ICommonObject, INodeData, IServerSideEventStreamer } from './Interface' import { LangWatch, LangWatchSpan, LangWatchTrace, autoconvertTypedValues } from 'langwatch' interface AgentRun extends Run { @@ -163,16 +162,16 @@ export class ConsoleCallbackHandler extends BaseTracer { export class CustomChainHandler extends BaseCallbackHandler { name = 'custom_chain_handler' isLLMStarted = false - socketIO: Server - socketIOClientId = '' skipK = 0 // Skip streaming for first K numbers of handleLLMStart returnSourceDocuments = false cachedResponse = true + chatId: string = '' + sseStreamer: IServerSideEventStreamer | undefined - constructor(socketIO: Server, socketIOClientId: string, skipK?: number, returnSourceDocuments?: boolean) { + constructor(sseStreamer: IServerSideEventStreamer | undefined, chatId: string, skipK?: number, returnSourceDocuments?: boolean) { super() - this.socketIO = socketIO - this.socketIOClientId = socketIOClientId + this.sseStreamer = sseStreamer + this.chatId = chatId this.skipK = skipK ?? this.skipK this.returnSourceDocuments = returnSourceDocuments ?? this.returnSourceDocuments } @@ -186,14 +185,20 @@ export class CustomChainHandler extends BaseCallbackHandler { if (this.skipK === 0) { if (!this.isLLMStarted) { this.isLLMStarted = true - this.socketIO.to(this.socketIOClientId).emit('start', token) + if (this.sseStreamer) { + this.sseStreamer.streamStartEvent(this.chatId, token) + } + } + if (this.sseStreamer) { + this.sseStreamer.streamTokenEvent(this.chatId, token) } - this.socketIO.to(this.socketIOClientId).emit('token', token) } } handleLLMEnd() { - this.socketIO.to(this.socketIOClientId).emit('end') + if (this.sseStreamer) { + this.sseStreamer.streamEndEvent(this.chatId) + } } handleChainEnd(outputs: ChainValues, _: string, parentRunId?: string): void | Promise { @@ -208,17 +213,23 @@ export class CustomChainHandler extends BaseCallbackHandler { const result = cachedValue.split(/(\s+)/) result.forEach((token: string, index: number) => { if (index === 0) { - this.socketIO.to(this.socketIOClientId).emit('start', token) + if (this.sseStreamer) { + this.sseStreamer.streamStartEvent(this.chatId, token) + } + } + if (this.sseStreamer) { + this.sseStreamer.streamTokenEvent(this.chatId, token) } - this.socketIO.to(this.socketIOClientId).emit('token', token) }) - if (this.returnSourceDocuments) { - this.socketIO.to(this.socketIOClientId).emit('sourceDocuments', outputs?.sourceDocuments) + if (this.returnSourceDocuments && this.sseStreamer) { + this.sseStreamer.streamSourceDocumentsEvent(this.chatId, outputs?.sourceDocuments) + } + if (this.sseStreamer) { + this.sseStreamer.streamEndEvent(this.chatId) } - this.socketIO.to(this.socketIOClientId).emit('end') } else { - if (this.returnSourceDocuments) { - this.socketIO.to(this.socketIOClientId).emit('sourceDocuments', outputs?.sourceDocuments) + if (this.returnSourceDocuments && this.sseStreamer) { + this.sseStreamer.streamSourceDocumentsEvent(this.chatId, outputs?.sourceDocuments) } } } diff --git a/packages/server/src/Interface.ts b/packages/server/src/Interface.ts index f4104576233..97a6d1d9bf9 100644 --- a/packages/server/src/Interface.ts +++ b/packages/server/src/Interface.ts @@ -216,7 +216,6 @@ export interface IMessage { export interface IncomingInput { question: string overrideConfig?: ICommonObject - socketIOClientId?: string chatId?: string stopNodeId?: string uploads?: IFileUpload[] diff --git a/packages/server/src/controllers/internal-predictions/index.ts b/packages/server/src/controllers/internal-predictions/index.ts index 1fea952d7a3..6539d76e0e7 100644 --- a/packages/server/src/controllers/internal-predictions/index.ts +++ b/packages/server/src/controllers/internal-predictions/index.ts @@ -1,16 +1,45 @@ import { Request, Response, NextFunction } from 'express' import { utilBuildChatflow } from '../../utils/buildChatflow' +import { getRunningExpressApp } from '../../utils/getRunningExpressApp' +import { getErrorMessage } from '../../errors/utils' // Send input message and get prediction result (Internal) const createInternalPrediction = async (req: Request, res: Response, next: NextFunction) => { try { - const apiResponse = await utilBuildChatflow(req, req.io, true) - return res.json(apiResponse) + if (req.body.streaming || req.body.streaming === 'true') { + createAndStreamInternalPrediction(req, res, next) + return + } else { + const apiResponse = await utilBuildChatflow(req, true) + return res.json(apiResponse) + } } catch (error) { next(error) } } +// Send input message and stream prediction result using SSE (Internal) +const createAndStreamInternalPrediction = async (req: Request, res: Response, next: NextFunction) => { + const chatId = req.body.chatId + const sseStreamer = getRunningExpressApp().sseStreamer + try { + sseStreamer.addClient(chatId, res) + res.setHeader('Content-Type', 'text/event-stream') + res.setHeader('Cache-Control', 'no-cache') + res.setHeader('Connection', 'keep-alive') + res.flushHeaders() + + const apiResponse = await utilBuildChatflow(req, true) + sseStreamer.streamMetadataEvent(apiResponse.chatId, apiResponse) + } catch (error) { + if (chatId) { + sseStreamer.streamErrorEvent(chatId, getErrorMessage(error)) + } + next(error) + } finally { + sseStreamer.removeClient(chatId) + } +} export default { createInternalPrediction } diff --git a/packages/server/src/controllers/predictions/index.ts b/packages/server/src/controllers/predictions/index.ts index 0e78c80f7cd..ae649b9bc81 100644 --- a/packages/server/src/controllers/predictions/index.ts +++ b/packages/server/src/controllers/predictions/index.ts @@ -5,6 +5,9 @@ import logger from '../../utils/logger' import predictionsServices from '../../services/predictions' import { InternalFlowiseError } from '../../errors/internalFlowiseError' import { StatusCodes } from 'http-status-codes' +import { getRunningExpressApp } from '../../utils/getRunningExpressApp' +import { v4 as uuidv4 } from 'uuid' +import { getErrorMessage } from '../../errors/utils' // Send input message and get prediction result (External) const createPrediction = async (req: Request, res: Response, next: NextFunction) => { @@ -46,9 +49,36 @@ const createPrediction = async (req: Request, res: Response, next: NextFunction) } } if (isDomainAllowed) { - //@ts-ignore - const apiResponse = await predictionsServices.buildChatflow(req, req?.io) - return res.json(apiResponse) + const streamable = await chatflowsService.checkIfChatflowIsValidForStreaming(req.params.id) + const isStreamingRequested = req.body.streaming === 'true' || req.body.streaming === true + if (streamable?.isStreaming && isStreamingRequested) { + const sseStreamer = getRunningExpressApp().sseStreamer + let chatId = req.body.chatId + if (!req.body.chatId) { + chatId = req.body.chatId ?? req.body.overrideConfig?.sessionId ?? uuidv4() + req.body.chatId = chatId + } + try { + sseStreamer.addExternalClient(chatId, res) + res.setHeader('Content-Type', 'text/event-stream') + res.setHeader('Cache-Control', 'no-cache') + res.setHeader('Connection', 'keep-alive') + res.flushHeaders() + + const apiResponse = await predictionsServices.buildChatflow(req) + sseStreamer.streamMetadataEvent(apiResponse.chatId, apiResponse) + } catch (error) { + if (chatId) { + sseStreamer.streamErrorEvent(chatId, getErrorMessage(error)) + } + next(error) + } finally { + sseStreamer.removeClient(chatId) + } + } else { + const apiResponse = await predictionsServices.buildChatflow(req) + return res.json(apiResponse) + } } else { throw new InternalFlowiseError(StatusCodes.UNAUTHORIZED, `This site is not allowed to access this chatbot`) } diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 406cb42d9c0..b14c83cf7be 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -20,6 +20,7 @@ import { sanitizeMiddleware, getCorsOptions, getAllowedIframeOrigins } from './u import { Telemetry } from './utils/telemetry' import flowiseApiV1Router from './routes' import errorHandlerMiddleware from './middlewares/errors' +import { SSEStreamer } from './utils/SSEStreamer' import { validateAPIKey } from './utils/validateKey' declare global { @@ -37,6 +38,7 @@ export class App { cachePool: CachePool telemetry: Telemetry AppDataSource: DataSource = getDataSource() + sseStreamer: SSEStreamer constructor() { this.app = express() @@ -200,6 +202,7 @@ export class App { } this.app.use('/api/v1', flowiseApiV1Router) + this.sseStreamer = new SSEStreamer(this.app) // ---------------------------------------- // Configure number of proxies in Host Environment diff --git a/packages/server/src/middlewares/errors/index.ts b/packages/server/src/middlewares/errors/index.ts index ea0ab513ae8..75cd2c21b24 100644 --- a/packages/server/src/middlewares/errors/index.ts +++ b/packages/server/src/middlewares/errors/index.ts @@ -12,8 +12,10 @@ async function errorHandlerMiddleware(err: InternalFlowiseError, req: Request, r // Provide error stack trace only in development stack: process.env.NODE_ENV === 'development' ? err.stack : {} } - res.setHeader('Content-Type', 'application/json') - res.status(displayedError.statusCode).json(displayedError) + if (!req.body.streaming || req.body.streaming === 'false') { + res.setHeader('Content-Type', 'application/json') + res.status(displayedError.statusCode).json(displayedError) + } } export default errorHandlerMiddleware diff --git a/packages/server/src/services/predictions/index.ts b/packages/server/src/services/predictions/index.ts index e7411fe6ac5..6f2dbe199c1 100644 --- a/packages/server/src/services/predictions/index.ts +++ b/packages/server/src/services/predictions/index.ts @@ -1,13 +1,12 @@ import { Request } from 'express' -import { Server } from 'socket.io' import { StatusCodes } from 'http-status-codes' import { utilBuildChatflow } from '../../utils/buildChatflow' import { InternalFlowiseError } from '../../errors/internalFlowiseError' import { getErrorMessage } from '../../errors/utils' -const buildChatflow = async (fullRequest: Request, ioServer: Server) => { +const buildChatflow = async (fullRequest: Request) => { try { - const dbResponse = await utilBuildChatflow(fullRequest, ioServer) + const dbResponse = await utilBuildChatflow(fullRequest) return dbResponse } catch (error) { throw new InternalFlowiseError( diff --git a/packages/server/src/utils/SSEStreamer.ts b/packages/server/src/utils/SSEStreamer.ts new file mode 100644 index 00000000000..0e81b8380ca --- /dev/null +++ b/packages/server/src/utils/SSEStreamer.ts @@ -0,0 +1,208 @@ +import express from 'express' +import { Response } from 'express' +import { IServerSideEventStreamer } from 'flowise-components' + +// define a new type that has a client type (INTERNAL or EXTERNAL) and Response type +type Client = { + // future use + clientType: 'INTERNAL' | 'EXTERNAL' + response: Response + // optional property with default value + started?: boolean +} + +export class SSEStreamer implements IServerSideEventStreamer { + clients: { [id: string]: Client } = {} + app: express.Application + + constructor(app: express.Application) { + this.app = app + } + + addExternalClient(chatId: string, res: Response) { + this.clients[chatId] = { clientType: 'EXTERNAL', response: res, started: false } + } + + addClient(chatId: string, res: Response) { + this.clients[chatId] = { clientType: 'INTERNAL', response: res, started: false } + } + + removeClient(chatId: string) { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'end', + data: '[DONE]' + } + client.response.write('message\ndata:' + JSON.stringify(clientResponse) + '\n\n') + client.response.end() + delete this.clients[chatId] + } + } + + // Send SSE message to a specific client + streamEvent(chatId: string, data: string) { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'start', + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + + streamCustomEvent(chatId: string, eventType: string, data: any) { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: eventType, + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + + streamStartEvent(chatId: string, data: string) { + const client = this.clients[chatId] + // prevent multiple start events being streamed to the client + if (client && !client.started) { + const clientResponse = { + event: 'start', + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + client.started = true + } + } + + streamTokenEvent(chatId: string, data: string) { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'token', + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + + streamSourceDocumentsEvent(chatId: string, data: any) { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'sourceDocuments', + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + streamUsedToolsEvent(chatId: string, data: any): void { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'usedTools', + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + streamFileAnnotationsEvent(chatId: string, data: any): void { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'fileAnnotations', + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + streamToolEvent(chatId: string, data: any): void { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'tool', + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + streamAgentReasoningEvent(chatId: string, data: any): void { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'agentReasoning', + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + streamNextAgentEvent(chatId: string, data: any): void { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'nextAgent', + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + streamActionEvent(chatId: string, data: any): void { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'action', + data: data + } + client.response.write('message:\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + + streamAbortEvent(chatId: string): void { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'abort', + data: '[DONE]' + } + client.response.write('message\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + + streamEndEvent(_: string) { + // placeholder for future use + } + + streamErrorEvent(chatId: string, msg: string) { + const client = this.clients[chatId] + if (client) { + const clientResponse = { + event: 'error', + data: msg + } + client.response.write('message\ndata:' + JSON.stringify(clientResponse) + '\n\n') + } + } + + streamMetadataEvent(chatId: string, apiResponse: any) { + const metadataJson: any = {} + if (apiResponse.chatId) { + metadataJson['chatId'] = apiResponse.chatId + } + if (apiResponse.chatMessageId) { + metadataJson['chatMessageId'] = apiResponse.chatMessageId + } + if (apiResponse.question) { + metadataJson['question'] = apiResponse.question + } + if (apiResponse.sessionId) { + metadataJson['sessionId'] = apiResponse.sessionId + } + if (apiResponse.memoryType) { + metadataJson['memoryType'] = apiResponse.memoryType + } + if (Object.keys(metadataJson).length > 0) { + this.streamCustomEvent(chatId, 'metadata', metadataJson) + } + } +} diff --git a/packages/server/src/utils/buildAgentGraph.ts b/packages/server/src/utils/buildAgentGraph.ts index 945bc07ce75..cd6fafe7a5a 100644 --- a/packages/server/src/utils/buildAgentGraph.ts +++ b/packages/server/src/utils/buildAgentGraph.ts @@ -9,9 +9,9 @@ import { ISeqAgentsState, ISeqAgentNode, IUsedTool, - IDocument + IDocument, + IServerSideEventStreamer } from 'flowise-components' -import { Server } from 'socket.io' import { omit, cloneDeep, flatten, uniq } from 'lodash' import { StateGraph, END, START } from '@langchain/langgraph' import { Document } from '@langchain/core/documents' @@ -53,7 +53,6 @@ import logger from './logger' * @param {ICommonObject} incomingInput * @param {boolean} isInternal * @param {string} baseURL - * @param {Server} socketIO */ export const buildAgentGraph = async ( chatflow: IChatFlow, @@ -62,7 +61,8 @@ export const buildAgentGraph = async ( incomingInput: IncomingInput, isInternal: boolean, baseURL?: string, - socketIO?: Server + sseStreamer?: IServerSideEventStreamer, + shouldStreamResponse?: boolean ): Promise => { try { const appServer = getRunningExpressApp() @@ -287,28 +287,31 @@ export const buildAgentGraph = async ( ? output[agentName].messages[output[agentName].messages.length - 1].content : lastWorkerResult - if (socketIO && incomingInput.socketIOClientId) { + if (shouldStreamResponse) { if (!isStreamingStarted) { isStreamingStarted = true - socketIO.to(incomingInput.socketIOClientId).emit('start', agentReasoning) + if (sseStreamer) { + sseStreamer.streamStartEvent(chatId, agentReasoning) + } } - socketIO.to(incomingInput.socketIOClientId).emit('agentReasoning', agentReasoning) + if (sseStreamer) { + sseStreamer.streamAgentReasoningEvent(chatId, agentReasoning) + } // Send loading next agent indicator if (reasoning.next && reasoning.next !== 'FINISH' && reasoning.next !== 'END') { - socketIO - .to(incomingInput.socketIOClientId) - .emit('nextAgent', mapNameToLabel[reasoning.next].label || reasoning.next) + if (sseStreamer) { + sseStreamer.streamNextAgentEvent(chatId, mapNameToLabel[reasoning.next].label || reasoning.next) + } } } } } else { finalResult = output.__end__.messages.length ? output.__end__.messages.pop()?.content : '' if (Array.isArray(finalResult)) finalResult = output.__end__.instructions - - if (socketIO && incomingInput.socketIOClientId) { - socketIO.to(incomingInput.socketIOClientId).emit('token', finalResult) + if (shouldStreamResponse && sseStreamer) { + sseStreamer.streamTokenEvent(chatId, finalResult) } } } @@ -321,9 +324,8 @@ export const buildAgentGraph = async ( if (!isSequential && !finalResult) { if (lastWorkerResult) finalResult = lastWorkerResult else if (finalSummarization) finalResult = finalSummarization - - if (socketIO && incomingInput.socketIOClientId) { - socketIO.to(incomingInput.socketIOClientId).emit('token', finalResult) + if (shouldStreamResponse && sseStreamer) { + sseStreamer.streamTokenEvent(chatId, finalResult) } } @@ -377,16 +379,16 @@ export const buildAgentGraph = async ( { type: 'reject-button', label: rejectButtonText } ] } - if (socketIO && incomingInput.socketIOClientId) { - socketIO.to(incomingInput.socketIOClientId).emit('token', finalResult) - socketIO.to(incomingInput.socketIOClientId).emit('action', finalAction) + if (shouldStreamResponse && sseStreamer) { + sseStreamer.streamTokenEvent(chatId, finalResult) + sseStreamer.streamActionEvent(chatId, finalAction) } } totalUsedTools.push(...mappedToolCalls) } else if (lastAgentReasoningMessage) { finalResult = lastAgentReasoningMessage - if (socketIO && incomingInput.socketIOClientId) { - socketIO.to(incomingInput.socketIOClientId).emit('token', finalResult) + if (shouldStreamResponse && sseStreamer) { + sseStreamer.streamTokenEvent(chatId, finalResult) } } } @@ -394,10 +396,10 @@ export const buildAgentGraph = async ( totalSourceDocuments = uniq(flatten(totalSourceDocuments)) totalUsedTools = uniq(flatten(totalUsedTools)) - if (socketIO && incomingInput.socketIOClientId) { - socketIO.to(incomingInput.socketIOClientId).emit('usedTools', totalUsedTools) - socketIO.to(incomingInput.socketIOClientId).emit('sourceDocuments', totalSourceDocuments) - socketIO.to(incomingInput.socketIOClientId).emit('end') + if (shouldStreamResponse && sseStreamer) { + sseStreamer.streamUsedToolsEvent(chatId, totalUsedTools) + sseStreamer.streamSourceDocumentsEvent(chatId, totalSourceDocuments) + sseStreamer.streamEndEvent(chatId) } return { @@ -412,8 +414,8 @@ export const buildAgentGraph = async ( // clear agent memory because checkpoints were saved during runtime await clearSessionMemory(nodes, appServer.nodesPool.componentNodes, chatId, appServer.AppDataSource, sessionId) if (getErrorMessage(e).includes('Aborted')) { - if (socketIO && incomingInput.socketIOClientId) { - socketIO.to(incomingInput.socketIOClientId).emit('abort') + if (shouldStreamResponse && sseStreamer) { + sseStreamer.streamAbortEvent(chatId) } return { finalResult, agentReasoning } } diff --git a/packages/server/src/utils/buildChatflow.ts b/packages/server/src/utils/buildChatflow.ts index c7ab304c8cc..7cd7b68f8a8 100644 --- a/packages/server/src/utils/buildChatflow.ts +++ b/packages/server/src/utils/buildChatflow.ts @@ -5,7 +5,8 @@ import { ICommonObject, addSingleFileToStorage, addArrayFilesToStorage, - mapMimeTypeToInputField + mapMimeTypeToInputField, + IServerSideEventStreamer } from 'flowise-components' import { StatusCodes } from 'http-status-codes' import { @@ -22,7 +23,6 @@ import { } from '../Interface' import { InternalFlowiseError } from '../errors/internalFlowiseError' import { ChatFlow } from '../database/entities/ChatFlow' -import { Server } from 'socket.io' import { getRunningExpressApp } from '../utils/getRunningExpressApp' import { isFlowValidForStream, @@ -56,10 +56,9 @@ import { IAction } from 'flowise-components' /** * Build Chatflow * @param {Request} req - * @param {Server} socketIO * @param {boolean} isInternal */ -export const utilBuildChatflow = async (req: Request, socketIO?: Server, isInternal: boolean = false): Promise => { +export const utilBuildChatflow = async (req: Request, isInternal: boolean = false): Promise => { try { const appServer = getRunningExpressApp() const chatflowid = req.params.id @@ -78,7 +77,6 @@ export const utilBuildChatflow = async (req: Request, socketIO?: Server, isInter const chatId = incomingInput.chatId ?? incomingInput.overrideConfig?.sessionId ?? uuidv4() const userMessageDateTime = new Date() - if (!isInternal) { const isKeyValidated = await validateChatflowAPIKey(req, chatflow) if (!isKeyValidated) { @@ -161,8 +159,7 @@ export const utilBuildChatflow = async (req: Request, socketIO?: Server, isInter } incomingInput = { question: req.body.question ?? 'hello', - overrideConfig, - socketIOClientId: req.body.socketIOClientId + overrideConfig } } @@ -181,7 +178,6 @@ export const utilBuildChatflow = async (req: Request, socketIO?: Server, isInter const { graph, nodeDependencies } = constructGraphs(nodes, edges) const directedGraph = graph const endingNodes = getEndingNodes(nodeDependencies, directedGraph, nodes) - /*** If the graph is an agent graph, build the agent response ***/ if (endingNodes.filter((node) => node.data.category === 'Multi Agents' || node.data.category === 'Sequential Agents').length) { return await utilBuildAgentResponse( @@ -195,8 +191,9 @@ export const utilBuildChatflow = async (req: Request, socketIO?: Server, isInter incomingInput, nodes, edges, - socketIO, - baseURL + baseURL, + appServer.sseStreamer, + true ) } @@ -320,9 +317,7 @@ export const utilBuildChatflow = async (req: Request, socketIO?: Server, isInter cachePool: appServer.cachePool, isUpsert: false, uploads: incomingInput.uploads, - baseURL, - socketIO, - socketIOClientId: incomingInput.socketIOClientId + baseURL }) const nodeToExecute = @@ -373,9 +368,9 @@ export const utilBuildChatflow = async (req: Request, socketIO?: Server, isInter databaseEntities, analytic: chatflow.analytic, uploads: incomingInput.uploads, - socketIO, - socketIOClientId: incomingInput.socketIOClientId, - prependMessages + prependMessages, + sseStreamer: appServer.sseStreamer, + shouldStreamResponse: isStreamValid }) : await nodeInstance.run(nodeToExecuteData, incomingInput.question, { chatId, @@ -442,6 +437,8 @@ export const utilBuildChatflow = async (req: Request, socketIO?: Server, isInter result.question = incomingInput.question result.chatId = chatId result.chatMessageId = chatMessage?.id + result.isStreamValid = isStreamValid + if (sessionId) result.sessionId = sessionId if (memoryType) result.memoryType = memoryType @@ -467,12 +464,22 @@ const utilBuildAgentResponse = async ( incomingInput: IncomingInput, nodes: IReactFlowNode[], edges: IReactFlowEdge[], - socketIO?: Server, - baseURL?: string + baseURL?: string, + sseStreamer?: IServerSideEventStreamer, + shouldStreamResponse?: boolean ) => { try { const appServer = getRunningExpressApp() - const streamResults = await buildAgentGraph(agentflow, chatId, sessionId, incomingInput, isInternal, baseURL, socketIO) + const streamResults = await buildAgentGraph( + agentflow, + chatId, + sessionId, + incomingInput, + isInternal, + baseURL, + sseStreamer, + shouldStreamResponse + ) if (streamResults) { const { finalResult, finalAction, sourceDocuments, usedTools, agentReasoning } = streamResults const userMessage: Omit = { @@ -498,10 +505,10 @@ const utilBuildAgentResponse = async ( memoryType, sessionId } - if (sourceDocuments.length) apiMessage.sourceDocuments = JSON.stringify(sourceDocuments) - if (usedTools.length) apiMessage.usedTools = JSON.stringify(usedTools) - if (agentReasoning.length) apiMessage.agentReasoning = JSON.stringify(agentReasoning) - if (Object.keys(finalAction).length) apiMessage.action = JSON.stringify(finalAction) + if (sourceDocuments?.length) apiMessage.sourceDocuments = JSON.stringify(sourceDocuments) + if (usedTools?.length) apiMessage.usedTools = JSON.stringify(usedTools) + if (agentReasoning?.length) apiMessage.agentReasoning = JSON.stringify(agentReasoning) + if (finalAction && Object.keys(finalAction).length) apiMessage.action = JSON.stringify(finalAction) const chatMessage = await utilAddChatMessage(apiMessage) await appServer.telemetry.sendTelemetry('agentflow_prediction_sent', { @@ -548,8 +555,8 @@ const utilBuildAgentResponse = async ( result.chatMessageId = chatMessage?.id if (sessionId) result.sessionId = sessionId if (memoryType) result.memoryType = memoryType - if (agentReasoning.length) result.agentReasoning = agentReasoning - if (Object.keys(finalAction).length) result.action = finalAction + if (agentReasoning?.length) result.agentReasoning = agentReasoning + if (finalAction && Object.keys(finalAction).length) result.action = finalAction return result } diff --git a/packages/server/src/utils/index.ts b/packages/server/src/utils/index.ts index 0b5e89b2e28..af58d727872 100644 --- a/packages/server/src/utils/index.ts +++ b/packages/server/src/utils/index.ts @@ -1,7 +1,6 @@ import path from 'path' import fs from 'fs' import logger from './logger' -import { Server } from 'socket.io' import { IComponentCredentials, IComponentNodes, @@ -436,8 +435,6 @@ type BuildFlowParams = { stopNodeId?: string uploads?: IFileUpload[] baseURL?: string - socketIO?: Server - socketIOClientId?: string } /** @@ -462,9 +459,7 @@ export const buildFlow = async ({ isUpsert, stopNodeId, uploads, - baseURL, - socketIO, - socketIOClientId + baseURL }: BuildFlowParams) => { const flowNodes = cloneDeep(reactFlowNodes) @@ -533,9 +528,7 @@ export const buildFlow = async ({ cachePool, dynamicVariables, uploads, - baseURL, - socketIO, - socketIOClientId + baseURL }) if (indexResult) upsertHistory['result'] = indexResult logger.debug(`[server]: Finished upserting ${reactFlowNode.data.label} (${reactFlowNode.data.id})`) @@ -561,8 +554,6 @@ export const buildFlow = async ({ dynamicVariables, uploads, baseURL, - socketIO, - socketIOClientId, componentNodes: componentNodes as ICommonObject }) diff --git a/packages/ui/package.json b/packages/ui/package.json index db9d022b714..5fb0bd39dc4 100644 --- a/packages/ui/package.json +++ b/packages/ui/package.json @@ -14,6 +14,7 @@ "@emotion/cache": "^11.4.0", "@emotion/react": "^11.10.6", "@emotion/styled": "^11.10.6", + "@microsoft/fetch-event-source": "^2.0.1", "@mui/base": "5.0.0-beta.40", "@mui/icons-material": "5.0.3", "@mui/lab": "5.0.0-alpha.156", diff --git a/packages/ui/src/api/prediction.js b/packages/ui/src/api/prediction.js index d3512843c2a..207d222f1a6 100644 --- a/packages/ui/src/api/prediction.js +++ b/packages/ui/src/api/prediction.js @@ -1,7 +1,9 @@ import client from './client' const sendMessageAndGetPrediction = (id, input) => client.post(`/internal-prediction/${id}`, input) +const sendMessageAndStreamPrediction = (id, input) => client.post(`/internal-prediction/stream/${id}`, input) export default { - sendMessageAndGetPrediction + sendMessageAndGetPrediction, + sendMessageAndStreamPrediction } diff --git a/packages/ui/src/views/chatmessage/ChatMessage.jsx b/packages/ui/src/views/chatmessage/ChatMessage.jsx index 294352055eb..f9bf3431316 100644 --- a/packages/ui/src/views/chatmessage/ChatMessage.jsx +++ b/packages/ui/src/views/chatmessage/ChatMessage.jsx @@ -1,7 +1,6 @@ import { useState, useRef, useEffect, useCallback, Fragment } from 'react' import { useSelector, useDispatch } from 'react-redux' import PropTypes from 'prop-types' -import socketIOClient from 'socket.io-client' import { cloneDeep } from 'lodash' import rehypeMathjax from 'rehype-mathjax' import rehypeRaw from 'rehype-raw' @@ -9,6 +8,7 @@ import remarkGfm from 'remark-gfm' import remarkMath from 'remark-math' import axios from 'axios' import { v4 as uuidv4 } from 'uuid' +import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source' import { Box, @@ -171,7 +171,6 @@ export const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, preview type: 'apiMessage' } ]) - const [socketIOClientId, setSocketIOClientId] = useState('') const [isChatFlowAvailableToStream, setIsChatFlowAvailableToStream] = useState(false) const [isChatFlowAvailableForSpeech, setIsChatFlowAvailableForSpeech] = useState(false) const [sourceDialogOpen, setSourceDialogOpen] = useState(false) @@ -500,6 +499,14 @@ export const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, preview }) } + const updateErrorMessage = (errorMessage) => { + setMessages((prevMessages) => { + let allMessages = [...cloneDeep(prevMessages)] + allMessages.push({ message: errorMessage, type: 'apiMessage' }) + return allMessages + }) + } + const updateLastMessageSourceDocuments = (sourceDocuments) => { setMessages((prevMessages) => { let allMessages = [...cloneDeep(prevMessages)] @@ -614,6 +621,34 @@ export const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, preview handleSubmit(undefined, elem.label, action) } + const updateMetadata = (data, input) => { + // set message id that is needed for feedback + if (data.chatMessageId) { + setMessages((prevMessages) => { + let allMessages = [...cloneDeep(prevMessages)] + if (allMessages[allMessages.length - 1].type === 'apiMessage') { + allMessages[allMessages.length - 1].id = data.chatMessageId + } + return allMessages + }) + } + + if (data.chatId) { + setChatId(data.chatId) + } + + if (input === '' && data.question) { + // the response contains the question even if it was in an audio format + // so if input is empty but the response contains the question, update the user message to show the question + setMessages((prevMessages) => { + let allMessages = [...cloneDeep(prevMessages)] + if (allMessages[allMessages.length - 2].type === 'apiMessage') return allMessages + allMessages[allMessages.length - 2].message = data.question + return allMessages + }) + } + } + // Handle form submission const handleSubmit = async (e, selectedInput, action) => { if (e) e.preventDefault() @@ -649,7 +684,7 @@ export const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, preview } if (uploads && uploads.length > 0) params.uploads = uploads if (leadEmail) params.leadEmail = leadEmail - if (isChatFlowAvailableToStream) params.socketIOClientId = socketIOClientId + if (action) params.action = action if (uploadedFiles.length > 0) { @@ -671,33 +706,15 @@ export const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, preview } } - const response = await predictionApi.sendMessageAndGetPrediction(chatflowid, params) - - if (response.data) { - const data = response.data + if (isChatFlowAvailableToStream) { + fetchResponseFromEventStream(chatflowid, params) + } else { + const response = await predictionApi.sendMessageAndGetPrediction(chatflowid, params) + if (response.data) { + const data = response.data - setMessages((prevMessages) => { - let allMessages = [...cloneDeep(prevMessages)] - if (allMessages[allMessages.length - 1].type === 'apiMessage') { - allMessages[allMessages.length - 1].id = data?.chatMessageId - } - return allMessages - }) - - setChatId(data.chatId) - - if (input === '' && data.question) { - // the response contains the question even if it was in an audio format - // so if input is empty but the response contains the question, update the user message to show the question - setMessages((prevMessages) => { - let allMessages = [...cloneDeep(prevMessages)] - if (allMessages[allMessages.length - 2].type === 'apiMessage') return allMessages - allMessages[allMessages.length - 2].message = data.question - return allMessages - }) - } + updateMetadata(data, input) - if (!isChatFlowAvailableToStream) { let text = '' if (data.text) text = data.text else if (data.json) text = '```json\n' + JSON.stringify(data.json, null, 2) @@ -717,15 +734,16 @@ export const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, preview feedback: null } ]) + + setLocalStorageChatflow(chatflowid, data.chatId) + setLoading(false) + setUserInput('') + setUploadedFiles([]) + setTimeout(() => { + inputRef.current?.focus() + scrollToBottom() + }, 100) } - setLocalStorageChatflow(chatflowid, data.chatId) - setLoading(false) - setUserInput('') - setUploadedFiles([]) - setTimeout(() => { - inputRef.current?.focus() - scrollToBottom() - }, 100) } } catch (error) { handleError(error.response.data.message) @@ -733,6 +751,88 @@ export const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, preview } } + const fetchResponseFromEventStream = async (chatflowid, params) => { + const chatId = params.chatId + const input = params.question + const username = localStorage.getItem('username') + const password = localStorage.getItem('password') + params.streaming = true + await fetchEventSource(`${baseURL}/api/v1/internal-prediction/${chatflowid}`, { + openWhenHidden: true, + method: 'POST', + body: JSON.stringify(params), + headers: { + 'Content-Type': 'application/json', + Authorization: username && password ? `Basic ${btoa(`${username}:${password}`)}` : undefined, + 'x-request-from': 'internal' + }, + async onopen(response) { + if (response.ok && response.headers.get('content-type') === EventStreamContentType) { + //console.log('EventSource Open') + } + }, + async onmessage(ev) { + const payload = JSON.parse(ev.data) + switch (payload.event) { + case 'start': + setMessages((prevMessages) => [...prevMessages, { message: '', type: 'apiMessage' }]) + break + case 'token': + updateLastMessage(payload.data) + break + case 'sourceDocuments': + updateLastMessageSourceDocuments(payload.data) + break + case 'usedTools': + updateLastMessageUsedTools(payload.data) + break + case 'fileAnnotations': + updateLastMessageFileAnnotations(payload.data) + break + case 'agentReasoning': + updateLastMessageAgentReasoning(payload.data) + break + case 'action': + updateLastMessageAction(payload.data) + break + case 'nextAgent': + updateLastMessageNextAgent(payload.data) + break + case 'metadata': + updateMetadata(payload.data, input) + break + case 'error': + updateErrorMessage(payload.data) + break + case 'abort': + abortMessage(payload.data) + closeResponse() + break + case 'end': + setLocalStorageChatflow(chatflowid, chatId) + closeResponse() + break + } + }, + async onclose() { + closeResponse() + }, + async onerror(err) { + console.error('EventSource Error: ', err) + closeResponse() + } + }) + } + + const closeResponse = () => { + setLoading(false) + setUserInput('') + setUploadedFiles([]) + setTimeout(() => { + inputRef.current?.focus() + scrollToBottom() + }, 100) + } // Prevent blank submissions and allow for multiline input const handleEnter = (e) => { // Check if IME composition is in progress @@ -899,7 +999,6 @@ export const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, preview }, [isDialog, inputRef]) useEffect(() => { - let socket if (open && chatflowid) { // API request getChatmessageApi.request(chatflowid) @@ -918,33 +1017,6 @@ export const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, preview setIsLeadSaved(!!savedLead) setLeadEmail(savedLead.email) } - - // SocketIO - socket = socketIOClient(baseURL) - - socket.on('connect', () => { - setSocketIOClientId(socket.id) - }) - - socket.on('start', () => { - setMessages((prevMessages) => [...prevMessages, { message: '', type: 'apiMessage' }]) - }) - - socket.on('sourceDocuments', updateLastMessageSourceDocuments) - - socket.on('usedTools', updateLastMessageUsedTools) - - socket.on('fileAnnotations', updateLastMessageFileAnnotations) - - socket.on('token', updateLastMessage) - - socket.on('agentReasoning', updateLastMessageAgentReasoning) - - socket.on('action', updateLastMessageAction) - - socket.on('nextAgent', updateLastMessageNextAgent) - - socket.on('abort', abortMessage) } return () => { @@ -957,10 +1029,6 @@ export const ChatMessage = ({ open, chatflowid, isAgentCanvas, isDialog, preview type: 'apiMessage' } ]) - if (socket) { - socket.disconnect() - setSocketIOClientId('') - } } // eslint-disable-next-line react-hooks/exhaustive-deps diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 000a7e1a64a..c0d2411f3a1 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -613,6 +613,9 @@ importers: '@emotion/styled': specifier: ^11.10.6 version: 11.11.0(@emotion/react@11.11.4(@types/react@18.2.65)(react@18.2.0))(@types/react@18.2.65)(react@18.2.0) + '@microsoft/fetch-event-source': + specifier: ^2.0.1 + version: 2.0.1 '@mui/base': specifier: 5.0.0-beta.40 version: 5.0.0-beta.40(@types/react@18.2.65)(react-dom@18.2.0(react@18.2.0))(react@18.2.0) @@ -3741,6 +3744,9 @@ packages: '@mendable/firecrawl-js@0.0.28': resolution: { integrity: sha512-Xa+ZbBQkoR/KHM1ZpvJBdLWSCdRoRGyllDNoVvhKxGv9qXZk9h/lBxbqp3Kc1Kg2L2JJnJCkmeaTUCAn8y33GA== } + '@microsoft/fetch-event-source@2.0.1': + resolution: { integrity: sha512-W6CLUJ2eBMw3Rec70qrsEW0jOm/3twwJv21mrmj2yORiaVmVYGS4sSS5yUwvQc1ZlDLYGPnClVWmUUMagKNsfA== } + '@mistralai/mistralai@0.1.3': resolution: { integrity: sha512-WUHxC2xdeqX9PTXJEqdiNY54vT2ir72WSJrZTTBKRnkfhX6zIfCYA24faRlWjUB5WTpn+wfdGsTMl3ArijlXFA== } @@ -21315,6 +21321,8 @@ snapshots: transitivePeerDependencies: - debug + '@microsoft/fetch-event-source@2.0.1': {} + '@mistralai/mistralai@0.1.3(encoding@0.1.13)': dependencies: node-fetch: 2.7.0(encoding@0.1.13)