Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 810e67c

Browse files
Caenorstapeforest
authored andcommitted
Add fast implementation of LARS (#16122)
* add MXNet operator for fast LARS * add unit tests for fast LARS related MXNet Ops * fix preloaded_multi_* dtype inference, add SGDwFastLARS optimizer and test Conflicts: tests/python/gpu/test_operator_gpu.py * remove commented out cast from lenet5 model * fix lint * Add more documentation, change name of SGDwFastLARS by LARS, removing redundancy of 'lars' in the parameters * change optimizer code to be python2 retro-compatible * fix lint * replace push_back by emplace_back for cland-tidy
1 parent 66f1656 commit 810e67c

File tree

12 files changed

+1682
-4
lines changed

12 files changed

+1682
-4
lines changed

python/mxnet/optimizer/optimizer.py

Lines changed: 267 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,21 @@
2626
import os
2727
import numpy
2828
from ..base import py_str
29-
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
29+
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply,
30+
multi_sum_sq, multi_lars, norm as NDnorm)
3031
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
3132
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
3233
signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update,
3334
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
34-
multi_mp_sgd_mom_update)
35+
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
36+
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
37+
preloaded_multi_mp_sgd_mom_update)
3538
from ..ndarray import sparse
3639
from ..random import normal
3740
from ..util import is_np_array
3841

3942
__all__ = [
40-
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD',
43+
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD',
4144
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum',
4245
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
4346
]
@@ -781,6 +784,266 @@ def update(self, index, weight, grad, state):
781784
ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight,
782785
lr=lr, wd=wd, **kwargs)
783786

787+
@register
788+
class LARS(Optimizer):
789+
"""the LARS optimizer from 'Large Batch Training of Convolution Networks' \
790+
(https://arxiv.org/abs/1708.03888)
791+
792+
Behave mostly like SGD with momentum and weight decay but is scaling \
793+
adaptively the learning for each layer (except bias and batch norm parameters):
794+
w_norm = L2norm(weights)
795+
g_norm = L2norm(gradients)
796+
if w_norm > 0 and g_norm > 0:
797+
lr_layer = lr * lr_mult * eta * w_norm / (g_norm + weight_decay * w_norm + eps)
798+
else:
799+
lr_layer = lr * lr_mult
800+
801+
Parameters
802+
----------
803+
momentum : float, optional
804+
The momentum value.
805+
lazy_update : bool, optional
806+
Default is True. If True, lazy updates are applied \
807+
if the storage types of weight and grad are both ``row_sparse``.
808+
lars_eta : float, optional
809+
LARS coefficient used to scale the learning rate. Default set to 0.001.
810+
lars_epsilon : float, optional
811+
Optional epsilon in case of very small gradients. Default set to 0.
812+
momentum_correction : bool, optional
813+
If True scale momentum w.r.t global learning rate change (with an lr_scheduler) \
814+
as indicated in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour` \
815+
(https://arxiv.org/pdf/1706.02677.pdf)
816+
Default set to True.
817+
"""
818+
def __init__(self, momentum=0.0, lazy_update=True, eta=0.001, eps=0,
819+
momentum_correction=True, **kwargs):
820+
super(LARS, self).__init__(**kwargs)
821+
self.momentum = momentum
822+
self.momentum_correction = momentum_correction
823+
self.lazy_update = lazy_update
824+
self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4"))
825+
self.eta = eta
826+
self.eps = eps
827+
self.skip = 0
828+
self.last_lr = None
829+
self.cur_lr = None
830+
831+
832+
def _get_lrs(self, indices):
833+
"""Gets the learning rates given the indices of the weights.
834+
835+
Parameters
836+
----------
837+
indices : list of int
838+
Indices corresponding to weights.
839+
840+
Returns
841+
-------
842+
lrs : list of float
843+
Learning rates for those indices.
844+
"""
845+
if self.cur_lr is not None:
846+
self.last_lr = self.cur_lr
847+
848+
if self.lr_scheduler is not None:
849+
lr = self.lr_scheduler(self.num_update)
850+
else:
851+
lr = self.lr
852+
853+
if self.cur_lr is None:
854+
self.last_lr = lr
855+
self.cur_lr = lr
856+
857+
lrs = [lr for _ in indices]
858+
for i, index in enumerate(indices):
859+
if index in self.param_dict:
860+
lrs[i] *= self.param_dict[index].lr_mult
861+
elif index in self.lr_mult:
862+
lrs[i] *= self.lr_mult[index]
863+
elif index in self.idx2name:
864+
lrs[i] *= self.lr_mult.get(self.idx2name[index], 1.0)
865+
return lrs
866+
867+
def set_wd_mult(self, args_wd_mult):
868+
self.wd_mult = {}
869+
for n in self.idx2name.values():
870+
is_weight = n.endswith('_weight')
871+
872+
if not is_weight:
873+
self.wd_mult[n] = 0.0
874+
875+
if self.sym_info:
876+
attr, arg_names = self.sym_info
877+
for name in arg_names:
878+
if name in attr and '__wd_mult__' in attr[name]:
879+
self.wd_mult[name] = float(attr[name]['__wd_mult__'])
880+
self.wd_mult.update(args_wd_mult)
881+
882+
def create_state_multi_precision(self, index, weight):
883+
weight_master_copy = None
884+
if self.multi_precision and weight.dtype == numpy.float16:
885+
weight_master_copy = weight.astype(numpy.float32)
886+
return (self.create_state(index, weight_master_copy), weight_master_copy)
887+
if weight.dtype == numpy.float16 and not self.multi_precision:
888+
warnings.warn("Accumulating with float16 in optimizer can lead to "
889+
"poor accuracy or slow convergence. "
890+
"Consider using multi_precision=True option of the "
891+
"SGD optimizer")
892+
return self.create_state(index, weight)
893+
894+
def create_state(self, index, weight):
895+
momentum = None
896+
if self.momentum != 0.0:
897+
stype = weight.stype if self.lazy_update else 'default'
898+
momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype)
899+
return momentum
900+
901+
def _l2norm(self, v, rescale=False):
902+
"""L2 Norm implementation"""
903+
v = v.astype('float32')
904+
if rescale:
905+
v *= self.rescale_grad
906+
norm = NDnorm(v).asnumpy()[0]
907+
return norm
908+
909+
def _get_lars(self, i, weight, g, lr, wd):
910+
"""Returns a scaling factor for the learning rate for this layer"""
911+
name = self.idx2name[i] if i in self.idx2name else str(i)
912+
if name.endswith('gamma') or name.endswith('beta') or name.endswith('bias'):
913+
return lr
914+
915+
w_norm = self._l2norm(weight)
916+
g_norm = self._l2norm(g, rescale=True)
917+
918+
if w_norm > 0.0 and g_norm > 0.0:
919+
lars = self.eta * w_norm/(g_norm + wd * w_norm + self.eps)
920+
else:
921+
lars = 1.0
922+
return lars * lr
923+
924+
def _update_impl(self, indices, weights, grads, states, multi_precision=False):
925+
aggregate = True
926+
if not isinstance(indices, (tuple, list)):
927+
indices = [indices]
928+
weights = [weights]
929+
grads = [grads]
930+
states = [states]
931+
for weight, grad in zip(weights, grads):
932+
assert(isinstance(weight, NDArray))
933+
assert(isinstance(grad, NDArray))
934+
aggregate = (aggregate and
935+
weight.stype == 'default' and
936+
grad.stype == 'default')
937+
self._update_count(indices)
938+
lrs = self._get_lrs(indices)
939+
wds = self._get_wds(indices)
940+
941+
kwargs = {'rescale_grad': self.rescale_grad}
942+
if self.momentum > 0:
943+
kwargs['momentum'] = (self.momentum * (self.cur_lr / self.last_lr)) \
944+
if (self.momentum_correction and self.last_lr != 0) else \
945+
self.momentum
946+
947+
if self.clip_gradient:
948+
kwargs['clip_gradient'] = self.clip_gradient
949+
950+
if aggregate:
951+
nb_params = len(indices)
952+
names = [self.idx2name[i] if i in self.idx2name else str(i) for i in indices]
953+
lars_idx = [i for i in range(nb_params) if
954+
not(names[i].endswith('gamma') or names[i].endswith('beta') or
955+
names[i].endswith('bias'))]
956+
nb_lars = len(lars_idx)
957+
no_lars_idx = [i for i in range(nb_params) if
958+
(names[i].endswith('gamma') or names[i].endswith('beta') or
959+
names[i].endswith('bias'))]
960+
cur_ctx = weights[0].context
961+
full_idx = lars_idx + no_lars_idx
962+
new_lrs = array([lrs[i] for i in full_idx], ctx=cur_ctx, dtype='float32')
963+
new_wds = array([wds[i] for i in full_idx], ctx=cur_ctx, dtype='float32')
964+
new_weights = [weights[i] for i in full_idx]
965+
new_grads = [grads[i] for i in full_idx]
966+
new_states = [states[i] for i in full_idx]
967+
if nb_lars > 0:
968+
w_sum_sq = multi_sum_sq(*new_weights[:nb_lars], num_arrays=nb_lars)
969+
g_sum_sq = multi_sum_sq(*new_grads[:nb_lars], num_arrays=nb_lars)
970+
multi_lars(new_lrs[:nb_lars], w_sum_sq, g_sum_sq, new_wds[:nb_lars],
971+
eta=self.eta, eps=self.eps, rescale_grad=self.rescale_grad,
972+
out=new_lrs[:nb_lars])
973+
# Same than usual using preloaded sgd functions
974+
sidx = 0
975+
while sidx < len(indices):
976+
eidx = sidx + len(new_weights[sidx:sidx+self.aggregate_num])
977+
if not multi_precision:
978+
if self.momentum > 0:
979+
preloaded_multi_sgd_mom_update(
980+
*(_flatten_list(zip(new_weights[sidx:eidx],
981+
new_grads[sidx:eidx],
982+
new_states[sidx:eidx])) +
983+
[new_lrs[sidx:eidx], new_wds[sidx:eidx]]),
984+
out=new_weights[sidx:eidx],
985+
num_weights=len(new_weights[sidx:eidx]),
986+
**kwargs)
987+
else:
988+
preloaded_multi_sgd_update(
989+
*(_flatten_list(zip(new_weights[sidx:eidx],
990+
new_grads[sidx:eidx])) +
991+
[new_lrs[sidx:eidx], new_wds[sidx:eidx]]),
992+
out=new_weights[sidx:eidx],
993+
num_weights=len(new_weights[sidx:eidx]),
994+
**kwargs)
995+
else:
996+
if self.momentum > 0:
997+
preloaded_multi_mp_sgd_mom_update(
998+
*(_flatten_list(zip(new_weights[sidx:eidx],
999+
new_grads[sidx:eidx],
1000+
*zip(*new_states[sidx:eidx]))) +
1001+
[new_lrs[sidx:eidx], new_wds[sidx:eidx]]),
1002+
out=new_weights[sidx:eidx],
1003+
num_weights=len(new_weights[sidx:eidx]),
1004+
**kwargs)
1005+
else:
1006+
preloaded_multi_mp_sgd_update(
1007+
*(_flatten_list(zip(new_weights[sidx:eidx],
1008+
new_grads[sidx:eidx],
1009+
list(zip(*new_states[sidx:eidx]))[1])) +
1010+
[new_lrs[sidx:eidx], new_wds[sidx:eidx]]),
1011+
out=new_weights[sidx:eidx],
1012+
num_weights=len(new_weights[sidx:eidx]),
1013+
**kwargs)
1014+
sidx += self.aggregate_num
1015+
else:
1016+
lrs = [self._get_lars(i, w, g, lr, wd) for (i, w, g, lr, wd) in
1017+
zip(indices, weights, grads, lrs, wds)]
1018+
1019+
for weight, grad, state, lr, wd in zip(weights, grads, states, lrs, wds):
1020+
if not multi_precision:
1021+
if state is not None:
1022+
sgd_mom_update(weight, grad, state, out=weight,
1023+
lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
1024+
else:
1025+
sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update,
1026+
lr=lr, wd=wd, **kwargs)
1027+
else:
1028+
if state[0] is not None:
1029+
mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight,
1030+
lr=lr, wd=wd, **kwargs)
1031+
else:
1032+
mp_sgd_update(weight, grad, state[1], out=weight,
1033+
lr=lr, wd=wd, **kwargs)
1034+
1035+
def update(self, index, weight, grad, state):
1036+
self._update_impl(index, weight, grad, state, multi_precision=False)
1037+
1038+
def update_multi_precision(self, index, weight, grad, state):
1039+
if not isinstance(index, (tuple, list)):
1040+
use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
1041+
else:
1042+
use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
1043+
self._update_impl(index, weight, grad, state,
1044+
multi_precision=use_multi_precision)
1045+
1046+
#
7841047
@register
7851048
class LBSGD(Optimizer):
7861049
"""The Large Batch SGD optimizer with momentum and weight decay.
@@ -812,7 +1075,7 @@ class LBSGD(Optimizer):
8121075
8131076
warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars' default : 'linear')
8141077
warmup_epochs: unsigned, default: 5
815-
batch_scale: unsigned, default: 1 (same as batch size*numworkers)
1078+
batch_scale: unsigned, default: 1 (same as batch size * numworkers)
8161079
updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
8171080
begin_epoch: unsigned, default 0, starting epoch.
8181081
"""

0 commit comments

Comments
 (0)