diff --git a/pkg/middleware/trace.go b/pkg/middleware/trace.go index 9b11768f..6caf33d3 100644 --- a/pkg/middleware/trace.go +++ b/pkg/middleware/trace.go @@ -25,6 +25,7 @@ import ( "net/http" "sync" + "github.com/go-chi/chi/v5/middleware" "github.com/optimizely/agent/config" "github.com/rs/zerolog/log" "go.opentelemetry.io/otel" @@ -80,16 +81,6 @@ func (gen *traceIDGenerator) NewIDs(ctx context.Context) (trace.TraceID, trace.S return tid, sid } -type statusRecorder struct { - http.ResponseWriter - statusCode int -} - -func (r *statusRecorder) WriteHeader(code int) { - r.statusCode = code - r.ResponseWriter.WriteHeader(code) -} - func AddTracing(conf config.TracingConfig, tracerName, spanName string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { @@ -107,15 +98,12 @@ func AddTracing(conf config.TracingConfig, tracerName, spanName string) func(htt attribute.String(OptlySDKHeader, r.Header.Get(OptlySDKHeader)), ) - rec := &statusRecorder{ - ResponseWriter: w, - statusCode: http.StatusOK, - } + respWriter := middleware.NewWrapResponseWriter(w, r.ProtoMajor) - next.ServeHTTP(rec, r.WithContext(ctx)) + next.ServeHTTP(respWriter, r.WithContext(ctx)) span.SetAttributes( - semconv.HTTPStatusCodeKey.Int(rec.statusCode), + semconv.HTTPStatusCodeKey.Int(respWriter.Status()), ) } return http.HandlerFunc(fn) diff --git a/tests/acceptance/test_acceptance/test_odp_redis.py b/tests/acceptance/test_acceptance/test_odp_redis.py index a58e2db2..7e4b5fdc 100644 --- a/tests/acceptance/test_acceptance/test_odp_redis.py +++ b/tests/acceptance/test_acceptance/test_odp_redis.py @@ -52,6 +52,7 @@ def test_redis_save(session_override_sdk_key_odp): """ expected_segments = ["atsbugbashsegmenthaspurchased", "atsbugbashsegmentdob"] + expected_segments_rev = ["atsbugbashsegmentdob", "atsbugbashsegmenthaspurchased"] uId = "fs_user_id-$-matjaz-user-1" r = redis.Redis(host='localhost', port=6379, db=0) # clean redis before testing since several tests use same user_id @@ -72,7 +73,7 @@ def test_redis_save(session_override_sdk_key_odp): params=params) # Check saved segments - assert json.loads(json.dumps(expected_segments)) == json.loads(r.get(uId)) + assert json.loads(json.dumps(expected_segments)) == json.loads(r.get(uId)) or json.loads(json.dumps(expected_segments_rev)) == json.loads(r.get(uId)) assert json.loads(json.dumps(expected_redis_save)) == resp.json() assert resp.status_code == 200, resp.text