Skip to content

Commit 2bfbb3d

Browse files
committed
fix: Fix the test cases
Signed-off-by: Ce Gao <[email protected]>
1 parent 756e343 commit 2bfbb3d

File tree

3 files changed

+163
-29
lines changed

3 files changed

+163
-29
lines changed

pkg/controller.v1/pytorch/elastic_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,12 @@ import (
2525
pytorchv1 "github.com/kubeflow/training-operator/pkg/apis/pytorch/v1"
2626
)
2727

28-
var (
29-
backendC10D = pytorchv1.BackendC10D
30-
)
31-
3228
func TestElasticGenerate(t *testing.T) {
3329
gomega.RegisterFailHandler(ginkgo.Fail)
3430
defer ginkgo.GinkgoRecover()
3531

32+
backendC10D := pytorchv1.BackendC10D
33+
3634
tests := []struct {
3735
name string
3836
job *pytorchv1.PyTorchJob

pkg/controller.v1/pytorch/pytorchjob_controller_suite_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ import (
3838
// http://onsi.github.io/ginkgo/ to learn more about Ginkgo.
3939

4040
var (
41-
k8sClient client.Client
42-
testEnv *envtest.Environment
43-
ctx context.Context
44-
cancel context.CancelFunc
41+
testK8sClient client.Client
42+
testEnv *envtest.Environment
43+
testCtx context.Context
44+
testCancel context.CancelFunc
4545
)
4646

4747
func TestAPIs(t *testing.T) {
@@ -55,7 +55,7 @@ func TestAPIs(t *testing.T) {
5555
var _ = BeforeSuite(func() {
5656
logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true)))
5757

58-
ctx, cancel = context.WithCancel(context.TODO())
58+
testCtx, testCancel = context.WithCancel(context.TODO())
5959

6060
By("bootstrapping test environment")
6161
testEnv = &envtest.Environment{
@@ -72,9 +72,9 @@ var _ = BeforeSuite(func() {
7272

7373
//+kubebuilder:scaffold:scheme
7474

75-
k8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme})
75+
testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme})
7676
Expect(err).NotTo(HaveOccurred())
77-
Expect(k8sClient).NotTo(BeNil())
77+
Expect(testK8sClient).NotTo(BeNil())
7878

7979
mgr, err := ctrl.NewManager(cfg, ctrl.Options{
8080
MetricsBindAddress: "0",
@@ -87,14 +87,14 @@ var _ = BeforeSuite(func() {
8787

8888
go func() {
8989
defer GinkgoRecover()
90-
err = mgr.Start(ctx)
90+
err = mgr.Start(testCtx)
9191
Expect(err).ToNot(HaveOccurred(), "failed to run manager")
9292
}()
9393
}, 60)
9494

9595
var _ = AfterSuite(func() {
9696
By("tearing down the test environment")
97-
cancel()
97+
testCancel()
9898
err := testEnv.Stop()
9999
Expect(err).NotTo(HaveOccurred())
100100
})

pkg/controller.v1/pytorch/pytorchjob_controller_test.go

Lines changed: 152 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,17 @@ import (
3232
var _ = Describe("PyTorchJob controller", func() {
3333
// Define utility constants for object names and testing timeouts/durations and intervals.
3434
const (
35-
namespace = "default"
36-
name = "test-job"
37-
38-
timeout = time.Second * 10
39-
interval = time.Millisecond * 250
35+
timeout = time.Second * 10
36+
interval = time.Millisecond * 250
37+
expectedPort = int32(8080)
4038
)
4139

4240
Context("When creating the PyTorchJob", func() {
4341
It("Should get the corresponding resources successfully", func() {
42+
const (
43+
namespace = "default"
44+
name = "test-job"
45+
)
4446
By("By creating a new PyTorchJob")
4547
ctx := context.Background()
4648
job := newPyTorchJobForTest(name, namespace)
@@ -56,7 +58,7 @@ var _ = Describe("PyTorchJob controller", func() {
5658
Ports: []corev1.ContainerPort{
5759
{
5860
Name: pytorchv1.DefaultPortName,
59-
ContainerPort: 80,
61+
ContainerPort: expectedPort,
6062
Protocol: corev1.ProtocolTCP,
6163
},
6264
},
@@ -76,7 +78,7 @@ var _ = Describe("PyTorchJob controller", func() {
7678
Ports: []corev1.ContainerPort{
7779
{
7880
Name: pytorchv1.DefaultPortName,
79-
ContainerPort: 80,
81+
ContainerPort: expectedPort,
8082
Protocol: corev1.ProtocolTCP,
8183
},
8284
},
@@ -87,41 +89,46 @@ var _ = Describe("PyTorchJob controller", func() {
8789
},
8890
}
8991

90-
Expect(k8sClient.Create(ctx, job)).Should(Succeed())
92+
Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
9193

9294
key := types.NamespacedName{Name: name, Namespace: namespace}
9395
created := &pytorchv1.PyTorchJob{}
9496

9597
// We'll need to retry getting this newly created PyTorchJob, given that creation may not immediately happen.
9698
Eventually(func() bool {
97-
err := k8sClient.Get(ctx, key, created)
99+
err := testK8sClient.Get(ctx, key, created)
98100
return err == nil
99101
}, timeout, interval).Should(BeTrue())
100102

101103
masterKey := types.NamespacedName{Name: fmt.Sprintf("%s-master-0", name), Namespace: namespace}
102104
masterPod := &corev1.Pod{}
103105
Eventually(func() bool {
104-
err := k8sClient.Get(ctx, masterKey, masterPod)
106+
err := testK8sClient.Get(ctx, masterKey, masterPod)
105107
return err == nil
106108
}, timeout, interval).Should(BeTrue())
107109

108110
masterSvc := &corev1.Service{}
109111
Eventually(func() bool {
110-
err := k8sClient.Get(ctx, masterKey, masterSvc)
112+
err := testK8sClient.Get(ctx, masterKey, masterSvc)
111113
return err == nil
112114
}, timeout, interval).Should(BeTrue())
113115

116+
// Check the pod port.
114117
Expect(masterPod.Spec.Containers[0].Ports).To(ContainElement(corev1.ContainerPort{
115118
Name: pytorchv1.DefaultPortName,
116-
ContainerPort: 80,
119+
ContainerPort: expectedPort,
117120
Protocol: corev1.ProtocolTCP}))
121+
// Check MASTER_PORT and MASTER_ADDR env variable
118122
Expect(masterPod.Spec.Containers[0].Env).To(ContainElements(corev1.EnvVar{
119123
Name: EnvMasterPort,
120124
Value: fmt.Sprintf("%d", masterSvc.Spec.Ports[0].Port),
121125
}, corev1.EnvVar{
122126
Name: EnvMasterAddr,
123127
Value: masterSvc.Name,
124128
}))
129+
// Check service port.
130+
Expect(masterSvc.Spec.Ports[0].Port).To(Equal(expectedPort))
131+
// Check owner reference.
125132
trueVal := true
126133
Expect(masterPod.OwnerReferences).To(ContainElement(metav1.OwnerReference{
127134
APIVersion: pytorchv1.SchemeGroupVersion.String(),
@@ -143,19 +150,148 @@ var _ = Describe("PyTorchJob controller", func() {
143150
// Test job status.
144151
masterPod.Status.Phase = corev1.PodSucceeded
145152
masterPod.ResourceVersion = ""
146-
Expect(k8sClient.Status().Update(ctx, masterPod)).Should(Succeed())
153+
Expect(testK8sClient.Status().Update(ctx, masterPod)).Should(Succeed())
154+
Eventually(func() bool {
155+
err := testK8sClient.Get(ctx, key, created)
156+
if err != nil {
157+
return false
158+
}
159+
return created.Status.ReplicaStatuses != nil && created.Status.
160+
ReplicaStatuses[pytorchv1.PyTorchReplicaTypeMaster].Succeeded == 1
161+
}, timeout, interval).Should(BeTrue())
162+
// Check if the job is succeeded.
163+
cond := getCondition(created.Status, commonv1.JobSucceeded)
164+
Expect(cond.Status).To(Equal(corev1.ConditionTrue))
165+
By("Deleting the PyTorchJob")
166+
Expect(testK8sClient.Delete(ctx, job)).Should(Succeed())
167+
})
168+
})
169+
170+
Context("When creating the elastic PyTorchJob", func() {
171+
// TODO(gaocegege): Test with more than 1 worker.
172+
It("Should get the corresponding resources successfully", func() {
173+
// Define the expected elastic policy.
174+
var (
175+
backendC10D = pytorchv1.BackendC10D
176+
minReplicas = int32Ptr(1)
177+
maxReplicas = int32Ptr(3)
178+
maxRestarts = int32Ptr(3)
179+
namespace = "default"
180+
name = "easltic-job"
181+
)
182+
183+
By("By creating a new PyTorchJob")
184+
ctx := context.Background()
185+
job := newPyTorchJobForTest(name, namespace)
186+
job.Spec.ElasticPolicy = &pytorchv1.ElasticPolicy{
187+
RDZVBackend: &backendC10D,
188+
MaxReplicas: maxReplicas,
189+
MinReplicas: minReplicas,
190+
MaxRestarts: maxRestarts,
191+
}
192+
job.Spec.PyTorchReplicaSpecs = map[commonv1.ReplicaType]*commonv1.ReplicaSpec{
193+
pytorchv1.PyTorchReplicaTypeWorker: {
194+
Replicas: int32Ptr(1),
195+
Template: corev1.PodTemplateSpec{
196+
Spec: corev1.PodSpec{
197+
Containers: []corev1.Container{
198+
{
199+
Image: "test-image",
200+
Name: pytorchv1.DefaultContainerName,
201+
Ports: []corev1.ContainerPort{
202+
{
203+
Name: pytorchv1.DefaultPortName,
204+
ContainerPort: expectedPort,
205+
Protocol: corev1.ProtocolTCP,
206+
},
207+
},
208+
},
209+
},
210+
},
211+
},
212+
},
213+
}
214+
215+
Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
216+
217+
key := types.NamespacedName{Name: name, Namespace: namespace}
218+
created := &pytorchv1.PyTorchJob{}
219+
220+
// We'll need to retry getting this newly created PyTorchJob, given that creation may not immediately happen.
221+
Eventually(func() bool {
222+
err := testK8sClient.Get(ctx, key, created)
223+
return err == nil
224+
}, timeout, interval).Should(BeTrue())
225+
226+
workerKey := types.NamespacedName{Name: fmt.Sprintf("%s-worker-0", name), Namespace: namespace}
227+
pod := &corev1.Pod{}
228+
Eventually(func() bool {
229+
err := testK8sClient.Get(ctx, workerKey, pod)
230+
return err == nil
231+
}, timeout, interval).Should(BeTrue())
232+
233+
svc := &corev1.Service{}
234+
Eventually(func() bool {
235+
err := testK8sClient.Get(ctx, workerKey, svc)
236+
return err == nil
237+
}, timeout, interval).Should(BeTrue())
238+
239+
// Check pod port.
240+
Expect(pod.Spec.Containers[0].Ports).To(ContainElement(corev1.ContainerPort{
241+
Name: pytorchv1.DefaultPortName,
242+
ContainerPort: expectedPort,
243+
Protocol: corev1.ProtocolTCP}))
244+
// Check environment variables.
245+
Expect(pod.Spec.Containers[0].Env).To(ContainElements(corev1.EnvVar{
246+
Name: EnvRDZVBackend,
247+
Value: string(backendC10D),
248+
}, corev1.EnvVar{
249+
Name: EnvNNodes,
250+
Value: fmt.Sprintf("%d:%d", *minReplicas, *maxReplicas),
251+
}, corev1.EnvVar{
252+
Name: EnvRDZVEndpoint,
253+
Value: fmt.Sprintf("%s:%d", svc.Name, expectedPort),
254+
}, corev1.EnvVar{
255+
Name: EnvMaxRestarts,
256+
Value: fmt.Sprintf("%d", *maxRestarts),
257+
}))
258+
Expect(svc.Spec.Ports[0].Port).To(Equal(expectedPort))
259+
// Check owner references.
260+
trueVal := true
261+
Expect(pod.OwnerReferences).To(ContainElement(metav1.OwnerReference{
262+
APIVersion: pytorchv1.SchemeGroupVersion.String(),
263+
Kind: pytorchv1.Kind,
264+
Name: name,
265+
UID: created.UID,
266+
Controller: &trueVal,
267+
BlockOwnerDeletion: &trueVal,
268+
}))
269+
Expect(svc.OwnerReferences).To(ContainElement(metav1.OwnerReference{
270+
APIVersion: pytorchv1.SchemeGroupVersion.String(),
271+
Kind: pytorchv1.Kind,
272+
Name: name,
273+
UID: created.UID,
274+
Controller: &trueVal,
275+
BlockOwnerDeletion: &trueVal,
276+
}))
277+
278+
// Test job status.
279+
pod.Status.Phase = corev1.PodSucceeded
280+
pod.ResourceVersion = ""
281+
Expect(testK8sClient.Status().Update(ctx, pod)).Should(Succeed())
147282
Eventually(func() bool {
148-
err := k8sClient.Get(ctx, key, created)
283+
err := testK8sClient.Get(ctx, key, created)
149284
if err != nil {
150285
return false
151286
}
152-
return created.Status.ReplicaStatuses[pytorchv1.PyTorchReplicaTypeMaster].Succeeded == 1
287+
return created.Status.ReplicaStatuses != nil && created.Status.
288+
ReplicaStatuses[pytorchv1.PyTorchReplicaTypeWorker].Succeeded == 1
153289
}, timeout, interval).Should(BeTrue())
154290
// Check if the job is succeeded.
155291
cond := getCondition(created.Status, commonv1.JobSucceeded)
156292
Expect(cond.Status).To(Equal(corev1.ConditionTrue))
157293
By("Deleting the PyTorchJob")
158-
Expect(k8sClient.Delete(ctx, job)).Should(Succeed())
294+
Expect(testK8sClient.Delete(ctx, job)).Should(Succeed())
159295
})
160296
})
161297
})

0 commit comments

Comments
 (0)