Skip to content

refactor(stdio): improve stdio server message handling #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 29, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 81 additions & 58 deletions server/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,85 @@ func (s *StdioServer) SetContextFunc(fn StdioContextFunc) {
s.contextFunc = fn
}

// handleNotifications continuously processes notifications from the session's notification channel
// and writes them to the provided output. It runs until the context is cancelled.
// Any errors encountered while writing notifications are logged but do not stop the handler.
func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) {
for {
select {
case notification := <-stdioSessionInstance.notifications:
if err := s.writeResponse(notification, stdout); err != nil {
s.errLogger.Printf("Error writing notification: %v", err)
}
case <-ctx.Done():
return
}
}
}

// processInputStream continuously reads and processes messages from the input stream.
// It handles EOF gracefully as a normal termination condition.
// The function returns when either:
// - The context is cancelled (returns context.Err())
// - EOF is encountered (returns nil)
// - An error occurs while reading or processing messages (returns the error)
func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error {
for {
if err := ctx.Err(); err != nil {
return err
}

line, err := s.readNextLine(ctx, reader)
if err != nil {
if err == io.EOF {
return nil
}
s.errLogger.Printf("Error reading input: %v", err)
return err
}

if err := s.processMessage(ctx, line, stdout); err != nil {
if err == io.EOF {
return nil
}
s.errLogger.Printf("Error handling message: %v", err)
return err
}
}
}

// readNextLine reads a single line from the input reader in a context-aware manner.
// It uses channels to make the read operation cancellable via context.
// Returns the read line and any error encountered. If the context is cancelled,
// returns an empty string and the context's error. EOF is returned when the input
// stream is closed.
func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) {
readChan := make(chan string, 1)
errChan := make(chan error, 1)
defer func() {
close(readChan)
close(errChan)
}()

go func() {
line, err := reader.ReadString('\n')
if err != nil {
errChan <- err
return
}
readChan <- line
}()

select {
case <-ctx.Done():
return "", ctx.Err()
case err := <-errChan:
return "", err
case line := <-readChan:
return line, nil
}
}

// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
// It runs until the context is cancelled or an error occurs.
// Returns an error if there are issues with reading input or writing output.
Expand All @@ -126,64 +205,8 @@ func (s *StdioServer) Listen(
reader := bufio.NewReader(stdin)

// Start notification handler
go func() {
for {
select {
case notification := <-stdioSessionInstance.notifications:
err := s.writeResponse(
notification,
stdout,
)
if err != nil {
s.errLogger.Printf(
"Error writing notification: %v",
err,
)
}
case <-ctx.Done():
return
}
}
}()

for {
select {
case <-ctx.Done():
return ctx.Err()
default:
// Use a goroutine to make the read cancellable
readChan := make(chan string, 1)
errChan := make(chan error, 1)

go func() {
line, err := reader.ReadString('\n')
if err != nil {
errChan <- err
return
}
readChan <- line
}()

select {
case <-ctx.Done():
return ctx.Err()
case err := <-errChan:
if err == io.EOF {
return nil
}
s.errLogger.Printf("Error reading input: %v", err)
return err
case line := <-readChan:
if err := s.processMessage(ctx, line, stdout); err != nil {
if err == io.EOF {
return nil
}
s.errLogger.Printf("Error handling message: %v", err)
return err
}
}
}
}
go s.handleNotifications(ctx, stdout)
return s.processInputStream(ctx, reader, stdout)
}

// processMessage handles a single JSON-RPC message and writes the response.
Expand Down