@@ -44,7 +44,10 @@ def check_diff_to_scalar(A, x, rank=None):
44
44
def init_kv ():
45
45
# init kv dns keys
46
46
kv .init (keys , [mx .nd .ones (shape )] * len (keys ))
47
+ kv .init ('9' , mx .nd .ones (shape ))
48
+ kv .init ('10' , mx .nd .ones (shape ))
47
49
kv .init ('99' , mx .nd .ones (big_shape ))
50
+ kv .init ('100' , mx .nd .ones (big_shape ))
48
51
# worker info
49
52
my_rank = kv .rank
50
53
nworker = kv .num_workers
@@ -55,33 +58,30 @@ def init_kv():
55
58
def test_sync_push_pull ():
56
59
kv , my_rank , nworker = init_kv ()
57
60
num_gpus = 2
58
- def check_default_keys (kv , my_rank , nworker , nrepeat = 3 , offset = 0 , use_pushpull = False ):
61
+ def check_default_keys (kv , my_rank , nworker , nrepeat = 3 ):
59
62
# checks pull after push in loop, because behavior during
60
63
# consecutive pushes doesn't offer any guarantees
61
- for i in range (offset , nrepeat ):
64
+ for i in range (nrepeat ):
62
65
scale = my_rank + 1
63
66
num = (nworker + 1 ) * nworker * rate * num_gpus / 2 * (i + 1 ) + 1
64
67
65
68
arr = [mx .nd .ones (shape , ctx = mx .gpu (j )) * scale for j in range (num_gpus )]
66
69
val = mx .nd .zeros (shape )
67
- if use_pushpull :
68
- kv .pushpull ('3' , arr , out = val )
69
- else :
70
- kv .push ('3' , arr )
71
- kv .pull ('3' , out = val )
70
+ kv .push ('9' , arr )
71
+ kv .pull ('9' , out = val )
72
+ check_diff_to_scalar (val , num )
73
+ kv .pushpull ('10' , arr , out = val )
72
74
check_diff_to_scalar (val , num )
73
75
74
76
big_arr = [mx .nd .ones (big_shape , ctx = mx .gpu (j )) * scale for j in range (num_gpus )]
75
77
big_val = mx .nd .zeros (big_shape )
76
- if use_pushpull :
77
- kv .pushpull ('99' , big_arr , out = big_val )
78
- else :
79
- kv .push ('99' , big_arr )
80
- kv .pull ('99' , out = big_val )
78
+ kv .push ('99' , big_arr )
79
+ kv .pull ('99' , out = big_val )
80
+ check_diff_to_scalar (big_val , num )
81
+ kv .pushpull ('100' , big_arr , out = big_val )
81
82
check_diff_to_scalar (big_val , num )
82
83
83
- check_default_keys (kv , my_rank , nworker , nrepeat = 3 , offset = 0 , use_pushpull = False )
84
- check_default_keys (kv , my_rank , nworker , nrepeat = 3 , offset = 3 , use_pushpull = True )
84
+ check_default_keys (kv , my_rank , nworker , nrepeat = 3 )
85
85
print ('worker ' + str (my_rank ) + ' is done' )
86
86
87
87
def test_sync_init ():
@@ -106,10 +106,12 @@ def check_trainer_kv_update(update_on_kv):
106
106
x = params .get ('x' , shape = (10 ,1 ), lr_mult = 1.0 )
107
107
params .initialize (ctx = [mx .cpu (0 ), mx .cpu (1 )], init = 'zeros' )
108
108
try :
109
- trainer = mx .gluon .Trainer (params , 'sgd' , {'learning_rate' : 0.1 }, kvstore = kv , update_on_kvstore = update_on_kv )
109
+ trainer = mx .gluon .Trainer (params , 'sgd' , {'learning_rate' : 0.1 },
110
+ kvstore = kv , update_on_kvstore = update_on_kv )
110
111
trainer ._init_kvstore ()
111
112
assert trainer ._kv_initialized
112
- assert trainer ._update_on_kvstore is True
113
+ if update_on_kv is not None :
114
+ assert trainer ._update_on_kvstore is update_on_kv
113
115
except ValueError :
114
116
assert update_on_kv is False
115
117
@@ -122,3 +124,4 @@ def check_trainer_kv_update(update_on_kv):
122
124
if __name__ == "__main__" :
123
125
test_sync_init ()
124
126
test_sync_push_pull ()
127
+ test_gluon_trainer_type ()
0 commit comments