Skip to content

Commit b87c91a

Browse files
Added modelName parameter to be considered for testing finetune demo
1 parent 34de932 commit b87c91a

File tree

5 files changed

+41
-81
lines changed

5 files changed

+41
-81
lines changed

tests/odh/mnist_ray_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ func mnistRay(t *testing.T, numGpus int) {
131131

132132
// Fetch created raycluster
133133
rayClusterName := "mnisttest"
134-
// Wait until raycluster is up and running
135134
rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Get(test.Ctx(), rayClusterName, metav1.GetOptions{})
136135
test.Expect(err).ToNot(HaveOccurred())
137136

tests/odh/mnist_raytune_hpo_test.go

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -129,53 +129,12 @@ func mnistRayTuneHpo(t *testing.T, numGpus int) {
129129
ContainElement(WithTransform(KueueWorkloadAdmitted, BeTrueBecause("Workload failed to be admitted"))),
130130
),
131131
)
132-
time.Sleep(30 * time.Second)
133132

134133
// Fetch created raycluster
135134
rayClusterName := "mnisthpotest"
136135
rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Get(test.Ctx(), rayClusterName, metav1.GetOptions{})
137136
test.Expect(err).ToNot(HaveOccurred())
138137

139-
// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
140-
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
141-
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}
142-
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
143-
if err != nil {
144-
test.T().Errorf("%s", err)
145-
}
146-
147-
jobID := GetTestJobId(test, rayClient, dashboardUrl.Host)
148-
test.Expect(jobID).ToNot(Equal(nil))
149-
150-
// Wait for the job to be succeeded or failed
151-
var rayJobStatus string
152-
fmt.Printf("Waiting for job to be Succeeded...\n")
153-
test.Eventually(func() string {
154-
resp, err := rayClient.GetJobDetails(jobID)
155-
test.Expect(err).ToNot(HaveOccurred())
156-
rayJobStatusVal := resp.Status
157-
if rayJobStatusVal == "SUCCEEDED" || rayJobStatusVal == "FAILED" {
158-
fmt.Printf("JobStatus : %s\n", rayJobStatusVal)
159-
rayJobStatus = rayJobStatusVal
160-
return rayJobStatus
161-
}
162-
if rayJobStatus != rayJobStatusVal && rayJobStatusVal != "SUCCEEDED" {
163-
fmt.Printf("JobStatus : %s...\n", rayJobStatusVal)
164-
rayJobStatus = rayJobStatusVal
165-
}
166-
return rayJobStatus
167-
}, TestTimeoutDouble, 3*time.Second).Should(Or(Equal("SUCCEEDED"), Equal("FAILED")), "Job did not complete within the expected time")
168-
test.Expect(rayJobStatus).To(Equal("SUCCEEDED"), "RayJob failed !")
169-
170-
// Store job logs in output directory
171-
WriteRayJobAPILogs(test, rayClient, jobID)
172-
173-
// Fetch created raycluster
174-
rayClusterName := "mnisthpotest"
175-
// Wait until raycluster is up and running
176-
rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Get(test.Ctx(), rayClusterName, metav1.GetOptions{})
177-
test.Expect(err).ToNot(HaveOccurred())
178-
179138
// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
180139
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
181140
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}

tests/odh/notebook.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ import (
2929
"k8s.io/apimachinery/pkg/util/yaml"
3030
)
3131

32-
const recommendedTagAnnotation = "opendatahub.io/workbench-image-recommended"
33-
3432
var notebookResource = schema.GroupVersionResource{Group: "kubeflow.org", Version: "v1", Resource: "notebooks"}
3533

3634
type NotebookProps struct {

tests/odh/ray_finetune_llm_deepspeed_test.go

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ import (
3030
)
3131

3232
func TestRayFinetuneLlmDeepspeedDemoLlama_2_7b(t *testing.T) {
33-
rayFinetuneLlmDeepspeed(t, 1, "zero_3_llama_2_7b.json")
33+
rayFinetuneLlmDeepspeed(t, 1, "meta-llama/Llama-2-7b-chat-hf", "zero_3_llama_2_7b.json")
3434
}
3535
func TestRayFinetuneLlmDeepspeedDemoLlama_31_8b(t *testing.T) {
36-
rayFinetuneLlmDeepspeed(t, 1, "zero_3_offload_optim_param.json")
36+
rayFinetuneLlmDeepspeed(t, 1, "meta-llama/Meta-Llama-3.1-8B", "zero_3_offload_optim_param.json")
3737
}
3838

39-
func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string) {
39+
func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelName string, modelConfigFile string) {
4040
test := With(t)
4141

4242
// Create a namespace
@@ -56,21 +56,22 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
5656
"import os": "import os,time,sys",
5757
"import sys": "!cp /opt/app-root/notebooks/* ./\\n\",\n\t\"!ls",
5858
"from codeflare_sdk.cluster.auth import TokenAuthentication": "from codeflare_sdk.cluster.auth import TokenAuthentication\\n\",\n\t\"from codeflare_sdk.job import RayJobClient",
59-
"token = ''": fmt.Sprintf("token = '%s'", userToken),
60-
"server = ''": fmt.Sprintf("server = '%s'", GetOpenShiftApiUrl(test)),
61-
"namespace='ray-finetune-llm-deepspeed'": fmt.Sprintf("namespace='%s'", namespace.Name),
62-
"head_cpus=16": "head_cpus=2",
63-
"head_extended_resource_requests=1": "head_extended_resource_requests=0",
64-
"num_workers=7": "num_workers=1",
65-
"worker_cpu_requests=16": "worker_cpu_requests=4",
66-
"worker_cpu_limits=16": "worker_cpu_limits=4",
67-
"worker_memory_requests=128": "worker_memory_requests=64",
68-
"worker_memory_limits=256": "worker_memory_limits=128",
69-
"head_memory=128": "head_memory=48",
70-
"client = cluster.job_client": "ray_dashboard = cluster.cluster_dashboard_uri()\\n\",\n\t\"header = {\\\"Authorization\\\": \\\"Bearer " + userToken + "\\\"}\\n\",\n\t\"client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\n",
71-
"--num-devices=8": fmt.Sprintf("--num-devices=%d", numGpus),
72-
"--num-epochs=3": fmt.Sprintf("--num-epochs=%d", 1),
73-
"--ds-config=./deepspeed_configs/zero_3_offload_optim+param.json": fmt.Sprintf("--ds-config=./%s \\\"\\n\",\n\t\" \\\"--lora-config=./lora.json \\\"\\n\",\n\t\" \\\"--as-test", modelConfigFile),
59+
"token = ''": fmt.Sprintf("token = '%s'", userToken),
60+
"server = ''": fmt.Sprintf("server = '%s'", GetOpenShiftApiUrl(test)),
61+
"namespace='ray-finetune-llm-deepspeed'": fmt.Sprintf("namespace='%s'", namespace.Name),
62+
"head_cpus=16": "head_cpus=2",
63+
"head_extended_resource_requests=1": "head_extended_resource_requests=0",
64+
"num_workers=7": "num_workers=1",
65+
"worker_cpu_requests=16": "worker_cpu_requests=4",
66+
"worker_cpu_limits=16": "worker_cpu_limits=4",
67+
"worker_memory_requests=128": "worker_memory_requests=64",
68+
"worker_memory_limits=256": "worker_memory_limits=128",
69+
"head_memory=128": "head_memory=48",
70+
"client = cluster.job_client": "ray_dashboard = cluster.cluster_dashboard_uri()\\n\",\n\t\"header = {\\\"Authorization\\\": \\\"Bearer " + userToken + "\\\"}\\n\",\n\t\"client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\n",
71+
"--num-devices=8": fmt.Sprintf("--num-devices=%d", numGpus),
72+
"--num-epochs=3": fmt.Sprintf("--num-epochs=%d", 1),
73+
"--model-name=meta-llama/Meta-Llama-3.1-8B": fmt.Sprintf("--model-name=%s", modelName),
74+
"--ds-config=./deepspeed_configs/zero_3_offload_optim_param.json": fmt.Sprintf("--ds-config=./%s \\\"\\n\",\n\t\" \\\"--lora-config=./lora.json \\\"\\n\",\n\t\" \\\"--as-test", modelConfigFile),
7475
"--batch-size-per-device=32": "--batch-size-per-device=6",
7576
"--eval-batch-size-per-device=32": "--eval-batch-size-per-device=6",
7677
"'pip': 'requirements.txt'": "'pip': '/opt/app-root/src/requirements.txt'",
@@ -117,8 +118,6 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
117118
),
118119
)
119120

120-
time.Sleep(30 * time.Second)
121-
122121
// Fetch created raycluster
123122
rayClusterName := "ray"
124123
rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Get(test.Ctx(), rayClusterName, metav1.GetOptions{})
@@ -128,37 +127,44 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
128127
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
129128
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}
130129
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
131-
if err != nil {
132-
test.T().Errorf("%s", err)
133-
}
130+
test.Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Failed to create new raycluster client: %s", err))
134131

132+
// wait until rayjob exists
133+
test.Eventually(func() []RayJobDetailsResponse {
134+
rayJobs, err := rayClient.GetJobs()
135+
test.Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Failed to fetch ray-jobs : %s", err))
136+
return *rayJobs
137+
}, TestTimeoutMedium, 1*time.Second).Should(HaveLen(1), "Ray job not found")
138+
139+
// Get test job-id
135140
jobID := GetTestJobId(test, rayClient, dashboardUrl.Host)
136-
test.Expect(jobID).ToNot(Equal(nil))
141+
test.Expect(jobID).ToNot(BeEmpty())
137142

138143
// Wait for the job to be succeeded or failed
139144
var rayJobStatus string
140-
fmt.Printf("Waiting for job to be Succeeded...\n")
145+
test.T().Logf("Waiting for job to be Succeeded...\n")
141146
test.Eventually(func() string {
142147
resp, err := rayClient.GetJobDetails(jobID)
143-
test.Expect(err).ToNot(HaveOccurred())
148+
test.Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Failed to get job details :%s", err))
144149
rayJobStatusVal := resp.Status
145150
if rayJobStatusVal == "SUCCEEDED" || rayJobStatusVal == "FAILED" {
146-
fmt.Printf("JobStatus : %s\n", rayJobStatusVal)
151+
test.T().Logf("JobStatus - %s\n", rayJobStatusVal)
147152
rayJobStatus = rayJobStatusVal
148-
WriteRayJobAPILogs(test, rayClient, jobID)
149153
return rayJobStatus
150154
}
151155
if rayJobStatus != rayJobStatusVal && rayJobStatusVal != "SUCCEEDED" {
152-
fmt.Printf("JobStatus : %s...\n", rayJobStatusVal)
156+
test.T().Logf("JobStatus - %s...\n", rayJobStatusVal)
153157
rayJobStatus = rayJobStatusVal
154158
}
155159
return rayJobStatus
156-
}, TestTimeoutDouble, 3*time.Second).Should(Or(Equal("SUCCEEDED"), Equal("FAILED")), "Job did not complete within the expected time")
160+
}, TestTimeoutDouble, 1*time.Second).Should(Or(Equal("SUCCEEDED"), Equal("FAILED")), "Job did not complete within the expected time")
157161
// Store job logs in output directory
158162
WriteRayJobAPILogs(test, rayClient, jobID)
163+
164+
// Assert ray-job status after job execution
159165
test.Expect(rayJobStatus).To(Equal("SUCCEEDED"), "RayJob failed !")
160166

161167
// Make sure the RayCluster finishes and is deleted
162-
test.Eventually(RayClusters(test, namespace.Name), TestTimeoutMedium).
163-
Should(HaveLen(0))
168+
test.Eventually(RayClusters(test, namespace.Name), TestTimeoutLong).
169+
Should(BeEmpty())
164170
}

tests/odh/support.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ import (
2222
"net/url"
2323
"os"
2424

25-
. "github.com/onsi/gomega"
2625
gomega "github.com/onsi/gomega"
2726
"github.com/project-codeflare/codeflare-common/support"
28-
. "github.com/project-codeflare/codeflare-common/support"
2927
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
3028
v1 "k8s.io/api/core/v1"
3129
)
@@ -50,23 +48,23 @@ func ReadFileExt(t support.Test, fileName string) []byte {
5048
func GetDashboardUrl(test support.Test, namespace *v1.Namespace, rayCluster *rayv1.RayCluster) *url.URL {
5149
dashboardName := "ray-dashboard-" + rayCluster.Name
5250
test.T().Logf("Raycluster created : %s\n", rayCluster.Name)
53-
route := GetRoute(test, namespace.Name, dashboardName)
51+
route := support.GetRoute(test, namespace.Name, dashboardName)
5452
hostname := route.Status.Ingress[0].Host
5553
dashboardUrl, _ := url.Parse("https://" + hostname)
5654
test.T().Logf("Ray-dashboard route : %s\n", dashboardUrl.String())
5755

5856
return dashboardUrl
5957
}
6058

61-
func GetTestJobId(test Test, rayClient RayClusterClient, hostName string) string {
59+
func GetTestJobId(test support.Test, rayClient support.RayClusterClient, hostName string) string {
6260
listJobsReq, err := http.NewRequest("GET", "https://"+hostName+"/api/jobs/", nil)
6361
if err != nil {
6462
test.T().Errorf("failed to do get request: %s\n", err)
6563
}
6664
listJobsReq.Header.Add("Authorization", "Bearer "+test.Config().BearerToken)
6765

6866
allJobsData, err := rayClient.GetJobs()
69-
test.Expect(err).ToNot(HaveOccurred())
67+
test.Expect(err).ToNot(gomega.HaveOccurred())
7068

7169
jobID := (*allJobsData)[0].SubmissionID
7270
if len(*allJobsData) > 0 {

0 commit comments

Comments
 (0)