Skip to content

Commit 988301d

Browse files
committed
Fix trainjob controller not setting the trainer podset count value correctly
1 parent 50d89e0 commit 988301d

File tree

4 files changed

+80
-1
lines changed

4 files changed

+80
-1
lines changed

pkg/controller/jobs/trainjob/trainjob_controller.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,15 @@ func getChildJobSet(ctx context.Context, t *TrainJob) (*jobsetapi.JobSet, error)
177177
if !ok {
178178
return nil, err
179179
}
180+
181+
// Jobset replicaJob parallelism/completions are set outside of the jobset builder
182+
for psIdx, ps := range info.TemplateSpec.PodSets {
183+
if ps.Count != nil {
184+
jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Parallelism = ps.Count
185+
jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Completions = ps.Count
186+
}
187+
}
188+
180189
jobsetApply := kftrainerjobset.NewBuilder(jobsetapplyapi.JobSet(t.Name, t.Namespace).
181190
WithSpec(jobSetSpec)).Initializer(trainJob).Trainer(info, trainJob).PodLabels(info.Scheduler.PodLabels).Build()
182191

pkg/controller/jobs/trainjob/trainjob_controller_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ func TestRunWithPodsetsInfo(t *testing.T) {
7979
testJobset := testingjobset.MakeJobSet("", "").ReplicatedJobs(
8080
testingjobset.ReplicatedJobRequirements{
8181
Name: "node",
82+
Labels: map[string]string{
83+
"trainer.kubeflow.org/trainjob-ancestor-step": "trainer",
84+
},
8285
}).Obj()
8386
testCtr := testingtrainjob.MakeClusterTrainingRuntime("test", testJobset.Spec)
8487

@@ -361,6 +364,15 @@ func TestReconciler(t *testing.T) {
361364
Replicas: 1,
362365
Parallelism: 1,
363366
Completions: 1,
367+
Labels: map[string]string{
368+
"trainer.kubeflow.org/trainjob-ancestor-step": "trainer",
369+
},
370+
},
371+
testingjobset.ReplicatedJobRequirements{
372+
Name: "foo",
373+
Replicas: 1,
374+
Parallelism: 1,
375+
Completions: 1,
364376
}).Obj()
365377
testCtr := testingtrainjob.MakeClusterTrainingRuntime("test", testJobset.Spec)
366378

@@ -386,6 +398,35 @@ func TestReconciler(t *testing.T) {
386398
SubGroupIndexLabel(ptr.To(jobsetapi.JobIndexKey)).
387399
SubGroupCount(ptr.To[int32](1)).
388400
Obj(),
401+
*utiltestingapi.MakePodSet("foo", 1).
402+
PodIndexLabel(ptr.To("batch.kubernetes.io/job-completion-index")).
403+
SubGroupIndexLabel(ptr.To(jobsetapi.JobIndexKey)).
404+
SubGroupCount(ptr.To[int32](1)).
405+
Obj(),
406+
).
407+
Obj(),
408+
},
409+
},
410+
"podset count for the trainer job is set to .Spec.Trainer.NumNodes": {
411+
reconcilerOptions: []jobframework.Option{
412+
jobframework.WithManageJobsWithoutQueueName(true),
413+
jobframework.WithManagedJobsNamespaceSelector(labels.Everything()),
414+
},
415+
trainJob: testTrainJob.Clone().TrainerNumNodes(2).Obj(),
416+
wantTrainJob: testTrainJob.Clone().TrainerNumNodes(2).Obj(),
417+
wantWorkloads: []kueue.Workload{
418+
*utiltestingapi.MakeWorkload(testTrainJob.Name, testTrainJob.Namespace).
419+
PodSets(
420+
*utiltestingapi.MakePodSet("node", 2).
421+
PodIndexLabel(ptr.To("batch.kubernetes.io/job-completion-index")).
422+
SubGroupIndexLabel(ptr.To(jobsetapi.JobIndexKey)).
423+
SubGroupCount(ptr.To[int32](1)).
424+
Obj(),
425+
*utiltestingapi.MakePodSet("foo", 1).
426+
PodIndexLabel(ptr.To("batch.kubernetes.io/job-completion-index")).
427+
SubGroupIndexLabel(ptr.To(jobsetapi.JobIndexKey)).
428+
SubGroupCount(ptr.To[int32](1)).
429+
Obj(),
389430
).
390431
Obj(),
391432
},

pkg/util/testingjobs/trainjob/wrappers.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ func (t *TrainJobWrapper) TrainerImage(image string, cmd, args []string) *TrainJ
7979
return t
8080
}
8181

82+
// TrainerNumNodes sets a the number of nodes that will be used in the Trainer job
83+
func (t *TrainJobWrapper) TrainerNumNodes(numNodes int32) *TrainJobWrapper {
84+
t.Spec.Trainer.NumNodes = ptr.To(numNodes)
85+
return t
86+
}
87+
8288
// Label sets a Trainjob annotation key and value
8389
func (t *TrainJobWrapper) Annotation(key, value string) *TrainJobWrapper {
8490
if t.Annotations == nil {

test/integration/singlecluster/controller/jobs/trainjob/trainjob_controller_test.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ var _ = ginkgo.Describe("Trainjob controller", ginkgo.Ordered, ginkgo.ContinueOn
9292
testingjobset.ReplicatedJobRequirements{
9393
Name: "node",
9494
Replicas: 1,
95+
Labels: map[string]string{
96+
"trainer.kubeflow.org/trainjob-ancestor-step": "trainer",
97+
},
98+
},
99+
testingjobset.ReplicatedJobRequirements{
100+
Name: "foo",
101+
Replicas: 1,
95102
}).
96103
Obj()
97104
testCtr = testingtrainjob.MakeClusterTrainingRuntime("test", testJobSet.Spec)
@@ -126,6 +133,7 @@ var _ = ginkgo.Describe("Trainjob controller", ginkgo.Ordered, ginkgo.ContinueOn
126133
Name: "test",
127134
Kind: ptr.To("ClusterTrainingRuntime"),
128135
}).
136+
TrainerNumNodes(2).
129137
Suspend(false).
130138
Queue("local-queue").
131139
Obj()
@@ -136,11 +144,14 @@ var _ = ginkgo.Describe("Trainjob controller", ginkgo.Ordered, ginkgo.ContinueOn
136144
}, util.Timeout, util.Interval).Should(gomega.Succeed())
137145
})
138146

139-
ginkgo.By("checking the workload is created", func() {
147+
ginkgo.By("checking the workload is created with the correct values", func() {
140148
wlLookupKey = types.NamespacedName{Name: workloadtrainjob.GetWorkloadNameForTrainJob(createdTrainJob.Name, createdTrainJob.UID), Namespace: ns.Name}
141149
gomega.Eventually(func(g gomega.Gomega) {
142150
g.Expect(k8sClient.Get(ctx, wlLookupKey, createdWorkload)).Should(gomega.Succeed())
143151
g.Expect(createdWorkload.Spec.QueueName).Should(gomega.Equal(kueue.LocalQueueName("local-queue")))
152+
g.Expect(createdWorkload.Spec.PodSets).Should(gomega.HaveLen(2))
153+
g.Expect(createdWorkload.Spec.PodSets[0].Count).Should(gomega.Equal(int32(2)))
154+
g.Expect(createdWorkload.Spec.PodSets[1].Count).Should(gomega.Equal(int32(1)))
144155
}, util.Timeout, util.Interval).Should(gomega.Succeed())
145156
})
146157

@@ -160,6 +171,12 @@ var _ = ginkgo.Describe("Trainjob controller", ginkgo.Ordered, ginkgo.ContinueOn
160171
corev1.ResourceCPU: kueue.ResourceFlavorReference(onDemandFlavor.Name),
161172
},
162173
},
174+
kueue.PodSetAssignment{
175+
Name: createdWorkload.Spec.PodSets[1].Name,
176+
Flavors: map[corev1.ResourceName]kueue.ResourceFlavorReference{
177+
corev1.ResourceCPU: kueue.ResourceFlavorReference(onDemandFlavor.Name),
178+
},
179+
},
163180
).Obj()
164181
util.SetQuotaReservation(ctx, k8sClient, wlLookupKey, admission)
165182
util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload)
@@ -234,6 +251,12 @@ var _ = ginkgo.Describe("Trainjob controller", ginkgo.Ordered, ginkgo.ContinueOn
234251
corev1.ResourceCPU: kueue.ResourceFlavorReference(onDemandFlavor.Name),
235252
},
236253
},
254+
kueue.PodSetAssignment{
255+
Name: createdWorkload.Spec.PodSets[1].Name,
256+
Flavors: map[corev1.ResourceName]kueue.ResourceFlavorReference{
257+
corev1.ResourceCPU: kueue.ResourceFlavorReference(onDemandFlavor.Name),
258+
},
259+
},
237260
).Obj()
238261
util.SetQuotaReservation(ctx, k8sClient, wlLookupKey, admission)
239262
util.SyncAdmittedConditionForWorkloads(ctx, k8sClient, createdWorkload)

0 commit comments

Comments
 (0)