@@ -55,23 +55,33 @@ def init_kv():
55
55
def test_sync_push_pull ():
56
56
kv , my_rank , nworker = init_kv ()
57
57
num_gpus = 2
58
- def check_default_keys (kv , my_rank , nworker ):
59
- nrepeat = 3
58
+ def check_default_keys (kv , my_rank , nworker , nrepeat = 3 , offset = 0 , use_pushpull = False ):
60
59
# checks pull after push in loop, because behavior during
61
60
# consecutive pushes doesn't offer any guarantees
62
- for i in range (nrepeat ):
61
+ for i in range (offset , nrepeat ):
63
62
scale = my_rank + 1
64
- kv .push ('3' , [mx .nd .ones (shape , ctx = mx .gpu (j )) * scale for j in range (num_gpus )])
65
- kv .push ('99' , [mx .nd .ones (big_shape , ctx = mx .gpu (j )) * scale for j in range (num_gpus )])
66
63
num = (nworker + 1 ) * nworker * rate * num_gpus / 2 * (i + 1 ) + 1
64
+
65
+ arr = [mx .nd .ones (shape , ctx = mx .gpu (j )) * scale for j in range (num_gpus )]
67
66
val = mx .nd .zeros (shape )
68
- kv .pull ('3' , out = val )
67
+ if use_pushpull :
68
+ kv .pushpull ('3' , arr , out = val )
69
+ else :
70
+ kv .push ('3' , arr )
71
+ kv .pull ('3' , out = val )
69
72
check_diff_to_scalar (val , num )
70
- val2 = mx .nd .zeros (big_shape )
71
- kv .pull ('99' , out = val2 )
72
- check_diff_to_scalar (val2 , num )
73
73
74
- check_default_keys (kv , my_rank , nworker )
74
+ big_arr = [mx .nd .ones (big_shape , ctx = mx .gpu (j )) * scale for j in range (num_gpus )]
75
+ big_val = mx .nd .zeros (big_shape )
76
+ if use_pushpull :
77
+ kv .pushpull ('99' , big_arr , out = big_val )
78
+ else :
79
+ kv .push ('99' , big_arr )
80
+ kv .pull ('99' , out = big_val )
81
+ check_diff_to_scalar (big_val , num )
82
+
83
+ check_default_keys (kv , my_rank , nworker , nrepeat = 3 , offset = 0 , use_pushpull = False )
84
+ check_default_keys (kv , my_rank , nworker , nrepeat = 3 , offset = 3 , use_pushpull = True )
75
85
print ('worker ' + str (my_rank ) + ' is done' )
76
86
77
87
def test_sync_init ():
0 commit comments