Skip to content

Commit 8139e41

Browse files
Testing CI in JAX example (kubeflow/trainer#2385)
* Add MNIST example with SPMD for JAX Illustrate how to use JAX's `pmap` to express and execute single-program multiple-data (SPMD) programs for data parallelism along a batch dimension Signed-off-by: Sandipan Panda <[email protected]> * Update CONTRIBUTING.md Use -- server-side to install the latest local changes of Training Operator control plane Signed-off-by: Sandipan Panda <[email protected]> * Add JAXJob output Signed-off-by: Sandipan Panda <[email protected]> * Update JAXJob CI images Signed-off-by: Sandipan Panda <[email protected]> * Adjust jaxjob spmd example batch size Signed-off-by: Sandipan Panda <[email protected]> * Add JAX Example Docker Image Build in CI Signed-off-by: sailesh duddupudi <[email protected]> * Fix script name typo Signed-off-by: sailesh duddupudi <[email protected]> * Update script permissions Signed-off-by: sailesh duddupudi <[email protected]> * Add KIND_CLUSTER env var Signed-off-by: sailesh duddupudi <[email protected]> * Increase timeouts Signed-off-by: sailesh duddupudi <[email protected]> * Test higher resources Signed-off-by: sailesh duddupudi <[email protected]> * Increase Timeout Signed-off-by: sailesh duddupudi <[email protected]> * remove resource reqs Signed-off-by: sailesh duddupudi <[email protected]> * test low batch size Signed-off-by: sailesh duddupudi <[email protected]> * test small batch size Signed-off-by: sailesh duddupudi <[email protected]> * Hardcode number of batches Signed-off-by: sailesh duddupudi <[email protected]> --------- Signed-off-by: Sandipan Panda <[email protected]> Signed-off-by: sailesh duddupudi <[email protected]> Co-authored-by: Sandipan Panda <[email protected]>
1 parent 2f2c5ba commit 8139e41

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

python/kubeflow/training/constants/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@
153153
JAXJOB_PLURAL = "jaxjobs"
154154
JAXJOB_CONTAINER = "jax"
155155
JAXJOB_REPLICA_TYPES = REPLICA_TYPE_WORKER.lower()
156-
JAXJOB_BASE_IMAGE = "docker.io/kubeflow/jaxjob-simple:latest"
156+
JAXJOB_BASE_IMAGE = "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest"
157157

158158
# Dictionary to get plural, model, and container for each Job kind.
159159
JOB_PARAMETERS = {

python/test/e2e/test_e2e_jaxjob.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def generate_jaxjob(
155155
def generate_container() -> V1Container:
156156
return V1Container(
157157
name=CONTAINER_NAME,
158-
image="docker.io/kubeflow/jaxjob-simple:latest",
159-
command=["python", "train.py"],
160-
resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}),
158+
image=os.getenv("JAX_JOB_IMAGE", "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest"),
159+
resources=V1ResourceRequirements(limits={"memory": "3Gi", "cpu": "1.2"}),
161160
)

0 commit comments

Comments
 (0)