Skip to content

Commit e45b5e4

Browse files
jeffrajren73tjruwase
committed
ZeRO-Offload v1 (squash) (#345)
* update DSE to point to ZeRO-Offload staging * ZeRO-2 enable CPU offload (#313) * cpu-offload * update * deleted: deepspeed/pt/deepspeed_zero_optimizer_cpuoffload.py modified: deepspeed/pt/fp16_unfused_optimizer.py new file: install_output.txt modified: tests/unit/test_dynamic_loss_scale.py * modified: deepspeed/pt/deepspeed_zero_optimizer.py * update * modified: deepspeed/pt/deepspeed_cpu_adam.py modified: deepspeed/pt/deepspeed_zero_optimizer.py modified: tests/unit/test_checkpointing.py modified: tests/unit/test_fp16.py * deleted: install_output.txt * modified: deepspeed/pt/fp16_unfused_optimizer.py modified: tests/unit/test_dynamic_loss_scale.py * modified: deepspeed/pt/deepspeed_cpu_adam.py * modified: deepspeed/pt/deepspeed_zero_optimizer.py * modified: deepspeed/pt/deepspeed_cpu_adam.py modified: deepspeed/pt/deepspeed_zero_optimizer.py * deleted: deepspeed_cpu_adam.py modified: deepspeed_light.py modified: deepspeed_zero_optimizer.py ../../deepspeed_zero_optimizer_cpu_offload.py * modified: deepspeed/pt/deepspeed_light.py * modified: deepspeed/pt/deepspeed_light.py modified: deepspeed/pt/deepspeed_zero_optimizer.py modified: deepspeed/pt/deepspeed_zero_utils.py modified: tests/unit/test_fp16.py * modified: deepspeed/pt/deepspeed_config.py modified: deepspeed/pt/deepspeed_light.py modified: deepspeed/pt/deepspeed_zero_optimizer.py modified: tests/unit/test_checkpointing.py modified: tests/unit/test_fp16.py * modified: deepspeed/pt/deepspeed_checkpointing.py * update DSE to ZeRO-Offload commit Co-authored-by: Jeff Rasley <[email protected]> * Enable ZeRO checkpointing for ZeRO-Offload (#337) * Enable ZeRO checkpointing for ZeRO-Offload Fix unit tests Bump DSE to 33b9fb77c8cecdb49118188890f662526d8e9397 * Fix accidental revert * Add ZeRO-Offload checkpointing model tests (#344) * Enable ZeRO checkpointing for ZeRO-Offload Fix unit tests Bump DSE to 33b9fb77c8cecdb49118188890f662526d8e9397 * Fix accidental revert * Fix ZeRO-Offload checkpointing bug when change gpu count Add checkpointing model tests for ZeRO-Offload Remove optimizer key from Megatron model tests Use different deepspeed master port for Megatron model tests Co-authored-by: Jie <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 1661e83 commit e45b5e4

22 files changed

+564
-161
lines changed

DeepSpeedExamples

deepspeed/pt/deepspeed_zero_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from torch.autograd import Variable
3+
import collections
4+
5+
6+
def async_migrate_to(obj, dev, main_stream=None):
7+
if torch.is_tensor(obj):
8+
obj = Variable(obj)
9+
if isinstance(obj, Variable):
10+
v = obj.cuda(dev, async=True)
11+
if main_stream is not None:
12+
v.data.record_stream(main_stream)
13+
return v
14+
elif isinstance(obj, collections.Mapping):
15+
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
16+
elif isinstance(obj, collections.Sequence):
17+
return [async_copy_to(o, dev, main_stream) for o in obj]
18+
else:
19+
return obj
20+
21+
22+
def async_copy_to(obj, dev, main_stream=None):
23+
if torch.is_tensor(obj):
24+
obj = Variable(obj)
25+
if isinstance(obj, Variable):
26+
target = torch.empty_like(obj, device=dev).copy_(obj)
27+
if main_stream is not None:
28+
target.data.record_stream(main_stream)
29+
return target
30+
elif isinstance(obj, collections.Mapping):
31+
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
32+
elif isinstance(obj, collections.Sequence):
33+
return [async_copy_to(o, dev, main_stream) for o in obj]

deepspeed/runtime/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,12 @@ def _do_error_check(self):
630630
if self.zero_enabled:
631631
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
632632
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
633+
if self.zero_config.cpu_offload is True:
634+
assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)
633635

634636
assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
635637

636-
assert self.gradient_accumulation_steps, 'DeepSpeedConfig: {} is not defined'.format(
638+
assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format(
637639
GRADIENT_ACCUMULATION_STEPS)
638640

639641
def _do_warning_check(self):

deepspeed/runtime/engine.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def __init__(self,
106106
collate_fn=None,
107107
config_params=None):
108108
super(DeepSpeedEngine, self).__init__()
109-
110109
self.client_optimizer = optimizer
111110
self.client_model_parameters = model_parameters
112111
self.client_lr_scheduler = lr_scheduler
@@ -292,6 +291,9 @@ def zero_reduce_scatter(self):
292291
def zero_overlap_comm(self):
293292
return self._config.zero_config.overlap_comm
294293

294+
def zero_cpu_offload(self):
295+
return self._config.zero_config.cpu_offload
296+
295297
def zero_optimization_stage(self):
296298
return self._config.zero_optimization_stage
297299

@@ -491,6 +493,7 @@ def _configure_distributed_model(self, model):
491493

492494
# Configure optimizer
493495
def _configure_optimizer(self, client_optimizer, model_parameters):
496+
494497
if client_optimizer is not None:
495498
basic_optimizer = client_optimizer
496499
logger.info('Using client Optimizer as basic optimizer')
@@ -504,13 +507,14 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
504507

505508
if self.zero_optimization():
506509
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
507-
if self.optimizer_name() != ADAM_OPTIMIZER:
510+
if self.optimizer_name() not in [ADAM_OPTIMIZER]:
508511
assert self.zero_allow_untested_optimizer(), \
509512
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
510513

511514
logger.warning(
512515
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
513516
)
517+
514518
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
515519
elif self.amp_enabled():
516520
assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
@@ -522,8 +526,8 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
522526
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
523527
else:
524528
self.optimizer = basic_optimizer
525-
526-
# logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
529+
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer))
530+
logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict()))
527531

528532
def _configure_basic_optimizer(self, model_parameters):
529533
optimizer_parameters = self.optimizer_params()
@@ -532,8 +536,11 @@ def _configure_basic_optimizer(self, model_parameters):
532536
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
533537
)
534538
if self.optimizer_name() == ADAM_OPTIMIZER:
535-
from apex.optimizers.fused_adam import FusedAdam
536-
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
539+
if self.zero_cpu_offload():
540+
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
541+
else:
542+
from apex.optimizers.fused_adam import FusedAdam
543+
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
537544
elif self.optimizer_name() == LAMB_OPTIMIZER:
538545
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
539546
else:
@@ -610,6 +617,7 @@ def _configure_zero_optimizer(self, optimizer):
610617
dp_process_group=self.data_parallel_group,
611618
reduce_scatter=self.zero_reduce_scatter(),
612619
overlap_comm=self.zero_overlap_comm(),
620+
cpu_offload=self.zero_cpu_offload(),
613621
mpu=self.mpu,
614622
postscale_gradients=self.postscale_gradients(),
615623
gradient_predivide_factor=self.gradient_predivide_factor())
@@ -844,7 +852,6 @@ def step(self):
844852
master_params = amp.master_params(self.optimizer)
845853
torch.nn.utils.clip_grad_norm_(parameters=master_params,
846854
max_norm=self.gradient_clipping())
847-
848855
self.optimizer.step()
849856

850857
#zero grad in basic optimizer could be unreliable and may not exhibit
@@ -947,6 +954,9 @@ def _get_optimizer_param(self, param_name):
947954
def get_lr(self):
948955
return self._get_optimizer_param('lr')
949956

957+
def get_type(self):
958+
return self._get_optimizer_param('type')
959+
950960
def get_mom(self):
951961
return self._get_optimizer_param('betas')
952962

deepspeed/runtime/zero/config.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"overlap_comm": [true|false],
2424
"reduce_bucket_size": 500000000
2525
"load_from_fp32_weights": [true|false]
26+
"cpu_offload": [true|false]
2627
}
2728
}
2829
'''
@@ -62,21 +63,22 @@
6263
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights'
6364
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True
6465

66+
ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload'
67+
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False
68+
6569
ZERO_OPTIMIZATION_DEFAULT = {
66-
ZERO_OPTIMIZATION_STAGE:
67-
ZERO_OPTIMIZATION_STAGE_DEFAULT,
70+
ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT,
6871
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS:
6972
ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT,
70-
ZERO_OPTIMIZATION_REDUCE_SCATTER:
71-
ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
72-
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE:
73-
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
73+
ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT,
74+
ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT,
7475
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS:
7576
ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT,
7677
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE:
7778
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT,
7879
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS:
79-
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT
80+
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT,
81+
ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT
8082
}
8183

8284

@@ -92,6 +94,7 @@ def __init__(self, param_dict):
9294
self.allgather_bucket_size = None
9395
self.overlap_comm = None
9496
self.load_from_fp32_weights = None
97+
self.cpu_offload = None
9598

9699
if ZERO_OPTIMIZATION in param_dict.keys():
97100
zero_config_dict = param_dict[ZERO_OPTIMIZATION]
@@ -156,7 +159,12 @@ def _initialize(self, zero_config_dict):
156159
zero_config_dict,
157160
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE,
158161
ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT)
162+
159163
self.load_from_fp32_weights = get_scalar_param(
160164
zero_config_dict,
161165
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS,
162166
ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT)
167+
168+
self.cpu_offload = get_scalar_param(zero_config_dict,
169+
ZERO_OPTIMIZATION_CPU_OFFLOAD,
170+
ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT)

0 commit comments

Comments
 (0)