diff --git a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py index 0bd55d13..1e34b53f 100644 --- a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py +++ b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py @@ -182,7 +182,7 @@ def __init__(self, local_only=False, rpc_router=self.rpc_router, device=self.device ) else: - assert isinstance(self.dist_node_labels, Dict) + assert self.dist_node_labels is None or isinstance(self.dist_node_labels, Dict) if self.dist_node_labels is not None and \ all(isinstance(value, Feature) for value in self.dist_node_labels.values()): self.dist_node_labels = DistFeature(