|
26 | 26 | import os
|
27 | 27 | import numpy
|
28 | 28 | from ..base import py_str
|
29 |
| -from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) |
| 29 | +from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply, |
| 30 | + multi_sum_sq, multi_lars, norm as NDnorm) |
30 | 31 | from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
|
31 | 32 | mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
|
32 | 33 | signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update,
|
33 | 34 | multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
|
34 |
| - multi_mp_sgd_mom_update) |
| 35 | + multi_mp_sgd_mom_update, preloaded_multi_sgd_update, |
| 36 | + preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update, |
| 37 | + preloaded_multi_mp_sgd_mom_update) |
35 | 38 | from ..ndarray import sparse
|
36 | 39 | from ..random import normal
|
37 | 40 | from ..util import is_np_array
|
38 | 41 |
|
39 | 42 | __all__ = [
|
40 |
| - 'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD', |
| 43 | + 'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD', |
41 | 44 | 'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum',
|
42 | 45 | 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
|
43 | 46 | ]
|
@@ -781,6 +784,266 @@ def update(self, index, weight, grad, state):
|
781 | 784 | ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight,
|
782 | 785 | lr=lr, wd=wd, **kwargs)
|
783 | 786 |
|
| 787 | +@register |
| 788 | +class LARS(Optimizer): |
| 789 | + """the LARS optimizer from 'Large Batch Training of Convolution Networks' \ |
| 790 | + (https://arxiv.org/abs/1708.03888) |
| 791 | +
|
| 792 | + Behave mostly like SGD with momentum and weight decay but is scaling \ |
| 793 | + adaptively the learning for each layer (except bias and batch norm parameters): |
| 794 | + w_norm = L2norm(weights) |
| 795 | + g_norm = L2norm(gradients) |
| 796 | + if w_norm > 0 and g_norm > 0: |
| 797 | + lr_layer = lr * lr_mult * eta * w_norm / (g_norm + weight_decay * w_norm + eps) |
| 798 | + else: |
| 799 | + lr_layer = lr * lr_mult |
| 800 | +
|
| 801 | + Parameters |
| 802 | + ---------- |
| 803 | + momentum : float, optional |
| 804 | + The momentum value. |
| 805 | + lazy_update : bool, optional |
| 806 | + Default is True. If True, lazy updates are applied \ |
| 807 | + if the storage types of weight and grad are both ``row_sparse``. |
| 808 | + lars_eta : float, optional |
| 809 | + LARS coefficient used to scale the learning rate. Default set to 0.001. |
| 810 | + lars_epsilon : float, optional |
| 811 | + Optional epsilon in case of very small gradients. Default set to 0. |
| 812 | + momentum_correction : bool, optional |
| 813 | + If True scale momentum w.r.t global learning rate change (with an lr_scheduler) \ |
| 814 | + as indicated in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour` \ |
| 815 | + (https://arxiv.org/pdf/1706.02677.pdf) |
| 816 | + Default set to True. |
| 817 | + """ |
| 818 | + def __init__(self, momentum=0.0, lazy_update=True, eta=0.001, eps=0, |
| 819 | + momentum_correction=True, **kwargs): |
| 820 | + super(LARS, self).__init__(**kwargs) |
| 821 | + self.momentum = momentum |
| 822 | + self.momentum_correction = momentum_correction |
| 823 | + self.lazy_update = lazy_update |
| 824 | + self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4")) |
| 825 | + self.eta = eta |
| 826 | + self.eps = eps |
| 827 | + self.skip = 0 |
| 828 | + self.last_lr = None |
| 829 | + self.cur_lr = None |
| 830 | + |
| 831 | + |
| 832 | + def _get_lrs(self, indices): |
| 833 | + """Gets the learning rates given the indices of the weights. |
| 834 | +
|
| 835 | + Parameters |
| 836 | + ---------- |
| 837 | + indices : list of int |
| 838 | + Indices corresponding to weights. |
| 839 | +
|
| 840 | + Returns |
| 841 | + ------- |
| 842 | + lrs : list of float |
| 843 | + Learning rates for those indices. |
| 844 | + """ |
| 845 | + if self.cur_lr is not None: |
| 846 | + self.last_lr = self.cur_lr |
| 847 | + |
| 848 | + if self.lr_scheduler is not None: |
| 849 | + lr = self.lr_scheduler(self.num_update) |
| 850 | + else: |
| 851 | + lr = self.lr |
| 852 | + |
| 853 | + if self.cur_lr is None: |
| 854 | + self.last_lr = lr |
| 855 | + self.cur_lr = lr |
| 856 | + |
| 857 | + lrs = [lr for _ in indices] |
| 858 | + for i, index in enumerate(indices): |
| 859 | + if index in self.param_dict: |
| 860 | + lrs[i] *= self.param_dict[index].lr_mult |
| 861 | + elif index in self.lr_mult: |
| 862 | + lrs[i] *= self.lr_mult[index] |
| 863 | + elif index in self.idx2name: |
| 864 | + lrs[i] *= self.lr_mult.get(self.idx2name[index], 1.0) |
| 865 | + return lrs |
| 866 | + |
| 867 | + def set_wd_mult(self, args_wd_mult): |
| 868 | + self.wd_mult = {} |
| 869 | + for n in self.idx2name.values(): |
| 870 | + is_weight = n.endswith('_weight') |
| 871 | + |
| 872 | + if not is_weight: |
| 873 | + self.wd_mult[n] = 0.0 |
| 874 | + |
| 875 | + if self.sym_info: |
| 876 | + attr, arg_names = self.sym_info |
| 877 | + for name in arg_names: |
| 878 | + if name in attr and '__wd_mult__' in attr[name]: |
| 879 | + self.wd_mult[name] = float(attr[name]['__wd_mult__']) |
| 880 | + self.wd_mult.update(args_wd_mult) |
| 881 | + |
| 882 | + def create_state_multi_precision(self, index, weight): |
| 883 | + weight_master_copy = None |
| 884 | + if self.multi_precision and weight.dtype == numpy.float16: |
| 885 | + weight_master_copy = weight.astype(numpy.float32) |
| 886 | + return (self.create_state(index, weight_master_copy), weight_master_copy) |
| 887 | + if weight.dtype == numpy.float16 and not self.multi_precision: |
| 888 | + warnings.warn("Accumulating with float16 in optimizer can lead to " |
| 889 | + "poor accuracy or slow convergence. " |
| 890 | + "Consider using multi_precision=True option of the " |
| 891 | + "SGD optimizer") |
| 892 | + return self.create_state(index, weight) |
| 893 | + |
| 894 | + def create_state(self, index, weight): |
| 895 | + momentum = None |
| 896 | + if self.momentum != 0.0: |
| 897 | + stype = weight.stype if self.lazy_update else 'default' |
| 898 | + momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype) |
| 899 | + return momentum |
| 900 | + |
| 901 | + def _l2norm(self, v, rescale=False): |
| 902 | + """L2 Norm implementation""" |
| 903 | + v = v.astype('float32') |
| 904 | + if rescale: |
| 905 | + v *= self.rescale_grad |
| 906 | + norm = NDnorm(v).asnumpy()[0] |
| 907 | + return norm |
| 908 | + |
| 909 | + def _get_lars(self, i, weight, g, lr, wd): |
| 910 | + """Returns a scaling factor for the learning rate for this layer""" |
| 911 | + name = self.idx2name[i] if i in self.idx2name else str(i) |
| 912 | + if name.endswith('gamma') or name.endswith('beta') or name.endswith('bias'): |
| 913 | + return lr |
| 914 | + |
| 915 | + w_norm = self._l2norm(weight) |
| 916 | + g_norm = self._l2norm(g, rescale=True) |
| 917 | + |
| 918 | + if w_norm > 0.0 and g_norm > 0.0: |
| 919 | + lars = self.eta * w_norm/(g_norm + wd * w_norm + self.eps) |
| 920 | + else: |
| 921 | + lars = 1.0 |
| 922 | + return lars * lr |
| 923 | + |
| 924 | + def _update_impl(self, indices, weights, grads, states, multi_precision=False): |
| 925 | + aggregate = True |
| 926 | + if not isinstance(indices, (tuple, list)): |
| 927 | + indices = [indices] |
| 928 | + weights = [weights] |
| 929 | + grads = [grads] |
| 930 | + states = [states] |
| 931 | + for weight, grad in zip(weights, grads): |
| 932 | + assert(isinstance(weight, NDArray)) |
| 933 | + assert(isinstance(grad, NDArray)) |
| 934 | + aggregate = (aggregate and |
| 935 | + weight.stype == 'default' and |
| 936 | + grad.stype == 'default') |
| 937 | + self._update_count(indices) |
| 938 | + lrs = self._get_lrs(indices) |
| 939 | + wds = self._get_wds(indices) |
| 940 | + |
| 941 | + kwargs = {'rescale_grad': self.rescale_grad} |
| 942 | + if self.momentum > 0: |
| 943 | + kwargs['momentum'] = (self.momentum * (self.cur_lr / self.last_lr)) \ |
| 944 | + if (self.momentum_correction and self.last_lr != 0) else \ |
| 945 | + self.momentum |
| 946 | + |
| 947 | + if self.clip_gradient: |
| 948 | + kwargs['clip_gradient'] = self.clip_gradient |
| 949 | + |
| 950 | + if aggregate: |
| 951 | + nb_params = len(indices) |
| 952 | + names = [self.idx2name[i] if i in self.idx2name else str(i) for i in indices] |
| 953 | + lars_idx = [i for i in range(nb_params) if |
| 954 | + not(names[i].endswith('gamma') or names[i].endswith('beta') or |
| 955 | + names[i].endswith('bias'))] |
| 956 | + nb_lars = len(lars_idx) |
| 957 | + no_lars_idx = [i for i in range(nb_params) if |
| 958 | + (names[i].endswith('gamma') or names[i].endswith('beta') or |
| 959 | + names[i].endswith('bias'))] |
| 960 | + cur_ctx = weights[0].context |
| 961 | + full_idx = lars_idx + no_lars_idx |
| 962 | + new_lrs = array([lrs[i] for i in full_idx], ctx=cur_ctx, dtype='float32') |
| 963 | + new_wds = array([wds[i] for i in full_idx], ctx=cur_ctx, dtype='float32') |
| 964 | + new_weights = [weights[i] for i in full_idx] |
| 965 | + new_grads = [grads[i] for i in full_idx] |
| 966 | + new_states = [states[i] for i in full_idx] |
| 967 | + if nb_lars > 0: |
| 968 | + w_sum_sq = multi_sum_sq(*new_weights[:nb_lars], num_arrays=nb_lars) |
| 969 | + g_sum_sq = multi_sum_sq(*new_grads[:nb_lars], num_arrays=nb_lars) |
| 970 | + multi_lars(new_lrs[:nb_lars], w_sum_sq, g_sum_sq, new_wds[:nb_lars], |
| 971 | + eta=self.eta, eps=self.eps, rescale_grad=self.rescale_grad, |
| 972 | + out=new_lrs[:nb_lars]) |
| 973 | + # Same than usual using preloaded sgd functions |
| 974 | + sidx = 0 |
| 975 | + while sidx < len(indices): |
| 976 | + eidx = sidx + len(new_weights[sidx:sidx+self.aggregate_num]) |
| 977 | + if not multi_precision: |
| 978 | + if self.momentum > 0: |
| 979 | + preloaded_multi_sgd_mom_update( |
| 980 | + *(_flatten_list(zip(new_weights[sidx:eidx], |
| 981 | + new_grads[sidx:eidx], |
| 982 | + new_states[sidx:eidx])) + |
| 983 | + [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), |
| 984 | + out=new_weights[sidx:eidx], |
| 985 | + num_weights=len(new_weights[sidx:eidx]), |
| 986 | + **kwargs) |
| 987 | + else: |
| 988 | + preloaded_multi_sgd_update( |
| 989 | + *(_flatten_list(zip(new_weights[sidx:eidx], |
| 990 | + new_grads[sidx:eidx])) + |
| 991 | + [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), |
| 992 | + out=new_weights[sidx:eidx], |
| 993 | + num_weights=len(new_weights[sidx:eidx]), |
| 994 | + **kwargs) |
| 995 | + else: |
| 996 | + if self.momentum > 0: |
| 997 | + preloaded_multi_mp_sgd_mom_update( |
| 998 | + *(_flatten_list(zip(new_weights[sidx:eidx], |
| 999 | + new_grads[sidx:eidx], |
| 1000 | + *zip(*new_states[sidx:eidx]))) + |
| 1001 | + [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), |
| 1002 | + out=new_weights[sidx:eidx], |
| 1003 | + num_weights=len(new_weights[sidx:eidx]), |
| 1004 | + **kwargs) |
| 1005 | + else: |
| 1006 | + preloaded_multi_mp_sgd_update( |
| 1007 | + *(_flatten_list(zip(new_weights[sidx:eidx], |
| 1008 | + new_grads[sidx:eidx], |
| 1009 | + list(zip(*new_states[sidx:eidx]))[1])) + |
| 1010 | + [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), |
| 1011 | + out=new_weights[sidx:eidx], |
| 1012 | + num_weights=len(new_weights[sidx:eidx]), |
| 1013 | + **kwargs) |
| 1014 | + sidx += self.aggregate_num |
| 1015 | + else: |
| 1016 | + lrs = [self._get_lars(i, w, g, lr, wd) for (i, w, g, lr, wd) in |
| 1017 | + zip(indices, weights, grads, lrs, wds)] |
| 1018 | + |
| 1019 | + for weight, grad, state, lr, wd in zip(weights, grads, states, lrs, wds): |
| 1020 | + if not multi_precision: |
| 1021 | + if state is not None: |
| 1022 | + sgd_mom_update(weight, grad, state, out=weight, |
| 1023 | + lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) |
| 1024 | + else: |
| 1025 | + sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, |
| 1026 | + lr=lr, wd=wd, **kwargs) |
| 1027 | + else: |
| 1028 | + if state[0] is not None: |
| 1029 | + mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, |
| 1030 | + lr=lr, wd=wd, **kwargs) |
| 1031 | + else: |
| 1032 | + mp_sgd_update(weight, grad, state[1], out=weight, |
| 1033 | + lr=lr, wd=wd, **kwargs) |
| 1034 | + |
| 1035 | + def update(self, index, weight, grad, state): |
| 1036 | + self._update_impl(index, weight, grad, state, multi_precision=False) |
| 1037 | + |
| 1038 | + def update_multi_precision(self, index, weight, grad, state): |
| 1039 | + if not isinstance(index, (tuple, list)): |
| 1040 | + use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 |
| 1041 | + else: |
| 1042 | + use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 |
| 1043 | + self._update_impl(index, weight, grad, state, |
| 1044 | + multi_precision=use_multi_precision) |
| 1045 | + |
| 1046 | +# |
784 | 1047 | @register
|
785 | 1048 | class LBSGD(Optimizer):
|
786 | 1049 | """The Large Batch SGD optimizer with momentum and weight decay.
|
@@ -812,7 +1075,7 @@ class LBSGD(Optimizer):
|
812 | 1075 |
|
813 | 1076 | warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars' default : 'linear')
|
814 | 1077 | warmup_epochs: unsigned, default: 5
|
815 |
| - batch_scale: unsigned, default: 1 (same as batch size*numworkers) |
| 1078 | + batch_scale: unsigned, default: 1 (same as batch size * numworkers) |
816 | 1079 | updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
|
817 | 1080 | begin_epoch: unsigned, default 0, starting epoch.
|
818 | 1081 | """
|
|
0 commit comments