Skip to content

Commit ea963b1

Browse files
committed
kubernetes_scheduler: add service_account runopt so users can specify per job acls
1 parent 1f1a50f commit ea963b1

File tree

3 files changed

+48
-13
lines changed

3 files changed

+48
-13
lines changed

torchx/pipelines/kfp/adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def container_from_app(
235235
def resource_from_app(
236236
app: api.AppDef,
237237
queue: str,
238+
service_account: Optional[str] = None,
238239
) -> dsl.ResourceOp:
239240
"""
240241
resource_from_app generates a KFP ResourceOp from the provided app that uses
@@ -266,5 +267,5 @@ def resource_from_app(
266267
action="create",
267268
success_condition="status.state.phase = Completed",
268269
failure_condition="status.state.phase = Failed",
269-
k8s_resource=app_to_resource(app, queue),
270+
k8s_resource=app_to_resource(app, queue, service_account=service_account),
270271
)

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def sanitize_for_serialization(obj: object) -> object:
157157
return api.sanitize_for_serialization(obj)
158158

159159

160-
def role_to_pod(name: str, role: Role) -> "V1Pod":
160+
def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod":
161161
from kubernetes.client.models import ( # noqa: F811 redefinition of unused
162162
V1Pod,
163163
V1PodSpec,
@@ -207,6 +207,7 @@ def role_to_pod(name: str, role: Role) -> "V1Pod":
207207
spec=V1PodSpec(
208208
containers=[container],
209209
restart_policy="Never",
210+
service_account_name=service_account,
210211
),
211212
metadata=V1ObjectMeta(
212213
annotations={
@@ -232,7 +233,9 @@ def cleanup_str(data: str) -> str:
232233
return "".join(re.findall(pattern, data.lower()))
233234

234235

235-
def app_to_resource(app: AppDef, queue: str) -> Dict[str, object]:
236+
def app_to_resource(
237+
app: AppDef, queue: str, service_account: Optional[str]
238+
) -> Dict[str, object]:
236239
"""
237240
app_to_resource creates a volcano job kubernetes resource definition from
238241
the provided AppDef. The resource definition can be used to launch the
@@ -263,7 +266,7 @@ def app_to_resource(app: AppDef, queue: str) -> Dict[str, object]:
263266
if role_idx == 0 and replica_id == 0:
264267
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
265268

266-
pod = role_to_pod(name, replica_role)
269+
pod = role_to_pod(name, replica_role, service_account)
267270
pod.metadata.labels.update(pod_labels(app, role_idx, role, replica_id))
268271
task: Dict[str, Any] = {
269272
"replicas": 1,
@@ -437,7 +440,12 @@ def _submit_dryrun(
437440
# map any local images to the remote image
438441
images_to_push = self._update_app_images(app, cfg)
439442

440-
resource = app_to_resource(app, queue)
443+
service_account = cfg.get("service_account")
444+
assert service_account is None or isinstance(
445+
service_account, str
446+
), "service_account must be a str"
447+
448+
resource = app_to_resource(app, queue, service_account)
441449
req = KubernetesJob(
442450
resource=resource,
443451
images_to_push=images_to_push,
@@ -470,13 +478,21 @@ def run_opts(self) -> runopts:
470478
default="default",
471479
)
472480
opts.add(
473-
"queue", type_=str, help="Volcano queue to schedule job in", required=True
481+
"queue",
482+
type_=str,
483+
help="Volcano queue to schedule job in",
484+
required=True,
474485
)
475486
opts.add(
476487
"image_repo",
477488
type_=str,
478489
help="The image repository to use when pushing patched images, must have push access. Ex: example.com/your/container",
479490
)
491+
opts.add(
492+
"service_account",
493+
type_=str,
494+
help="The service account name to set on the pod specs",
495+
)
480496
return opts
481497

482498
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
@@ -540,7 +556,7 @@ def log_iter(
540556

541557
namespace, name = app_id.split(":")
542558

543-
pod_name = f"{name}-{role_name}-{k}-0"
559+
pod_name = cleanup_str(f"{name}-{role_name}-{k}-0")
544560

545561
args: Dict[str, object] = {
546562
"name": pod_name,

torchx/schedulers/test/kubernetes_scheduler_test.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_app_to_resource_resolved_macros(self) -> None:
6868
"torchx.schedulers.kubernetes_scheduler.make_unique"
6969
) as make_unique_ctx:
7070
make_unique_ctx.return_value = unique_app_name
71-
resource = app_to_resource(app, "test_queue")
71+
resource = app_to_resource(app, "test_queue", service_account=None)
7272
actual_cmd = (
7373
# pyre-ignore [16]
7474
resource["spec"]["tasks"][0]["template"]
@@ -88,7 +88,7 @@ def test_app_to_resource_resolved_macros(self) -> None:
8888

8989
def test_retry_policy_not_set(self) -> None:
9090
app = _test_app()
91-
resource = app_to_resource(app, "test_queue")
91+
resource = app_to_resource(app, "test_queue", service_account=None)
9292
self.assertListEqual(
9393
[
9494
{"event": "PodEvicted", "action": "RestartJob"},
@@ -99,7 +99,7 @@ def test_retry_policy_not_set(self) -> None:
9999
)
100100
for role in app.roles:
101101
role.max_retries = 0
102-
resource = app_to_resource(app, "test_queue")
102+
resource = app_to_resource(app, "test_queue", service_account=None)
103103
self.assertFalse("policies" in resource["spec"]["tasks"][0])
104104
self.assertFalse("maxRetry" in resource["spec"]["tasks"][0])
105105

@@ -115,7 +115,7 @@ def test_role_to_pod(self) -> None:
115115
)
116116

117117
app = _test_app()
118-
pod = role_to_pod("name", app.roles[0])
118+
pod = role_to_pod("name", app.roles[0], service_account="srvacc")
119119

120120
requests = {
121121
"cpu": "2000m",
@@ -146,6 +146,7 @@ def test_role_to_pod(self) -> None:
146146
spec=V1PodSpec(
147147
containers=[container],
148148
restart_policy="Never",
149+
service_account_name="srvacc",
149150
),
150151
metadata=V1ObjectMeta(
151152
annotations={
@@ -298,6 +299,22 @@ def test_submit_dryrun_patch(self) -> None:
298299
},
299300
)
300301

302+
def test_submit_dryrun_service_account(self) -> None:
303+
scheduler = create_scheduler("test")
304+
self.assertIn("service_account", scheduler.run_opts()._opts)
305+
app = _test_app()
306+
cfg = {
307+
"queue": "testqueue",
308+
"service_account": "srvacc",
309+
}
310+
311+
info = scheduler._submit_dryrun(app, cfg)
312+
self.assertIn("'service_account_name': 'srvacc'", str(info.request.resource))
313+
314+
del cfg["service_account"]
315+
info = scheduler._submit_dryrun(app, cfg)
316+
self.assertIn("service_account_name': None", str(info.request.resource))
317+
301318
@patch("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object")
302319
def test_submit(self, create_namespaced_custom_object: MagicMock) -> None:
303320
create_namespaced_custom_object.return_value = {
@@ -426,6 +443,7 @@ def test_runopts(self) -> None:
426443
"queue",
427444
"namespace",
428445
"image_repo",
446+
"service_account",
429447
},
430448
)
431449

@@ -452,7 +470,7 @@ def test_log_iter(self, read_namespaced_pod_log: MagicMock) -> None:
452470
read_namespaced_pod_log.return_value = "foo reg\nfoo\nbar reg\n"
453471
lines = scheduler.log_iter(
454472
app_id="testnamespace:testjob",
455-
role_name="role",
473+
role_name="role_blah",
456474
k=1,
457475
regex="reg",
458476
since=datetime.now(),
@@ -472,7 +490,7 @@ def test_log_iter(self, read_namespaced_pod_log: MagicMock) -> None:
472490
kwargs,
473491
{
474492
"namespace": "testnamespace",
475-
"name": "testjob-role-1-0",
493+
"name": "testjob-roleblah-1-0",
476494
"timestamps": True,
477495
},
478496
)

0 commit comments

Comments
 (0)