@@ -32,15 +32,17 @@ import (
32
32
var _ = Describe ("PyTorchJob controller" , func () {
33
33
// Define utility constants for object names and testing timeouts/durations and intervals.
34
34
const (
35
- namespace = "default"
36
- name = "test-job"
37
-
38
- timeout = time .Second * 10
39
- interval = time .Millisecond * 250
35
+ timeout = time .Second * 10
36
+ interval = time .Millisecond * 250
37
+ expectedPort = int32 (8080 )
40
38
)
41
39
42
40
Context ("When creating the PyTorchJob" , func () {
43
41
It ("Should get the corresponding resources successfully" , func () {
42
+ const (
43
+ namespace = "default"
44
+ name = "test-job"
45
+ )
44
46
By ("By creating a new PyTorchJob" )
45
47
ctx := context .Background ()
46
48
job := newPyTorchJobForTest (name , namespace )
@@ -56,7 +58,7 @@ var _ = Describe("PyTorchJob controller", func() {
56
58
Ports : []corev1.ContainerPort {
57
59
{
58
60
Name : pytorchv1 .DefaultPortName ,
59
- ContainerPort : 80 ,
61
+ ContainerPort : expectedPort ,
60
62
Protocol : corev1 .ProtocolTCP ,
61
63
},
62
64
},
@@ -76,7 +78,7 @@ var _ = Describe("PyTorchJob controller", func() {
76
78
Ports : []corev1.ContainerPort {
77
79
{
78
80
Name : pytorchv1 .DefaultPortName ,
79
- ContainerPort : 80 ,
81
+ ContainerPort : expectedPort ,
80
82
Protocol : corev1 .ProtocolTCP ,
81
83
},
82
84
},
@@ -87,41 +89,46 @@ var _ = Describe("PyTorchJob controller", func() {
87
89
},
88
90
}
89
91
90
- Expect (k8sClient .Create (ctx , job )).Should (Succeed ())
92
+ Expect (testK8sClient .Create (ctx , job )).Should (Succeed ())
91
93
92
94
key := types.NamespacedName {Name : name , Namespace : namespace }
93
95
created := & pytorchv1.PyTorchJob {}
94
96
95
97
// We'll need to retry getting this newly created PyTorchJob, given that creation may not immediately happen.
96
98
Eventually (func () bool {
97
- err := k8sClient .Get (ctx , key , created )
99
+ err := testK8sClient .Get (ctx , key , created )
98
100
return err == nil
99
101
}, timeout , interval ).Should (BeTrue ())
100
102
101
103
masterKey := types.NamespacedName {Name : fmt .Sprintf ("%s-master-0" , name ), Namespace : namespace }
102
104
masterPod := & corev1.Pod {}
103
105
Eventually (func () bool {
104
- err := k8sClient .Get (ctx , masterKey , masterPod )
106
+ err := testK8sClient .Get (ctx , masterKey , masterPod )
105
107
return err == nil
106
108
}, timeout , interval ).Should (BeTrue ())
107
109
108
110
masterSvc := & corev1.Service {}
109
111
Eventually (func () bool {
110
- err := k8sClient .Get (ctx , masterKey , masterSvc )
112
+ err := testK8sClient .Get (ctx , masterKey , masterSvc )
111
113
return err == nil
112
114
}, timeout , interval ).Should (BeTrue ())
113
115
116
+ // Check the pod port.
114
117
Expect (masterPod .Spec .Containers [0 ].Ports ).To (ContainElement (corev1.ContainerPort {
115
118
Name : pytorchv1 .DefaultPortName ,
116
- ContainerPort : 80 ,
119
+ ContainerPort : expectedPort ,
117
120
Protocol : corev1 .ProtocolTCP }))
121
+ // Check MASTER_PORT and MASTER_ADDR env variable
118
122
Expect (masterPod .Spec .Containers [0 ].Env ).To (ContainElements (corev1.EnvVar {
119
123
Name : EnvMasterPort ,
120
124
Value : fmt .Sprintf ("%d" , masterSvc .Spec .Ports [0 ].Port ),
121
125
}, corev1.EnvVar {
122
126
Name : EnvMasterAddr ,
123
127
Value : masterSvc .Name ,
124
128
}))
129
+ // Check service port.
130
+ Expect (masterSvc .Spec .Ports [0 ].Port ).To (Equal (expectedPort ))
131
+ // Check owner reference.
125
132
trueVal := true
126
133
Expect (masterPod .OwnerReferences ).To (ContainElement (metav1.OwnerReference {
127
134
APIVersion : pytorchv1 .SchemeGroupVersion .String (),
@@ -143,19 +150,148 @@ var _ = Describe("PyTorchJob controller", func() {
143
150
// Test job status.
144
151
masterPod .Status .Phase = corev1 .PodSucceeded
145
152
masterPod .ResourceVersion = ""
146
- Expect (k8sClient .Status ().Update (ctx , masterPod )).Should (Succeed ())
153
+ Expect (testK8sClient .Status ().Update (ctx , masterPod )).Should (Succeed ())
154
+ Eventually (func () bool {
155
+ err := testK8sClient .Get (ctx , key , created )
156
+ if err != nil {
157
+ return false
158
+ }
159
+ return created .Status .ReplicaStatuses != nil && created .Status .
160
+ ReplicaStatuses [pytorchv1 .PyTorchReplicaTypeMaster ].Succeeded == 1
161
+ }, timeout , interval ).Should (BeTrue ())
162
+ // Check if the job is succeeded.
163
+ cond := getCondition (created .Status , commonv1 .JobSucceeded )
164
+ Expect (cond .Status ).To (Equal (corev1 .ConditionTrue ))
165
+ By ("Deleting the PyTorchJob" )
166
+ Expect (testK8sClient .Delete (ctx , job )).Should (Succeed ())
167
+ })
168
+ })
169
+
170
+ Context ("When creating the elastic PyTorchJob" , func () {
171
+ // TODO(gaocegege): Test with more than 1 worker.
172
+ It ("Should get the corresponding resources successfully" , func () {
173
+ // Define the expected elastic policy.
174
+ var (
175
+ backendC10D = pytorchv1 .BackendC10D
176
+ minReplicas = int32Ptr (1 )
177
+ maxReplicas = int32Ptr (3 )
178
+ maxRestarts = int32Ptr (3 )
179
+ namespace = "default"
180
+ name = "easltic-job"
181
+ )
182
+
183
+ By ("By creating a new PyTorchJob" )
184
+ ctx := context .Background ()
185
+ job := newPyTorchJobForTest (name , namespace )
186
+ job .Spec .ElasticPolicy = & pytorchv1.ElasticPolicy {
187
+ RDZVBackend : & backendC10D ,
188
+ MaxReplicas : maxReplicas ,
189
+ MinReplicas : minReplicas ,
190
+ MaxRestarts : maxRestarts ,
191
+ }
192
+ job .Spec .PyTorchReplicaSpecs = map [commonv1.ReplicaType ]* commonv1.ReplicaSpec {
193
+ pytorchv1 .PyTorchReplicaTypeWorker : {
194
+ Replicas : int32Ptr (1 ),
195
+ Template : corev1.PodTemplateSpec {
196
+ Spec : corev1.PodSpec {
197
+ Containers : []corev1.Container {
198
+ {
199
+ Image : "test-image" ,
200
+ Name : pytorchv1 .DefaultContainerName ,
201
+ Ports : []corev1.ContainerPort {
202
+ {
203
+ Name : pytorchv1 .DefaultPortName ,
204
+ ContainerPort : expectedPort ,
205
+ Protocol : corev1 .ProtocolTCP ,
206
+ },
207
+ },
208
+ },
209
+ },
210
+ },
211
+ },
212
+ },
213
+ }
214
+
215
+ Expect (testK8sClient .Create (ctx , job )).Should (Succeed ())
216
+
217
+ key := types.NamespacedName {Name : name , Namespace : namespace }
218
+ created := & pytorchv1.PyTorchJob {}
219
+
220
+ // We'll need to retry getting this newly created PyTorchJob, given that creation may not immediately happen.
221
+ Eventually (func () bool {
222
+ err := testK8sClient .Get (ctx , key , created )
223
+ return err == nil
224
+ }, timeout , interval ).Should (BeTrue ())
225
+
226
+ workerKey := types.NamespacedName {Name : fmt .Sprintf ("%s-worker-0" , name ), Namespace : namespace }
227
+ pod := & corev1.Pod {}
228
+ Eventually (func () bool {
229
+ err := testK8sClient .Get (ctx , workerKey , pod )
230
+ return err == nil
231
+ }, timeout , interval ).Should (BeTrue ())
232
+
233
+ svc := & corev1.Service {}
234
+ Eventually (func () bool {
235
+ err := testK8sClient .Get (ctx , workerKey , svc )
236
+ return err == nil
237
+ }, timeout , interval ).Should (BeTrue ())
238
+
239
+ // Check pod port.
240
+ Expect (pod .Spec .Containers [0 ].Ports ).To (ContainElement (corev1.ContainerPort {
241
+ Name : pytorchv1 .DefaultPortName ,
242
+ ContainerPort : expectedPort ,
243
+ Protocol : corev1 .ProtocolTCP }))
244
+ // Check environment variables.
245
+ Expect (pod .Spec .Containers [0 ].Env ).To (ContainElements (corev1.EnvVar {
246
+ Name : EnvRDZVBackend ,
247
+ Value : string (backendC10D ),
248
+ }, corev1.EnvVar {
249
+ Name : EnvNNodes ,
250
+ Value : fmt .Sprintf ("%d:%d" , * minReplicas , * maxReplicas ),
251
+ }, corev1.EnvVar {
252
+ Name : EnvRDZVEndpoint ,
253
+ Value : fmt .Sprintf ("%s:%d" , svc .Name , expectedPort ),
254
+ }, corev1.EnvVar {
255
+ Name : EnvMaxRestarts ,
256
+ Value : fmt .Sprintf ("%d" , * maxRestarts ),
257
+ }))
258
+ Expect (svc .Spec .Ports [0 ].Port ).To (Equal (expectedPort ))
259
+ // Check owner references.
260
+ trueVal := true
261
+ Expect (pod .OwnerReferences ).To (ContainElement (metav1.OwnerReference {
262
+ APIVersion : pytorchv1 .SchemeGroupVersion .String (),
263
+ Kind : pytorchv1 .Kind ,
264
+ Name : name ,
265
+ UID : created .UID ,
266
+ Controller : & trueVal ,
267
+ BlockOwnerDeletion : & trueVal ,
268
+ }))
269
+ Expect (svc .OwnerReferences ).To (ContainElement (metav1.OwnerReference {
270
+ APIVersion : pytorchv1 .SchemeGroupVersion .String (),
271
+ Kind : pytorchv1 .Kind ,
272
+ Name : name ,
273
+ UID : created .UID ,
274
+ Controller : & trueVal ,
275
+ BlockOwnerDeletion : & trueVal ,
276
+ }))
277
+
278
+ // Test job status.
279
+ pod .Status .Phase = corev1 .PodSucceeded
280
+ pod .ResourceVersion = ""
281
+ Expect (testK8sClient .Status ().Update (ctx , pod )).Should (Succeed ())
147
282
Eventually (func () bool {
148
- err := k8sClient .Get (ctx , key , created )
283
+ err := testK8sClient .Get (ctx , key , created )
149
284
if err != nil {
150
285
return false
151
286
}
152
- return created .Status .ReplicaStatuses [pytorchv1 .PyTorchReplicaTypeMaster ].Succeeded == 1
287
+ return created .Status .ReplicaStatuses != nil && created .Status .
288
+ ReplicaStatuses [pytorchv1 .PyTorchReplicaTypeWorker ].Succeeded == 1
153
289
}, timeout , interval ).Should (BeTrue ())
154
290
// Check if the job is succeeded.
155
291
cond := getCondition (created .Status , commonv1 .JobSucceeded )
156
292
Expect (cond .Status ).To (Equal (corev1 .ConditionTrue ))
157
293
By ("Deleting the PyTorchJob" )
158
- Expect (k8sClient .Delete (ctx , job )).Should (Succeed ())
294
+ Expect (testK8sClient .Delete (ctx , job )).Should (Succeed ())
159
295
})
160
296
})
161
297
})
0 commit comments