Skip to content

Commit eb857a9

Browse files
committed
kubernetes_scheduler: add service_account runopt so users can specify per job acls
1 parent 1f1a50f commit eb857a9

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed

torchx/pipelines/kfp/adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def container_from_app(
235235
def resource_from_app(
236236
app: api.AppDef,
237237
queue: str,
238+
service_account: Optional[str] = None,
238239
) -> dsl.ResourceOp:
239240
"""
240241
resource_from_app generates a KFP ResourceOp from the provided app that uses
@@ -266,5 +267,5 @@ def resource_from_app(
266267
action="create",
267268
success_condition="status.state.phase = Completed",
268269
failure_condition="status.state.phase = Failed",
269-
k8s_resource=app_to_resource(app, queue),
270+
k8s_resource=app_to_resource(app, queue, service_account=service_account),
270271
)

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def sanitize_for_serialization(obj: object) -> object:
157157
return api.sanitize_for_serialization(obj)
158158

159159

160-
def role_to_pod(name: str, role: Role) -> "V1Pod":
160+
def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod":
161161
from kubernetes.client.models import ( # noqa: F811 redefinition of unused
162162
V1Pod,
163163
V1PodSpec,
@@ -207,6 +207,7 @@ def role_to_pod(name: str, role: Role) -> "V1Pod":
207207
spec=V1PodSpec(
208208
containers=[container],
209209
restart_policy="Never",
210+
service_account_name=service_account,
210211
),
211212
metadata=V1ObjectMeta(
212213
annotations={
@@ -232,7 +233,9 @@ def cleanup_str(data: str) -> str:
232233
return "".join(re.findall(pattern, data.lower()))
233234

234235

235-
def app_to_resource(app: AppDef, queue: str) -> Dict[str, object]:
236+
def app_to_resource(
237+
app: AppDef, queue: str, service_account: Optional[str]
238+
) -> Dict[str, object]:
236239
"""
237240
app_to_resource creates a volcano job kubernetes resource definition from
238241
the provided AppDef. The resource definition can be used to launch the
@@ -263,7 +266,7 @@ def app_to_resource(app: AppDef, queue: str) -> Dict[str, object]:
263266
if role_idx == 0 and replica_id == 0:
264267
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
265268

266-
pod = role_to_pod(name, replica_role)
269+
pod = role_to_pod(name, replica_role, service_account)
267270
pod.metadata.labels.update(pod_labels(app, role_idx, role, replica_id))
268271
task: Dict[str, Any] = {
269272
"replicas": 1,
@@ -437,7 +440,12 @@ def _submit_dryrun(
437440
# map any local images to the remote image
438441
images_to_push = self._update_app_images(app, cfg)
439442

440-
resource = app_to_resource(app, queue)
443+
service_account = cfg.get("service_account")
444+
assert service_account is None or isinstance(
445+
service_account, str
446+
), "service_account must be a str"
447+
448+
resource = app_to_resource(app, queue, service_account)
441449
req = KubernetesJob(
442450
resource=resource,
443451
images_to_push=images_to_push,
@@ -477,6 +485,11 @@ def run_opts(self) -> runopts:
477485
type_=str,
478486
help="The image repository to use when pushing patched images, must have push access. Ex: example.com/your/container",
479487
)
488+
opts.add(
489+
"service_account",
490+
type_=str,
491+
help="The service account name to set on the pod specs",
492+
)
480493
return opts
481494

482495
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:

torchx/schedulers/test/kubernetes_scheduler_test.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_app_to_resource_resolved_macros(self) -> None:
6868
"torchx.schedulers.kubernetes_scheduler.make_unique"
6969
) as make_unique_ctx:
7070
make_unique_ctx.return_value = unique_app_name
71-
resource = app_to_resource(app, "test_queue")
71+
resource = app_to_resource(app, "test_queue", service_account=None)
7272
actual_cmd = (
7373
# pyre-ignore [16]
7474
resource["spec"]["tasks"][0]["template"]
@@ -88,7 +88,7 @@ def test_app_to_resource_resolved_macros(self) -> None:
8888

8989
def test_retry_policy_not_set(self) -> None:
9090
app = _test_app()
91-
resource = app_to_resource(app, "test_queue")
91+
resource = app_to_resource(app, "test_queue", service_account=None)
9292
self.assertListEqual(
9393
[
9494
{"event": "PodEvicted", "action": "RestartJob"},
@@ -99,7 +99,7 @@ def test_retry_policy_not_set(self) -> None:
9999
)
100100
for role in app.roles:
101101
role.max_retries = 0
102-
resource = app_to_resource(app, "test_queue")
102+
resource = app_to_resource(app, "test_queue", service_account=None)
103103
self.assertFalse("policies" in resource["spec"]["tasks"][0])
104104
self.assertFalse("maxRetry" in resource["spec"]["tasks"][0])
105105

@@ -115,7 +115,7 @@ def test_role_to_pod(self) -> None:
115115
)
116116

117117
app = _test_app()
118-
pod = role_to_pod("name", app.roles[0])
118+
pod = role_to_pod("name", app.roles[0], service_account="srvacc")
119119

120120
requests = {
121121
"cpu": "2000m",
@@ -146,6 +146,7 @@ def test_role_to_pod(self) -> None:
146146
spec=V1PodSpec(
147147
containers=[container],
148148
restart_policy="Never",
149+
service_account_name="srvacc",
149150
),
150151
metadata=V1ObjectMeta(
151152
annotations={
@@ -298,6 +299,22 @@ def test_submit_dryrun_patch(self) -> None:
298299
},
299300
)
300301

302+
def test_submit_dryrun_service_account(self) -> None:
303+
scheduler = create_scheduler("test")
304+
self.assertIn("service_account", scheduler.run_opts()._opts)
305+
app = _test_app()
306+
cfg = {
307+
"queue": "testqueue",
308+
"service_account": "srvacc",
309+
}
310+
311+
info = scheduler._submit_dryrun(app, cfg)
312+
self.assertIn("'service_account_name': 'srvacc'", str(info.request.resource))
313+
314+
del cfg["service_account"]
315+
info = scheduler._submit_dryrun(app, cfg)
316+
self.assertIn("service_account_name': None", str(info.request.resource))
317+
301318
@patch("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object")
302319
def test_submit(self, create_namespaced_custom_object: MagicMock) -> None:
303320
create_namespaced_custom_object.return_value = {
@@ -426,6 +443,7 @@ def test_runopts(self) -> None:
426443
"queue",
427444
"namespace",
428445
"image_repo",
446+
"service_account",
429447
},
430448
)
431449

0 commit comments

Comments
 (0)