Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 05af5c4

Browse files
[BUGFIX] Fix race condition in kvstore.pushpull (#17007)
* add back gluon test * fix typo * change back gpu ctx * also handle the case there some are pull and some are pushpull * fix typo
1 parent 04ebe45 commit 05af5c4

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

src/kvstore/kvstore_dist_server.h

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -364,21 +364,34 @@ class KVStoreDistServer {
364364
if (log_verbose_) {
365365
LOG(INFO) << "sent response to " << update_buf->request.size() << " workers";
366366
}
367+
/**
368+
* Request can be for either push, pull or pushpull
369+
* If pull flag is set, respond immediately with the updated values
370+
* Otherwise, only send the notification
371+
*/
372+
bool has_pull = false;
367373
for (const auto& req : update_buf->request) {
368-
/**
369-
* Request can be for either push, pull or pushpull
370-
* If pull flag is set, respond immediately with the updated values
371-
* Otherwise, only send the notification
372-
*/
373-
if (req.pull) {
374-
DefaultStorageResponse(type, key, req, req_data, server);
375-
} else {
374+
has_pull = has_pull || req.pull;
375+
}
376+
if (has_pull) {
377+
// if there is a pull request, perform WaitToRead() once before DefaultStorageResponse
378+
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
379+
stored.WaitToRead();
380+
for (const auto& req : update_buf->request) {
381+
if (req.pull) {
382+
DefaultStorageResponse(type, key, req, req_data, server);
383+
}
384+
}
385+
update_buf->request.clear();
386+
} else {
387+
// otherwise, send response directly
388+
for (const auto& req : update_buf->request) {
376389
server->Response(req);
377390
}
391+
update_buf->request.clear();
392+
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
393+
stored.WaitToRead();
378394
}
379-
update_buf->request.clear();
380-
if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]);
381-
stored.WaitToRead();
382395
} else {
383396
update_buf->merged.WaitToRead();
384397
}

tests/nightly/dist_device_sync_kvstore.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def check_diff_to_scalar(A, x, rank=None):
4444
def init_kv():
4545
# init kv dns keys
4646
kv.init(keys, [mx.nd.ones(shape)] * len(keys))
47+
kv.init('9', mx.nd.ones(shape))
48+
kv.init('10', mx.nd.ones(shape))
4749
kv.init('99', mx.nd.ones(big_shape))
50+
kv.init('100', mx.nd.ones(big_shape))
4851
# worker info
4952
my_rank = kv.rank
5053
nworker = kv.num_workers
@@ -55,33 +58,30 @@ def init_kv():
5558
def test_sync_push_pull():
5659
kv, my_rank, nworker = init_kv()
5760
num_gpus = 2
58-
def check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False):
61+
def check_default_keys(kv, my_rank, nworker, nrepeat=3):
5962
# checks pull after push in loop, because behavior during
6063
# consecutive pushes doesn't offer any guarantees
61-
for i in range(offset, nrepeat):
64+
for i in range(nrepeat):
6265
scale = my_rank + 1
6366
num = (nworker + 1) * nworker * rate * num_gpus / 2 * (i + 1) + 1
6467

6568
arr = [mx.nd.ones(shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]
6669
val = mx.nd.zeros(shape)
67-
if use_pushpull:
68-
kv.pushpull('3', arr, out=val)
69-
else:
70-
kv.push('3', arr)
71-
kv.pull('3', out=val)
70+
kv.push('9', arr)
71+
kv.pull('9', out=val)
72+
check_diff_to_scalar(val, num)
73+
kv.pushpull('10', arr, out=val)
7274
check_diff_to_scalar(val, num)
7375

7476
big_arr = [mx.nd.ones(big_shape, ctx=mx.gpu(j)) * scale for j in range(num_gpus)]
7577
big_val = mx.nd.zeros(big_shape)
76-
if use_pushpull:
77-
kv.pushpull('99', big_arr, out=big_val)
78-
else:
79-
kv.push('99', big_arr)
80-
kv.pull('99', out=big_val)
78+
kv.push('99', big_arr)
79+
kv.pull('99', out=big_val)
80+
check_diff_to_scalar(big_val, num)
81+
kv.pushpull('100', big_arr, out=big_val)
8182
check_diff_to_scalar(big_val, num)
8283

83-
check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=0, use_pushpull=False)
84-
check_default_keys(kv, my_rank, nworker, nrepeat=3, offset=3, use_pushpull=True)
84+
check_default_keys(kv, my_rank, nworker, nrepeat=3)
8585
print('worker ' + str(my_rank) + ' is done')
8686

8787
def test_sync_init():
@@ -106,10 +106,12 @@ def check_trainer_kv_update(update_on_kv):
106106
x = params.get('x', shape=(10,1), lr_mult=1.0)
107107
params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
108108
try:
109-
trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv)
109+
trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1},
110+
kvstore=kv, update_on_kvstore=update_on_kv)
110111
trainer._init_kvstore()
111112
assert trainer._kv_initialized
112-
assert trainer._update_on_kvstore is True
113+
if update_on_kv is not None:
114+
assert trainer._update_on_kvstore is update_on_kv
113115
except ValueError:
114116
assert update_on_kv is False
115117

@@ -122,3 +124,4 @@ def check_trainer_kv_update(update_on_kv):
122124
if __name__ == "__main__":
123125
test_sync_init()
124126
test_sync_push_pull()
127+
test_gluon_trainer_type()

0 commit comments

Comments
 (0)