Skip to content

Fix deterministic indexing with broadcast #1705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

chunhuanMeng
Copy link
Contributor

Introduces enhancements to the index_put implementation for XPU tensors, focusing on deterministic behavior, improved shape handling, and expanded test coverage. Key changes include adding new helper functions, extending the makeLinearIndex and computeLinearIndex methods, and updating the associated test suite.

Enhancements to index_put Implementation:

  • New Helper Function for Shape Handling:

    • Introduced valsShape to compute the target shape for expanded values during index_put operations. This simplifies and centralizes shape manipulation logic. (src/ATen/native/xpu/sycl/Indexing.cpp)
  • Extended makeLinearIndex and computeLinearIndex:

    • Added dims_before and dims_indexed to track dimensions before and during indexing. These are now returned as part of the tuple from computeLinearIndex and propagated through makeLinearIndex. (src/ATen/native/xpu/sycl/IndexingUtils.h)
  • Simplified Value Expansion in index_put_deterministic_kernel:

    • Replaced manual size inference and expansion logic with a call to valsShape. This makes the code more concise and reduces duplication. (src/ATen/native/xpu/sycl/Indexing.cpp)

Test Suite Enhancements:

  • New Deterministic Tests:
    • Added a new test, test_index_put_deterministic_with_optional_tensors, to validate deterministic behavior of index_put with various tensor shapes and scenarios. This includes checks for shape mismatches and proper handling of 0D, 1D, and 2D values. (test/xpu/test_indexing_xpu.py)

These changes collectively improve the robustness, maintainability, and test coverage of the index_put functionality for XPU tensors.

@Copilot Copilot AI review requested due to automatic review settings May 27, 2025 08:30
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances the index_put implementation on XPU by ensuring deterministic indexing, centralizing shape logic, and bolstering test coverage.

  • Introduce a valsShape helper to compute expanded-value shapes.
  • Extend computeLinearIndex and makeLinearIndex to return dims_before and dims_indexed.
  • Simplify value expansion in index_put_deterministic_kernel via valsShape.
  • Add new deterministic tests for index_put with optional tensors and shape-mismatch checks.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
test/xpu/test_indexing_xpu.py Added deterministic index_put tests, including 0D/1D/2D values and mismatch assertions
src/ATen/native/xpu/sycl/IndexingUtils.h Extended computeLinearIndex/makeLinearIndex to return two new dimension counts
src/ATen/native/xpu/sycl/Indexing.cpp Added valsShape helper and replaced manual expansion in the deterministic kernel
Comments suppressed due to low confidence (2)

test/xpu/test_indexing_xpu.py:18

  • [nitpick] The helper names func and func1 are ambiguous—consider renaming them to clearly reflect their purpose (e.g., index_put_with_guard and simple_index_put).
def func(x, i, v):

test/xpu/test_indexing_xpu.py:35

  • [nitpick] Variable values2d does not match the value0d/value1d pattern—rename to value2d for consistency.
values2d = torch.randn(n, 1)

out_cpu = func(t, indices, value1d)
t = torch.zeros(2, 3, 4)
ind = torch.tensor([0, 1])
val = torch.randn(6, 2)
Copy link
Preview

Copilot AI May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests use different error regexes ("shape mismatch" vs "must match") for CPU vs XPU—consider unifying the expected message or adding a brief comment explaining the discrepancy to prevent brittleness.

Suggested change
val = torch.randn(6, 2)
val = torch.randn(6, 2)
# The error messages differ between CPU ("shape mismatch") and XPU ("must match")
# due to implementation-specific differences in error handling.

Copilot uses AI. Check for mistakes.

@@ -32,17 +32,38 @@ def func(x, i, v):
indices_dev = indices.to(device)
value0d = torch.tensor(10.0)
value1d = torch.tensor([1.0, 2.0])
values2d = torch.randn(n, 1)

for val in (value0d, value1d, values2d):
Copy link
Preview

Copilot AI May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The loop reuses the same t_dev/t across multiple func calls, mutating them cumulatively—consider reinitializing t and t_dev inside the loop to make each subtest independent.

Suggested change
for val in (value0d, value1d, values2d):
for val in (value0d, value1d, values2d):
t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
t_dev = t.to(device)

Copilot uses AI. Check for mistakes.

@@ -70,8 +68,10 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
// are not being index.
Tensor linearIndex;
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0;
Copy link
Preview

Copilot AI May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tuple fields dims_before and dims_indexed would benefit from a brief inline comment explaining their meaning and relationship to the indexing algorithm.

Suggested change
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0;
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0;
// `dims_before` counts the number of dimensions before the indexed dimensions.
// `dims_indexed` counts the number of dimensions that are being indexed.

Copilot uses AI. Check for mistakes.

@@ -609,6 +609,21 @@ void index_put_kernel(
}
}

DimVector valsShape(
Copy link
Preview

Copilot AI May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider marking valsShape as static inline or moving its declaration to the header with a doc comment, so its purpose and usage are clearer and the compiler can inline it across translation units.

Copilot uses AI. Check for mistakes.

@xytintel
Copy link
Contributor

@chunhuanMeng Pls remove test_index_put_deterministic_with_optional_tensors_xpu from skip list

@chunhuanMeng
Copy link
Contributor Author

@chunhuanMeng Pls remove test_index_put_deterministic_with_optional_tensors_xpu from skip list

done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants