Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 8d4ac0e

Browse files
author
Anand J
committed
Add PushPull test cases
1 parent d9c4588 commit 8d4ac0e

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

tests/nightly/dist_device_sync_kvstore.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,33 @@ def init_kv():
5555
def test_sync_push_pull():
5656
kv, my_rank, nworker = init_kv()
5757
num_gpus = 2
58-
def check_default_keys(kv, my_rank, nworker):
59-
nrepeat = 3
58+
def check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False):
6059
# checks pull after push in loop, because behavior during
6160
# consecutive pushes doesn't offer any guarantees
62-
for i in range(nrepeat):
61+
for i in range(offset, nrepeat):
6362
scale = my_rank + 1
64-
kv.push('3', [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)])
65-
kv.push('99', [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)])
6663
num = (nworker + 1) * nworker * rate * num_gpus / 2 * (i + 1) + 1
64+
65+
arr = [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]
6766
val = mx.nd.zeros(shape)
68-
kv.pull('3', out=val)
67+
if use_pushpull:
68+
kv.pushpull('3', arr, out=val)
69+
else:
70+
kv.push('3', arr)
71+
kv.pull('3', out=val)
6972
check_diff_to_scalar(val, num)
70-
val2 = mx.nd.zeros(big_shape)
71-
kv.pull('99', out=val2)
72-
check_diff_to_scalar(val2, num)
7373

74-
check_default_keys(kv, my_rank, nworker)
74+
big_arr = [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]
75+
big_val = mx.nd.zeros(big_shape)
76+
if use_pushpull:
77+
kv.pushpull('99', big_arr, out=big_val)
78+
else:
79+
kv.push('99', big_arr)
80+
kv.pull('99', out=big_val)
81+
check_diff_to_scalar(big_val, num)
82+
83+
check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False)
84+
check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=3, use_pushpull=True)
7585
print('worker ' + str(my_rank) + ' is done')
7686

7787
def test_sync_init():

0 commit comments

Comments
 (0)