From a7a9412b7c66461732ca1100e83e42c68a1cdc46 Mon Sep 17 00:00:00 2001 From: Ira Date: Tue, 3 Jun 2025 11:51:53 +0300 Subject: [PATCH] Server unit test and utility to help with such tests Signed-off-by: Ira --- pkg/epp/server/server_test.go | 192 ++++++++++++++++++++++++++++++++++ test/utils/server.go | 190 +++++++++++++++++++++++++++++++++ 2 files changed, 382 insertions(+) create mode 100644 pkg/epp/server/server_test.go create mode 100644 test/utils/server.go diff --git a/pkg/epp/server/server_test.go b/pkg/epp/server/server_test.go new file mode 100644 index 000000000..3696f5a71 --- /dev/null +++ b/pkg/epp/server/server_test.go @@ -0,0 +1,192 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package server + +import ( + "context" + "fmt" + "testing" + + pb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" + testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" + "sigs.k8s.io/gateway-api-inference-extension/test/utils" +) + +const ( + bufSize = 1024 * 1024 + podName = "pod1" + podAddress = "1.2.3.4" + poolPort = int32(5678) + destinationEndpointHintKey = "test-target" + namespace = "ns1" +) + +func TestServer(t *testing.T) { + theHeaderValue := "body" + requestHeader := "x-test" + + expectedRequestHeaders := map[string]string{destinationEndpointHintKey: fmt.Sprintf("%s:%d", podAddress, poolPort), + "Content-Length": "42", ":method": "POST", requestHeader: theHeaderValue} + expectedResponseHeaders := map[string]string{"x-went-into-resp-headers": "true", ":method": "POST", requestHeader: theHeaderValue} + expectedSchedulerHeaders := map[string]string{":method": "POST", requestHeader: theHeaderValue} + + t.Run("server", func(t *testing.T) { + tsModel := "food-review" + model := testutil.MakeInferenceModel("v1"). + CreationTimestamp(metav1.Unix(1000, 0)). + ModelName(tsModel).ObjRef() + + director := &testDirector{} + ctx, cancel, ds, _ := utils.PrepareForTestStreamingServer([]*v1alpha2.InferenceModel{model}, + []*v1.Pod{{ObjectMeta: metav1.ObjectMeta{Name: podName}}}, "test-pool1", namespace, poolPort) + + streamingServer := handlers.NewStreamingServer(namespace, destinationEndpointHintKey, ds, director) + + testListener, errChan := utils.SetupTestStreamingServer(t, ctx, ds, streamingServer) + process, conn := utils.GetStreamingServerClient(ctx, t) + defer conn.Close() + + // Send request headers - no response expected + headers := utils.BuildEnvoyGRPCHeaders(map[string]string{requestHeader: theHeaderValue, ":method": "POST"}, true) + request := &pb.ProcessingRequest{ + Request: &pb.ProcessingRequest_RequestHeaders{ + RequestHeaders: headers, + }, + } + err := process.Send(request) + if err != nil { + t.Error("Error sending request headers", err) + } + + // Send request body + requestBody := "{\"model\":\"food-review\",\"prompt\":\"Is banana tasty?\"}" + expectedBody := "{\"model\":\"v1\",\"prompt\":\"Is banana tasty?\"}" + request = &pb.ProcessingRequest{ + Request: &pb.ProcessingRequest_RequestBody{ + RequestBody: &pb.HttpBody{ + Body: []byte(requestBody), + EndOfStream: true, + }, + }, + } + err = process.Send(request) + if err != nil { + t.Error("Error sending request body", err) + } + + // Receive request headers and check + responseReqHeaders, err := process.Recv() + if err != nil { + t.Error("Error receiving response", err) + } else { + if responseReqHeaders == nil || responseReqHeaders.GetRequestHeaders() == nil || + responseReqHeaders.GetRequestHeaders().Response == nil || + responseReqHeaders.GetRequestHeaders().Response.HeaderMutation == nil || + responseReqHeaders.GetRequestHeaders().Response.HeaderMutation.SetHeaders == nil { + t.Error("Invalid request headers response") + } else if !utils.CheckEnvoyGRPCHeaders(t, responseReqHeaders.GetRequestHeaders().Response, expectedRequestHeaders) { + t.Error("Incorrect request headers") + } + } + + // Receive request body and check + responseReqBody, err := process.Recv() + if err != nil { + t.Error("Error receiving response", err) + } else { + if responseReqBody == nil || responseReqBody.GetRequestBody() == nil || + responseReqBody.GetRequestBody().Response == nil || + responseReqBody.GetRequestBody().Response.BodyMutation == nil || + responseReqBody.GetRequestBody().Response.BodyMutation.GetStreamedResponse() == nil { + t.Error("Invalid request body response") + } else { + body := responseReqBody.GetRequestBody().Response.BodyMutation.GetStreamedResponse().Body + if string(body) != expectedBody { + t.Errorf("Incorrect body %s expected %s", string(body), expectedBody) + } + } + } + + // Check headers passed to the scheduler + if len(director.requestHeaders) != 2 { + t.Errorf("Incorrect number of request headers %d instead of 2", len(director.requestHeaders)) + } + for expectedKey, expectedValue := range expectedSchedulerHeaders { + got, ok := director.requestHeaders[expectedKey] + if !ok { + t.Errorf("Missing header %s", expectedKey) + } else if got != expectedValue { + t.Errorf("Incorrect value for header %s, want %s got %s", expectedKey, expectedValue, got) + } + } + + // Send response headers + headers = utils.BuildEnvoyGRPCHeaders(map[string]string{requestHeader: theHeaderValue, ":method": "POST"}, false) + request = &pb.ProcessingRequest{ + Request: &pb.ProcessingRequest_ResponseHeaders{ + ResponseHeaders: headers, + }, + } + err = process.Send(request) + if err != nil { + t.Error("Error sending response", err) + } + + // Receive response headers and check + response, err := process.Recv() + if err != nil { + t.Error("Error receiving response", err) + } else { + if response == nil || response.GetResponseHeaders() == nil || response.GetResponseHeaders().Response == nil || + response.GetResponseHeaders().Response.HeaderMutation == nil || + response.GetResponseHeaders().Response.HeaderMutation.SetHeaders == nil { + t.Error("Invalid response") + } else if !utils.CheckEnvoyGRPCHeaders(t, response.GetResponseHeaders().Response, expectedResponseHeaders) { + t.Error("Incorrect response headers") + } + } + + cancel() + <-errChan + testListener.Close() + }) +} + +type testDirector struct { + requestHeaders map[string]string +} + +func (ts *testDirector) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { + ts.requestHeaders = reqCtx.Request.Headers + + reqCtx.Request.Body["model"] = "v1" + reqCtx.TargetEndpoint = fmt.Sprintf("%s:%d", podAddress, poolPort) + return reqCtx, nil +} + +func (ts *testDirector) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) { + return reqCtx, nil +} + +func (ts *testDirector) GetRandomPod() *backend.Pod { + return nil +} diff --git a/test/utils/server.go b/test/utils/server.go new file mode 100644 index 000000000..f3d0a5a94 --- /dev/null +++ b/test/utils/server.go @@ -0,0 +1,190 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "context" + "net" + "testing" + "time" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" +) + +const bufSize = 1024 * 1024 + +var testListener *bufconn.Listener + +func PrepareForTestStreamingServer(models []*v1alpha2.InferenceModel, pods []*v1.Pod, poolName string, namespace string, + poolPort int32) (context.Context, context.CancelFunc, datastore.Datastore, *metrics.FakePodMetricsClient) { + ctx, cancel := context.WithCancel(context.Background()) + + pmc := &metrics.FakePodMetricsClient{} + pmf := metrics.NewPodMetricsFactory(pmc, time.Second) + ds := datastore.NewDatastore(ctx, pmf) + + initObjs := []client.Object{} + for _, model := range models { + initObjs = append(initObjs, model) + ds.ModelSetIfOlder(model) + } + for _, pod := range pods { + initObjs = append(initObjs, pod) + ds.PodUpdateOrAddIfNotExist(pod) + } + + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = v1alpha2.Install(scheme) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(initObjs...). + Build() + pool := testutil.MakeInferencePool(poolName).Namespace(namespace).ObjRef() + pool.Spec.TargetPortNumber = poolPort + _ = ds.PoolSet(context.Background(), fakeClient, pool) + + return ctx, cancel, ds, pmc +} + +func SetupTestStreamingServer(t *testing.T, ctx context.Context, ds datastore.Datastore, + streamingServer pb.ExternalProcessorServer) (*bufconn.Listener, chan error) { + testListener = bufconn.Listen(bufSize) + + errChan := make(chan error) + go func() { + err := LaunchTestGRPCServer(streamingServer, ctx, testListener) + if err != nil { + t.Error("Error launching listener", err) + } + errChan <- err + }() + + time.Sleep(2 * time.Second) + return testListener, errChan +} + +func testDialer(context.Context, string) (net.Conn, error) { + return testListener.Dial() +} + +func GetStreamingServerClient(ctx context.Context, t *testing.T) (pb.ExternalProcessor_ProcessClient, *grpc.ClientConn) { + opts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(testDialer), + } + conn, err := grpc.NewClient("passthrough://bufconn", opts...) + if err != nil { + t.Error(err) + return nil, nil + } + + extProcClient := pb.NewExternalProcessorClient(conn) + process, err := extProcClient.Process(ctx) + if err != nil { + t.Error(err) + return nil, nil + } + + return process, conn +} + +// LaunchTestGRPCServer actually starts the server (enables testing) +func LaunchTestGRPCServer(s pb.ExternalProcessorServer, ctx context.Context, listener net.Listener) error { + grpcServer := grpc.NewServer() + + pb.RegisterExternalProcessorServer(grpcServer, s) + + // Shutdown on context closed. + // Terminate the server on context closed. + go func() { + <-ctx.Done() + grpcServer.GracefulStop() + }() + + if err := grpcServer.Serve(listener); err != nil { + return err + } + + return nil +} + +func CheckEnvoyGRPCHeaders(t *testing.T, response *pb.CommonResponse, expectedHeaders map[string]string) bool { + headers := response.HeaderMutation.SetHeaders + for expectedKey, expectedValue := range expectedHeaders { + found := false + for _, header := range headers { + if header.Header.Key == expectedKey { + if expectedValue != string(header.Header.RawValue) { + t.Errorf("Incorrect value for header %s, want %s got %s", expectedKey, expectedValue, + string(header.Header.RawValue)) + return false + } + found = true + break + } + } + if !found { + t.Errorf("Missing header %s", expectedKey) + return false + } + } + + for _, header := range headers { + expectedValue, ok := expectedHeaders[header.Header.Key] + if !ok { + t.Errorf("Unexpected header %s", header.Header.Key) + return false + } else if expectedValue != string(header.Header.RawValue) { + t.Errorf("Incorrect value for header %s, want %s got %s", header.Header.Key, expectedValue, + string(header.Header.RawValue)) + return false + } + } + return true +} + +func BuildEnvoyGRPCHeaders(headers map[string]string, rawValue bool) *pb.HttpHeaders { + headerValues := make([]*corev3.HeaderValue, 0) + for key, value := range headers { + header := &corev3.HeaderValue{Key: key} + if rawValue { + header.RawValue = []byte(value) + } else { + header.Value = value + } + headerValues = append(headerValues, header) + } + return &pb.HttpHeaders{ + Headers: &corev3.HeaderMap{ + Headers: headerValues, + }, + } +}