diff --git a/client/stdio_test.go b/client/stdio_test.go index 7bffa3b2..8c9ff299 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -7,7 +7,7 @@ import ( "log/slog" "os" "os/exec" - "path/filepath" + "runtime" "sync" "testing" "time" @@ -19,6 +19,7 @@ func compileTestServer(outputPath string) error { cmd := exec.Command( "go", "build", + "-buildmode=pie", "-o", outputPath, "../testdata/mockstdio_server.go", @@ -33,10 +34,22 @@ func compileTestServer(outputPath string) error { } func TestStdioMCPClient(t *testing.T) { - // Compile mock server - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first + mockServerPath += ".exe" + } + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index aa728ec6..53db7a0f 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "os/exec" - "path/filepath" "runtime" "sync" "testing" @@ -19,6 +18,7 @@ func compileTestServer(outputPath string) error { cmd := exec.Command( "go", "build", + "-buildmode=pie", "-o", outputPath, "../../testdata/mockstdio_server.go", @@ -26,18 +26,30 @@ func compileTestServer(outputPath string) error { if output, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) } + // Verify the binary was actually created + if _, err := os.Stat(outputPath); os.IsNotExist(err) { + return fmt.Errorf("mock server binary not found at %s after compilation", outputPath) + } return nil } func TestStdio(t *testing.T) { - // Compile mock server - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + // Add .exe suffix on Windows if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) @@ -48,9 +60,9 @@ func TestStdio(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - err := stdio.Start(ctx) - if err != nil { - t.Fatalf("Failed to start Stdio transport: %v", err) + startErr := stdio.Start(ctx) + if startErr != nil { + t.Fatalf("Failed to start Stdio transport: %v", startErr) } defer stdio.Close() @@ -307,13 +319,22 @@ func TestStdioErrors(t *testing.T) { }) t.Run("RequestBeforeStart", func(t *testing.T) { - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + // Add .exe suffix on Windows if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) @@ -328,23 +349,31 @@ func TestStdioErrors(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - _, err := uninitiatedStdio.SendRequest(ctx, request) - if err == nil { + _, reqErr := uninitiatedStdio.SendRequest(ctx, request) + if reqErr == nil { t.Errorf("Expected SendRequest to panic before Start(), but it didn't") - } else if err.Error() != "stdio client not started" { - t.Errorf("Expected error 'stdio client not started', got: %v", err) + } else if reqErr.Error() != "stdio client not started" { + t.Errorf("Expected error 'stdio client not started', got: %v", reqErr) } }) t.Run("RequestAfterClose", func(t *testing.T) { - // Compile mock server - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + // Add .exe suffix on Windows if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) @@ -353,8 +382,8 @@ func TestStdioErrors(t *testing.T) { // Start the transport ctx := context.Background() - if err := stdio.Start(ctx); err != nil { - t.Fatalf("Failed to start Stdio transport: %v", err) + if startErr := stdio.Start(ctx); startErr != nil { + t.Fatalf("Failed to start Stdio transport: %v", startErr) } // Close the transport - ignore errors like "broken pipe" since the process might exit already @@ -370,8 +399,8 @@ func TestStdioErrors(t *testing.T) { Method: "ping", } - _, err := stdio.SendRequest(ctx, request) - if err == nil { + _, sendErr := stdio.SendRequest(ctx, request) + if sendErr == nil { t.Errorf("Expected error when sending request after close, got nil") } })