|
22 | 22 | from dstack.core.repo import RepoAddress, RepoData
|
23 | 23 | from dstack.utils.common import _quoted
|
24 | 24 |
|
| 25 | +DEFAULT_CPU = 2 |
| 26 | +DEFAULT_MEM = "8GB" |
| 27 | + |
25 | 28 |
|
26 | 29 | def _str_to_mib(s: str) -> int:
|
27 | 30 | ns = s.replace(" ", "").lower()
|
@@ -187,43 +190,34 @@ def _parse_base_args(self, args: Namespace, unknown_args):
|
187 | 190 | env = self.provider_data.get("env") or []
|
188 | 191 | env.extend(args.env)
|
189 | 192 | self.provider_data["env"] = env
|
190 |
| - if ( |
191 |
| - args.cpu |
192 |
| - or args.memory |
193 |
| - or args.gpu |
194 |
| - or args.gpu_name |
195 |
| - or args.gpu_memory |
196 |
| - or args.shm_size |
197 |
| - or args.interruptible |
198 |
| - or args.local |
199 |
| - ): |
200 |
| - resources = self.provider_data.get("resources") or {} |
201 |
| - self.provider_data["resources"] = resources |
202 |
| - if args.cpu: |
203 |
| - resources["cpu"] = args.cpu |
204 |
| - if args.memory: |
205 |
| - resources["memory"] = args.memory |
206 |
| - if args.gpu or args.gpu_name or args.gpu_memory: |
207 |
| - gpu = ( |
208 |
| - self.provider_data["resources"].get("gpu") or {} |
209 |
| - if self.provider_data.get("resources") |
210 |
| - else {} |
211 |
| - ) |
212 |
| - if type(gpu) is int: |
213 |
| - gpu = {"count": gpu} |
214 |
| - resources["gpu"] = gpu |
215 |
| - if args.gpu: |
216 |
| - gpu["count"] = args.gpu |
217 |
| - if args.gpu_memory: |
218 |
| - gpu["memory"] = args.gpu_memory |
219 |
| - if args.gpu_name: |
220 |
| - gpu["name"] = args.gpu_name |
221 |
| - if args.shm_size: |
222 |
| - resources["shm_size"] = args.shm_size |
223 |
| - if args.interruptible: |
224 |
| - resources["interruptible"] = True |
225 |
| - if args.local: |
226 |
| - resources["local"] = True |
| 193 | + |
| 194 | + resources = self.provider_data.get("resources") or {} |
| 195 | + self.provider_data["resources"] = resources |
| 196 | + if args.cpu: |
| 197 | + resources["cpu"] = args.cpu |
| 198 | + if args.memory: |
| 199 | + resources["memory"] = args.memory |
| 200 | + if args.gpu or args.gpu_name or args.gpu_memory: |
| 201 | + gpu = ( |
| 202 | + self.provider_data["resources"].get("gpu") or {} |
| 203 | + if self.provider_data.get("resources") |
| 204 | + else {} |
| 205 | + ) |
| 206 | + if type(gpu) is int: |
| 207 | + gpu = {"count": gpu} |
| 208 | + resources["gpu"] = gpu |
| 209 | + if args.gpu: |
| 210 | + gpu["count"] = args.gpu |
| 211 | + if args.gpu_memory: |
| 212 | + gpu["memory"] = args.gpu_memory |
| 213 | + if args.gpu_name: |
| 214 | + gpu["name"] = args.gpu_name |
| 215 | + if args.shm_size: |
| 216 | + resources["shm_size"] = args.shm_size |
| 217 | + if args.interruptible: |
| 218 | + resources["interruptible"] = True |
| 219 | + if args.local: |
| 220 | + resources["local"] = True |
227 | 221 | if unknown_args:
|
228 | 222 | self.provider_data["run_args"] = unknown_args
|
229 | 223 |
|
@@ -393,58 +387,47 @@ def _get_list_data(self, name: str) -> Optional[List[str]]:
|
393 | 387 | else:
|
394 | 388 | return v
|
395 | 389 |
|
396 |
| - def _resources(self) -> Optional[Requirements]: |
397 |
| - if self.provider_data.get("resources"): |
398 |
| - resources = Requirements() |
399 |
| - if self.provider_data["resources"].get("cpu"): |
400 |
| - if not str(self.provider_data["resources"]["cpu"]).isnumeric(): |
401 |
| - sys.exit("resources.cpu should be an integer") |
402 |
| - cpu = int(self.provider_data["resources"]["cpu"]) |
403 |
| - if cpu > 0: |
404 |
| - resources.cpus = cpu |
405 |
| - if self.provider_data["resources"].get("memory"): |
406 |
| - resources.memory_mib = _str_to_mib(self.provider_data["resources"]["memory"]) |
407 |
| - gpu = self.provider_data["resources"].get("gpu") |
408 |
| - if gpu: |
409 |
| - if str(gpu).isnumeric(): |
410 |
| - gpu = int(self.provider_data["resources"]["gpu"]) |
411 |
| - if gpu > 0: |
412 |
| - resources.gpus = GpusRequirements(count=gpu) |
413 |
| - else: |
414 |
| - gpu_count = 0 |
415 |
| - gpu_name = None |
416 |
| - if str(gpu.get("count")).isnumeric(): |
417 |
| - gpu_count = int(gpu.get("count")) |
418 |
| - if gpu.get("name"): |
419 |
| - gpu_name = gpu.get("name") |
420 |
| - if not gpu_count: |
421 |
| - gpu_count = 1 |
422 |
| - if gpu_count: |
423 |
| - resources.gpus = GpusRequirements(count=gpu_count, name=gpu_name) |
424 |
| - for resource_name in self.provider_data["resources"]: |
425 |
| - if resource_name.endswith("/gpu") and len(resource_name) > 4: |
426 |
| - if not str(self.provider_data["resources"][resource_name]).isnumeric(): |
427 |
| - sys.exit(f"resources.'{resource_name}' should be an integer") |
428 |
| - gpu = int(self.provider_data["resources"][resource_name]) |
429 |
| - if gpu > 0: |
430 |
| - resources.gpus = GpusRequirements(count=gpu, name=resource_name[:-4]) |
431 |
| - if self.provider_data["resources"].get("shm_size"): |
432 |
| - resources.shm_size_mib = _str_to_mib(self.provider_data["resources"]["shm_size"]) |
433 |
| - if self.provider_data["resources"].get("interruptible"): |
434 |
| - resources.interruptible = self.provider_data["resources"]["interruptible"] |
435 |
| - if self.provider_data["resources"].get("local"): |
436 |
| - resources.local = self.provider_data["resources"]["local"] |
437 |
| - if ( |
438 |
| - resources.cpus |
439 |
| - or resources.memory_mib |
440 |
| - or resources.gpus |
441 |
| - or resources.shm_size_mib |
442 |
| - or resources.interruptible |
443 |
| - or resources.local |
444 |
| - ): |
445 |
| - return resources |
| 390 | + def _resources(self) -> Requirements: |
| 391 | + resources = Requirements() |
| 392 | + cpu = self.provider_data["resources"].get("cpu", DEFAULT_CPU) |
| 393 | + if not str(cpu).isnumeric(): |
| 394 | + sys.exit("resources.cpu should be an integer") |
| 395 | + cpu = int(cpu) |
| 396 | + if cpu > 0: |
| 397 | + resources.cpus = cpu |
| 398 | + memory = self.provider_data["resources"].get("memory", DEFAULT_MEM) |
| 399 | + resources.memory_mib = _str_to_mib(memory) |
| 400 | + gpu = self.provider_data["resources"].get("gpu") |
| 401 | + if gpu: |
| 402 | + if str(gpu).isnumeric(): |
| 403 | + gpu = int(self.provider_data["resources"]["gpu"]) |
| 404 | + if gpu > 0: |
| 405 | + resources.gpus = GpusRequirements(count=gpu) |
446 | 406 | else:
|
447 |
| - return None |
| 407 | + gpu_count = 0 |
| 408 | + gpu_name = None |
| 409 | + if str(gpu.get("count")).isnumeric(): |
| 410 | + gpu_count = int(gpu.get("count")) |
| 411 | + if gpu.get("name"): |
| 412 | + gpu_name = gpu.get("name") |
| 413 | + if not gpu_count: |
| 414 | + gpu_count = 1 |
| 415 | + if gpu_count: |
| 416 | + resources.gpus = GpusRequirements(count=gpu_count, name=gpu_name) |
| 417 | + for resource_name in self.provider_data["resources"]: |
| 418 | + if resource_name.endswith("/gpu") and len(resource_name) > 4: |
| 419 | + if not str(self.provider_data["resources"][resource_name]).isnumeric(): |
| 420 | + sys.exit(f"resources.'{resource_name}' should be an integer") |
| 421 | + gpu = int(self.provider_data["resources"][resource_name]) |
| 422 | + if gpu > 0: |
| 423 | + resources.gpus = GpusRequirements(count=gpu, name=resource_name[:-4]) |
| 424 | + if self.provider_data["resources"].get("shm_size"): |
| 425 | + resources.shm_size_mib = _str_to_mib(self.provider_data["resources"]["shm_size"]) |
| 426 | + if self.provider_data["resources"].get("interruptible"): |
| 427 | + resources.interruptible = self.provider_data["resources"]["interruptible"] |
| 428 | + if self.provider_data["resources"].get("local"): |
| 429 | + resources.local = self.provider_data["resources"]["local"] |
| 430 | + return resources |
448 | 431 |
|
449 | 432 | @staticmethod
|
450 | 433 | def _extend_commands_with_env(commands, env):
|
|
0 commit comments