diff --git a/client/transport/stdio.go b/client/transport/stdio.go index 7375ebe0..0297023f 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -100,7 +100,14 @@ func (c *Stdio) Start(ctx context.Context) error { // Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. // Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. func (c *Stdio) Close() error { + select { + case <-c.done: + return nil + default: + } + // cancel all in-flight request close(c.done) + if err := c.stdin.Close(); err != nil { return fmt.Errorf("failed to close stdin: %w", err) } @@ -182,27 +189,33 @@ func (c *Stdio) SendRequest( return nil, fmt.Errorf("stdio client not started") } - // Create the complete request structure - responseChan := make(chan *JSONRPCResponse, 1) - c.mu.Lock() - c.responses[request.ID] = responseChan - c.mu.Unlock() - + // Marshal request requestBytes, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } requestBytes = append(requestBytes, '\n') + // Register response channel + responseChan := make(chan *JSONRPCResponse, 1) + c.mu.Lock() + c.responses[request.ID] = responseChan + c.mu.Unlock() + deleteResponseChan := func() { + c.mu.Lock() + delete(c.responses, request.ID) + c.mu.Unlock() + } + + // Send request if _, err := c.stdin.Write(requestBytes); err != nil { + deleteResponseChan() return nil, fmt.Errorf("failed to write request: %w", err) } select { case <-ctx.Done(): - c.mu.Lock() - delete(c.responses, request.ID) - c.mu.Unlock() + deleteResponseChan() return nil, ctx.Err() case response := <-responseChan: return response, nil