diff --git a/torchx/components/dist.py b/torchx/components/dist.py index e369ec5b7..b8c2654f7 100644 --- a/torchx/components/dist.py +++ b/torchx/components/dist.py @@ -122,7 +122,7 @@ ----------------- """ from pathlib import Path -from typing import Optional +from typing import Dict, Optional import torchx import torchx.specs as specs @@ -139,6 +139,7 @@ def ddp( memMB: int = 1024, h: Optional[str] = None, j: str = "1x2", + env: Optional[Dict[str, str]] = None, rdzv_endpoint: str = "etcd-server.default.svc.cluster.local:2379", ) -> specs.AppDef: """ @@ -160,6 +161,7 @@ def ddp( memMB: cpu memory in MB per replica h: a registered named resource (if specified takes precedence over cpu, gpu, memMB) j: {nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus + env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3) rdzv_endpoint: etcd server endpoint (only matters when nnodes > 1) """ @@ -199,6 +201,7 @@ def ddp( script, *script_args, ], + env=env or {}, ) ], )