@@ -30,13 +30,13 @@ import (
30
30
)
31
31
32
32
func TestRayFinetuneLlmDeepspeedDemoLlama_2_7b (t * testing.T ) {
33
- rayFinetuneLlmDeepspeed (t , 1 , "zero_3_llama_2_7b.json" )
33
+ rayFinetuneLlmDeepspeed (t , 1 , "meta-llama/Llama-2-7b-chat-hf" , " zero_3_llama_2_7b.json" )
34
34
}
35
35
func TestRayFinetuneLlmDeepspeedDemoLlama_31_8b (t * testing.T ) {
36
- rayFinetuneLlmDeepspeed (t , 1 , "zero_3_offload_optim_param.json" )
36
+ rayFinetuneLlmDeepspeed (t , 1 , "meta-llama/Meta-Llama-3.1-8B" , " zero_3_offload_optim_param.json" )
37
37
}
38
38
39
- func rayFinetuneLlmDeepspeed (t * testing.T , numGpus int , modelConfigFile string ) {
39
+ func rayFinetuneLlmDeepspeed (t * testing.T , numGpus int , modelName string , modelConfigFile string ) {
40
40
test := With (t )
41
41
42
42
// Create a namespace
@@ -56,21 +56,22 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
56
56
"import os" : "import os,time,sys" ,
57
57
"import sys" : "!cp /opt/app-root/notebooks/* ./\\ n\" ,\n \t \" !ls" ,
58
58
"from codeflare_sdk.cluster.auth import TokenAuthentication" : "from codeflare_sdk.cluster.auth import TokenAuthentication\\ n\" ,\n \t \" from codeflare_sdk.job import RayJobClient" ,
59
- "token = ''" : fmt .Sprintf ("token = '%s'" , userToken ),
60
- "server = ''" : fmt .Sprintf ("server = '%s'" , GetOpenShiftApiUrl (test )),
61
- "namespace='ray-finetune-llm-deepspeed'" : fmt .Sprintf ("namespace='%s'" , namespace .Name ),
62
- "head_cpus=16" : "head_cpus=2" ,
63
- "head_extended_resource_requests=1" : "head_extended_resource_requests=0" ,
64
- "num_workers=7" : "num_workers=1" ,
65
- "worker_cpu_requests=16" : "worker_cpu_requests=4" ,
66
- "worker_cpu_limits=16" : "worker_cpu_limits=4" ,
67
- "worker_memory_requests=128" : "worker_memory_requests=64" ,
68
- "worker_memory_limits=256" : "worker_memory_limits=128" ,
69
- "head_memory=128" : "head_memory=48" ,
70
- "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
71
- "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
72
- "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
73
- "--ds-config=./deepspeed_configs/zero_3_offload_optim+param.json" : fmt .Sprintf ("--ds-config=./%s \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" , modelConfigFile ),
59
+ "token = ''" : fmt .Sprintf ("token = '%s'" , userToken ),
60
+ "server = ''" : fmt .Sprintf ("server = '%s'" , GetOpenShiftApiUrl (test )),
61
+ "namespace='ray-finetune-llm-deepspeed'" : fmt .Sprintf ("namespace='%s'" , namespace .Name ),
62
+ "head_cpus=16" : "head_cpus=2" ,
63
+ "head_extended_resource_requests=1" : "head_extended_resource_requests=0" ,
64
+ "num_workers=7" : "num_workers=1" ,
65
+ "worker_cpu_requests=16" : "worker_cpu_requests=4" ,
66
+ "worker_cpu_limits=16" : "worker_cpu_limits=4" ,
67
+ "worker_memory_requests=128" : "worker_memory_requests=64" ,
68
+ "worker_memory_limits=256" : "worker_memory_limits=128" ,
69
+ "head_memory=128" : "head_memory=48" ,
70
+ "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
71
+ "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
72
+ "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
73
+ "--model-name=meta-llama/Meta-Llama-3.1-8B" : fmt .Sprintf ("--model-name=%s" , modelName ),
74
+ "--ds-config=./deepspeed_configs/zero_3_offload_optim_param.json" : fmt .Sprintf ("--ds-config=./%s \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" , modelConfigFile ),
74
75
"--batch-size-per-device=32" : "--batch-size-per-device=6" ,
75
76
"--eval-batch-size-per-device=32" : "--eval-batch-size-per-device=6" ,
76
77
"'pip': 'requirements.txt'" : "'pip': '/opt/app-root/src/requirements.txt'" ,
@@ -117,8 +118,6 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
117
118
),
118
119
)
119
120
120
- time .Sleep (30 * time .Second )
121
-
122
121
// Fetch created raycluster
123
122
rayClusterName := "ray"
124
123
rayCluster , err := test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Get (test .Ctx (), rayClusterName , metav1.GetOptions {})
@@ -128,37 +127,44 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
128
127
dashboardUrl := GetDashboardUrl (test , namespace , rayCluster )
129
128
rayClusterClientConfig := RayClusterClientConfig {Address : dashboardUrl .String (), Client : nil , InsecureSkipVerify : true }
130
129
rayClient , err := NewRayClusterClient (rayClusterClientConfig , test .Config ().BearerToken )
131
- if err != nil {
132
- test .T ().Errorf ("%s" , err )
133
- }
130
+ test .Expect (err ).ToNot (HaveOccurred (), fmt .Sprintf ("Failed to create new raycluster client: %s" , err ))
134
131
132
+ // wait until rayjob exists
133
+ test .Eventually (func () []RayJobDetailsResponse {
134
+ rayJobs , err := rayClient .GetJobs ()
135
+ test .Expect (err ).ToNot (HaveOccurred (), fmt .Sprintf ("Failed to fetch ray-jobs : %s" , err ))
136
+ return * rayJobs
137
+ }, TestTimeoutMedium , 1 * time .Second ).Should (HaveLen (1 ), "Ray job not found" )
138
+
139
+ // Get test job-id
135
140
jobID := GetTestJobId (test , rayClient , dashboardUrl .Host )
136
- test .Expect (jobID ).ToNot (Equal ( nil ))
141
+ test .Expect (jobID ).ToNot (BeEmpty ( ))
137
142
138
143
// Wait for the job to be succeeded or failed
139
144
var rayJobStatus string
140
- fmt . Printf ("Waiting for job to be Succeeded...\n " )
145
+ test . T (). Logf ("Waiting for job to be Succeeded...\n " )
141
146
test .Eventually (func () string {
142
147
resp , err := rayClient .GetJobDetails (jobID )
143
- test .Expect (err ).ToNot (HaveOccurred ())
148
+ test .Expect (err ).ToNot (HaveOccurred (), fmt . Sprintf ( "Failed to get job details :%s" , err ) )
144
149
rayJobStatusVal := resp .Status
145
150
if rayJobStatusVal == "SUCCEEDED" || rayJobStatusVal == "FAILED" {
146
- fmt . Printf ( "JobStatus : %s\n " , rayJobStatusVal )
151
+ test . T (). Logf ( "JobStatus - %s\n " , rayJobStatusVal )
147
152
rayJobStatus = rayJobStatusVal
148
- WriteRayJobAPILogs (test , rayClient , jobID )
149
153
return rayJobStatus
150
154
}
151
155
if rayJobStatus != rayJobStatusVal && rayJobStatusVal != "SUCCEEDED" {
152
- fmt . Printf ( "JobStatus : %s...\n " , rayJobStatusVal )
156
+ test . T (). Logf ( "JobStatus - %s...\n " , rayJobStatusVal )
153
157
rayJobStatus = rayJobStatusVal
154
158
}
155
159
return rayJobStatus
156
- }, TestTimeoutDouble , 3 * time .Second ).Should (Or (Equal ("SUCCEEDED" ), Equal ("FAILED" )), "Job did not complete within the expected time" )
160
+ }, TestTimeoutDouble , 1 * time .Second ).Should (Or (Equal ("SUCCEEDED" ), Equal ("FAILED" )), "Job did not complete within the expected time" )
157
161
// Store job logs in output directory
158
162
WriteRayJobAPILogs (test , rayClient , jobID )
163
+
164
+ // Assert ray-job status after job execution
159
165
test .Expect (rayJobStatus ).To (Equal ("SUCCEEDED" ), "RayJob failed !" )
160
166
161
167
// Make sure the RayCluster finishes and is deleted
162
- test .Eventually (RayClusters (test , namespace .Name ), TestTimeoutMedium ).
163
- Should (HaveLen ( 0 ))
168
+ test .Eventually (RayClusters (test , namespace .Name ), TestTimeoutLong ).
169
+ Should (BeEmpty ( ))
164
170
}
0 commit comments