diff --git a/cmd/server/main.go b/cmd/server/main.go index a4266a00..bfe5df2b 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -3,12 +3,14 @@ package main import ( "context" "fmt" + "io" stdlog "log" "os" "os/signal" "syscall" "github.com/github/github-mcp-server/pkg/github" + iolog "github.com/github/github-mcp-server/pkg/log" gogithub "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/server" log "github.com/sirupsen/logrus" @@ -33,7 +35,8 @@ var ( if err != nil { stdlog.Fatal("Failed to initialize logger:", err) } - if err := runStdioServer(logger); err != nil { + logCommands := viper.GetBool("enable-command-logging") + if err := runStdioServer(logger, logCommands); err != nil { stdlog.Fatal("failed to run stdio server:", err) } }, @@ -45,9 +48,11 @@ func init() { // Add global flags that will be shared by all commands rootCmd.PersistentFlags().String("log-file", "", "Path to log file") + rootCmd.PersistentFlags().Bool("enable-command-logging", false, "When enabled, the server will log all command requests and responses to the log file") // Bind flag to viper viper.BindPFlag("log-file", rootCmd.PersistentFlags().Lookup("log-file")) + viper.BindPFlag("enable-command-logging", rootCmd.PersistentFlags().Lookup("enable-command-logging")) // Add subcommands rootCmd.AddCommand(stdioCmd) @@ -70,12 +75,13 @@ func initLogger(outPath string) (*log.Logger, error) { } logger := log.New() + logger.SetLevel(log.DebugLevel) logger.SetOutput(file) return logger, nil } -func runStdioServer(logger *log.Logger) error { +func runStdioServer(logger *log.Logger, logCommands bool) error { // Create app context ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() @@ -97,7 +103,14 @@ func runStdioServer(logger *log.Logger) error { // Start listening for messages errC := make(chan error, 1) go func() { - errC <- stdioServer.Listen(ctx, os.Stdin, os.Stdout) + in, out := io.Reader(os.Stdin), io.Writer(os.Stdout) + + if logCommands { + loggedIO := iolog.NewIOLogger(in, out, logger) + in, out = loggedIO, loggedIO + } + + errC <- stdioServer.Listen(ctx, in, out) }() // Output github-mcp-server string diff --git a/go.mod b/go.mod index e53b8b6b..6fbe54a4 100644 --- a/go.mod +++ b/go.mod @@ -9,10 +9,12 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.19.0 + github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 ) require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect @@ -21,6 +23,7 @@ require ( github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect diff --git a/pkg/log/io.go b/pkg/log/io.go new file mode 100644 index 00000000..de221027 --- /dev/null +++ b/pkg/log/io.go @@ -0,0 +1,45 @@ +package log + +import ( + "io" + + log "github.com/sirupsen/logrus" +) + +// IOLogger is a wrapper around io.Reader and io.Writer that can be used +// to log the data being read and written from the underlying streams +type IOLogger struct { + reader io.Reader + writer io.Writer + logger *log.Logger +} + +// NewIOLogger creates a new IOLogger instance +func NewIOLogger(r io.Reader, w io.Writer, logger *log.Logger) *IOLogger { + return &IOLogger{ + reader: r, + writer: w, + logger: logger, + } +} + +// Read reads data from the underlying io.Reader and logs it. +func (l *IOLogger) Read(p []byte) (n int, err error) { + if l.reader == nil { + return 0, io.EOF + } + n, err = l.reader.Read(p) + if n > 0 { + l.logger.Infof("[stdin]: received %d bytes: %s", n, string(p[:n])) + } + return n, err +} + +// Write writes data to the underlying io.Writer and logs it. +func (l *IOLogger) Write(p []byte) (n int, err error) { + if l.writer == nil { + return 0, io.ErrClosedPipe + } + l.logger.Infof("[stdout]: sending %d bytes: %s", len(p), string(p)) + return l.writer.Write(p) +} diff --git a/pkg/log/io_test.go b/pkg/log/io_test.go new file mode 100644 index 00000000..0d0cd895 --- /dev/null +++ b/pkg/log/io_test.go @@ -0,0 +1,65 @@ +package log + +import ( + "bytes" + "strings" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func TestLoggedReadWriter(t *testing.T) { + t.Run("Read method logs and passes data", func(t *testing.T) { + // Setup + inputData := "test input data" + reader := strings.NewReader(inputData) + + // Create logger with buffer to capture output + var logBuffer bytes.Buffer + logger := log.New() + logger.SetOutput(&logBuffer) + logger.SetFormatter(&log.TextFormatter{ + DisableTimestamp: true, + }) + + lrw := NewIOLogger(reader, nil, logger) + + // Test Read + buf := make([]byte, 100) + n, err := lrw.Read(buf) + + // Assertions + assert.NoError(t, err) + assert.Equal(t, len(inputData), n) + assert.Equal(t, inputData, string(buf[:n])) + assert.Contains(t, logBuffer.String(), "[stdin]") + assert.Contains(t, logBuffer.String(), inputData) + }) + + t.Run("Write method logs and passes data", func(t *testing.T) { + // Setup + outputData := "test output data" + var writeBuffer bytes.Buffer + + // Create logger with buffer to capture output + var logBuffer bytes.Buffer + logger := log.New() + logger.SetOutput(&logBuffer) + logger.SetFormatter(&log.TextFormatter{ + DisableTimestamp: true, + }) + + lrw := NewIOLogger(nil, &writeBuffer, logger) + + // Test Write + n, err := lrw.Write([]byte(outputData)) + + // Assertions + assert.NoError(t, err) + assert.Equal(t, len(outputData), n) + assert.Equal(t, outputData, writeBuffer.String()) + assert.Contains(t, logBuffer.String(), "[stdout]") + assert.Contains(t, logBuffer.String(), outputData) + }) +}