@@ -30,13 +30,13 @@ import (
30
30
)
31
31
32
32
func TestRayFinetuneLlmDeepspeedDemoLlama_2_7b (t * testing.T ) {
33
- rayFinetuneLlmDeepspeed (t , 1 , "zero_3_llama_2_7b.json" )
33
+ rayFinetuneLlmDeepspeed (t , 1 , "meta-llama/Llama-2-7b-chat-hf" , " zero_3_llama_2_7b.json" )
34
34
}
35
35
func TestRayFinetuneLlmDeepspeedDemoLlama_31_8b (t * testing.T ) {
36
- rayFinetuneLlmDeepspeed (t , 1 , "zero_3_offload_optim_param.json" )
36
+ rayFinetuneLlmDeepspeed (t , 1 , "meta-llama/Meta-Llama-3.1-8B" , " zero_3_offload_optim_param.json" )
37
37
}
38
38
39
- func rayFinetuneLlmDeepspeed (t * testing.T , numGpus int , modelConfigFile string ) {
39
+ func rayFinetuneLlmDeepspeed (t * testing.T , numGpus int , modelName string , modelConfigFile string ) {
40
40
test := With (t )
41
41
42
42
// Create a namespace
@@ -56,21 +56,22 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
56
56
"import os" : "import os,time,sys" ,
57
57
"import sys" : "!cp /opt/app-root/notebooks/* ./\\ n\" ,\n \t \" !ls" ,
58
58
"from codeflare_sdk.cluster.auth import TokenAuthentication" : "from codeflare_sdk.cluster.auth import TokenAuthentication\\ n\" ,\n \t \" from codeflare_sdk.job import RayJobClient" ,
59
- "token = ''" : fmt .Sprintf ("token = '%s'" , userToken ),
60
- "server = ''" : fmt .Sprintf ("server = '%s'" , GetOpenShiftApiUrl (test )),
61
- "namespace='ray-finetune-llm-deepspeed'" : fmt .Sprintf ("namespace='%s'" , namespace .Name ),
62
- "head_cpus=16" : "head_cpus=2" ,
63
- "head_extended_resource_requests=1" : "head_extended_resource_requests=0" ,
64
- "num_workers=7" : "num_workers=1" ,
65
- "worker_cpu_requests=16" : "worker_cpu_requests=4" ,
66
- "worker_cpu_limits=16" : "worker_cpu_limits=4" ,
67
- "worker_memory_requests=128" : "worker_memory_requests=64" ,
68
- "worker_memory_limits=256" : "worker_memory_limits=128" ,
69
- "head_memory=128" : "head_memory=48" ,
70
- "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
71
- "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
72
- "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
73
- "--ds-config=./deepspeed_configs/zero_3_offload_optim+param.json" : fmt .Sprintf ("--ds-config=./%s \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" , modelConfigFile ),
59
+ "token = ''" : fmt .Sprintf ("token = '%s'" , userToken ),
60
+ "server = ''" : fmt .Sprintf ("server = '%s'" , GetOpenShiftApiUrl (test )),
61
+ "namespace='ray-finetune-llm-deepspeed'" : fmt .Sprintf ("namespace='%s'" , namespace .Name ),
62
+ "head_cpus=16" : "head_cpus=2" ,
63
+ "head_extended_resource_requests=1" : "head_extended_resource_requests=0" ,
64
+ "num_workers=7" : "num_workers=1" ,
65
+ "worker_cpu_requests=16" : "worker_cpu_requests=4" ,
66
+ "worker_cpu_limits=16" : "worker_cpu_limits=4" ,
67
+ "worker_memory_requests=128" : "worker_memory_requests=64" ,
68
+ "worker_memory_limits=256" : "worker_memory_limits=128" ,
69
+ "head_memory=128" : "head_memory=48" ,
70
+ "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
71
+ "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
72
+ "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
73
+ "--model-name=meta-llama/Meta-Llama-3.1-8B" : fmt .Sprintf ("--model-name=%s" , modelName ),
74
+ "--ds-config=./deepspeed_configs/zero_3_offload_optim_param.json" : fmt .Sprintf ("--ds-config=./%s \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" , modelConfigFile ),
74
75
"--batch-size-per-device=32" : "--batch-size-per-device=6" ,
75
76
"--eval-batch-size-per-device=32" : "--eval-batch-size-per-device=6" ,
76
77
"'pip': 'requirements.txt'" : "'pip': '/opt/app-root/src/requirements.txt'" ,
@@ -83,7 +84,6 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
83
84
updatedNotebookContent = strings .Replace (updatedNotebookContent , oldValue , newValue , - 1 )
84
85
}
85
86
updatedNotebook := []byte (updatedNotebookContent )
86
- os .WriteFile ("demo.ipynb" , updatedNotebook , 0644 )
87
87
88
88
// Test configuration
89
89
jupyterNotebookConfigMapFileName := "ray_finetune_llm_deepspeed.ipynb"
@@ -117,8 +117,6 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
117
117
),
118
118
)
119
119
120
- time .Sleep (30 * time .Second )
121
-
122
120
// Fetch created raycluster
123
121
rayClusterName := "ray"
124
122
rayCluster , err := test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Get (test .Ctx (), rayClusterName , metav1.GetOptions {})
@@ -128,37 +126,44 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string)
128
126
dashboardUrl := GetDashboardUrl (test , namespace , rayCluster )
129
127
rayClusterClientConfig := RayClusterClientConfig {Address : dashboardUrl .String (), Client : nil , InsecureSkipVerify : true }
130
128
rayClient , err := NewRayClusterClient (rayClusterClientConfig , test .Config ().BearerToken )
131
- if err != nil {
132
- test .T ().Errorf ("%s" , err )
133
- }
129
+ test .Expect (err ).ToNot (HaveOccurred (), fmt .Sprintf ("Failed to create new raycluster client: %s" , err ))
134
130
131
+ // wait until rayjob exists
132
+ test .Eventually (func () []RayJobDetailsResponse {
133
+ rayJobs , err := rayClient .GetJobs ()
134
+ test .Expect (err ).ToNot (HaveOccurred (), fmt .Sprintf ("Failed to fetch ray-jobs : %s" , err ))
135
+ return * rayJobs
136
+ }, TestTimeoutMedium , 1 * time .Second ).Should (HaveLen (1 ), "Ray job not found" )
137
+
138
+ // Get test job-id
135
139
jobID := GetTestJobId (test , rayClient , dashboardUrl .Host )
136
- test .Expect (jobID ).ToNot (Equal ( nil ))
140
+ test .Expect (jobID ).ToNot (BeEmpty ( ))
137
141
138
142
// Wait for the job to be succeeded or failed
139
143
var rayJobStatus string
140
- fmt . Printf ("Waiting for job to be Succeeded...\n " )
144
+ test . T (). Logf ("Waiting for job to be Succeeded...\n " )
141
145
test .Eventually (func () string {
142
146
resp , err := rayClient .GetJobDetails (jobID )
143
- test .Expect (err ).ToNot (HaveOccurred ())
147
+ test .Expect (err ).ToNot (HaveOccurred (), fmt . Sprintf ( "Failed to get job details :%s" , err ) )
144
148
rayJobStatusVal := resp .Status
145
149
if rayJobStatusVal == "SUCCEEDED" || rayJobStatusVal == "FAILED" {
146
- fmt . Printf ( "JobStatus : %s\n " , rayJobStatusVal )
150
+ test . T (). Logf ( "JobStatus - %s\n " , rayJobStatusVal )
147
151
rayJobStatus = rayJobStatusVal
148
- WriteRayJobAPILogs (test , rayClient , jobID )
149
152
return rayJobStatus
150
153
}
151
154
if rayJobStatus != rayJobStatusVal && rayJobStatusVal != "SUCCEEDED" {
152
- fmt . Printf ( "JobStatus : %s...\n " , rayJobStatusVal )
155
+ test . T (). Logf ( "JobStatus - %s...\n " , rayJobStatusVal )
153
156
rayJobStatus = rayJobStatusVal
154
157
}
155
158
return rayJobStatus
156
- }, TestTimeoutDouble , 3 * time .Second ).Should (Or (Equal ("SUCCEEDED" ), Equal ("FAILED" )), "Job did not complete within the expected time" )
159
+ }, TestTimeoutDouble , 1 * time .Second ).Should (Or (Equal ("SUCCEEDED" ), Equal ("FAILED" )), "Job did not complete within the expected time" )
157
160
// Store job logs in output directory
158
161
WriteRayJobAPILogs (test , rayClient , jobID )
162
+
163
+ // Assert ray-job status after job execution
159
164
test .Expect (rayJobStatus ).To (Equal ("SUCCEEDED" ), "RayJob failed !" )
160
165
161
166
// Make sure the RayCluster finishes and is deleted
162
- test .Eventually (RayClusters (test , namespace .Name ), TestTimeoutMedium ).
163
- Should (HaveLen ( 0 ))
167
+ test .Eventually (RayClusters (test , namespace .Name ), TestTimeoutLong ).
168
+ Should (BeEmpty ( ))
164
169
}
0 commit comments