Skip to content

Commit 963d9a5

Browse files
committed
test chat completions api in e2e case
Signed-off-by: Hang Yin <[email protected]>
1 parent e8834c3 commit 963d9a5

File tree

1 file changed

+44
-33
lines changed

1 file changed

+44
-33
lines changed

test/e2e/epp/e2e_test.go

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -55,38 +55,40 @@ var _ = ginkgo.Describe("InferencePool", func() {
5555
}, existsTimeout, interval).Should(gomega.Succeed())
5656

5757
ginkgo.By("Verifying connectivity through the inference extension")
58-
curlCmd := getCurlCommand(envoyName, nsName, envoyPort, modelName, curlTimeout)
58+
for _, testApi := range []string{"/completions", "/chat/completions"} {
59+
curlCmd := getCurlCommand(envoyName, nsName, envoyPort, modelName, curlTimeout, testApi)
5960

60-
// Ensure the expected responses include the inferencemodel target model names.
61-
var expected []string
62-
for _, m := range infModel.Spec.TargetModels {
63-
expected = append(expected, m.Name)
64-
}
65-
actual := make(map[string]int)
66-
gomega.Eventually(func() error {
67-
resp, err := testutils.ExecCommandInPod(ctx, cfg, scheme, kubeCli, nsName, "curl", "curl", curlCmd)
68-
if err != nil {
69-
return err
70-
}
71-
if !strings.Contains(resp, "200 OK") {
72-
return fmt.Errorf("did not get 200 OK: %s", resp)
61+
// Ensure the expected responses include the inferencemodel target model names.
62+
var expected []string
63+
for _, m := range infModel.Spec.TargetModels {
64+
expected = append(expected, m.Name)
7365
}
74-
for _, m := range expected {
75-
if strings.Contains(resp, m) {
76-
actual[m] = 0
66+
actual := make(map[string]int)
67+
gomega.Eventually(func() error {
68+
resp, err := testutils.ExecCommandInPod(ctx, cfg, scheme, kubeCli, nsName, "curl", "curl", curlCmd)
69+
if err != nil {
70+
return err
71+
}
72+
if !strings.Contains(resp, "200 OK") {
73+
return fmt.Errorf("did not get 200 OK: %s", resp)
74+
}
75+
for _, m := range expected {
76+
if strings.Contains(resp, m) {
77+
actual[m] = 0
78+
}
79+
}
80+
var got []string
81+
for m := range actual {
82+
got = append(got, m)
83+
}
84+
// Compare ignoring order
85+
if !cmp.Equal(got, expected, cmpopts.SortSlices(func(a, b string) bool { return a < b })) {
86+
return fmt.Errorf("actual (%v) != expected (%v); resp=%q", got, expected, resp)
7787
}
78-
}
79-
var got []string
80-
for m := range actual {
81-
got = append(got, m)
82-
}
83-
// Compare ignoring order
84-
if !cmp.Equal(got, expected, cmpopts.SortSlices(func(a, b string) bool { return a < b })) {
85-
return fmt.Errorf("actual (%v) != expected (%v); resp=%q", got, expected, resp)
86-
}
8788

88-
return nil
89-
}, readyTimeout, curlInterval).Should(gomega.Succeed())
89+
return nil
90+
}, readyTimeout, curlInterval).Should(gomega.Succeed())
91+
}
9092

9193
})
9294
})
@@ -110,16 +112,25 @@ func newInferenceModel(ns string) *v1alpha2.InferenceModel {
110112

111113
// getCurlCommand returns the command, as a slice of strings, for curl'ing
112114
// the test model server at the given name, namespace, port, and model name.
113-
func getCurlCommand(name, ns, port, model string, timeout time.Duration) []string {
114-
return []string{
115+
func getCurlCommand(name, ns, port, model string, timeout time.Duration, api string) []string {
116+
command := []string{
115117
"curl",
116118
"-i",
117119
"--max-time",
118120
strconv.Itoa((int)(timeout.Seconds())),
119-
fmt.Sprintf("%s.%s.svc:%s/v1/completions", name, ns, port),
121+
fmt.Sprintf("%s.%s.svc:%s/v1%s", name, ns, port, api),
120122
"-H",
121123
"Content-Type: application/json",
122-
"-d",
123-
fmt.Sprintf(`{"model": "%s", "prompt": "Write as if you were a critic: San Francisco", "max_tokens": 100, "temperature": 0}`, model),
124124
}
125+
switch api {
126+
case "/completions":
127+
command = append(command,
128+
"-d",
129+
fmt.Sprintf(`{"model": "%s", "prompt": "Write as if you were a critic: San Francisco", "max_tokens": 100, "temperature": 0}`, model))
130+
case "/chat/completions":
131+
command = append(command,
132+
"-d",
133+
fmt.Sprintf(`{"model": "%s", "messages": [{"role": "user", "content": "Hello! Please introduce yourself"}], "max_tokens": 100, "temperature": 0}`, model))
134+
}
135+
return command
125136
}

0 commit comments

Comments
 (0)