Skip to content

Commit 2b63087

Browse files
committed
Fix edge weight normalization
1 parent c4154c8 commit 2b63087

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

csrc/cpu/rw_cpu.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,18 @@ void compute_cdf(const int64_t *rowptr, const float_t *edge_weight,
148148
at::parallel_for(0, numel - 1, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
149149
for(int64_t i = begin; i < end; i++) {
150150
int64_t row_start = rowptr[i], row_end = rowptr[i + 1];
151+
152+
// Compute sum to normalize weights
153+
float_t sum = 0.0;
154+
155+
for(int64_t j = row_start; j < row_end; j++) {
156+
sum += edge_weight[j];
157+
}
158+
151159
float_t acc = 0.0;
152160

153161
for(int64_t j = row_start; j < row_end; j++) {
154-
acc += edge_weight[j];
162+
acc += edge_weight[j] / sum;
155163
edge_weight_cdf[j] = acc;
156164
}
157165
}

csrc/cuda/rw_cuda.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,16 @@ __global__ void cdf_kernel(const int64_t *rowptr, const float_t *edge_weight,
159159
if (thread_idx < numel - 1) {
160160
int64_t row_start = rowptr[thread_idx], row_end = rowptr[thread_idx + 1];
161161

162+
float_t sum = 0.0;
163+
164+
for(int64_t i = row_start; i < row_end; i++) {
165+
sum += edge_weight[i];
166+
}
167+
162168
float_t acc = 0.0;
163169

164170
for(int64_t i = row_start; i < row_end; i++) {
165-
acc += edge_weight[i];
171+
acc += edge_weight[i] / sum;
166172
edge_weight_cdf[i] = acc;
167173
}
168174
}

torch_cluster/rw.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ def random_walk(
6262
rowptr, col, start, walk_length, p, q,
6363
)
6464
else:
65-
# Normalize edge weights by node degrees
66-
edge_weight = edge_weight / deg[row]
67-
6865
node_seq, edge_seq = torch.ops.torch_cluster.random_walk_weighted(
6966
rowptr, col, edge_weight, start, walk_length, p, q,
7067
)

0 commit comments

Comments
 (0)