Skip to content

[SYCL] Replace __spirv_SubgroupShuffle{...}INTEL with __spirv_GroupNonUniformShuffle{...} generic versions. #14748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: sycl
Choose a base branch
from
Open
14 changes: 14 additions & 0 deletions sycl/include/sycl/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,20 @@ template <typename ValueT, typename IdT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformShuffle(__spv::Scope::Flag, ValueT, IdT) noexcept;

template <typename ValueT, typename IdT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformShuffleXor(__spv::Scope::Flag, ValueT, IdT) noexcept;

template <typename ValueT, typename IdT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT ValueT
__spirv_GroupNonUniformShuffleUp(__spv::Scope::Flag, ValueT, IdT) noexcept;

template <typename ValueT, typename IdT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
__SYCL_EXPORT ValueT __spirv_GroupNonUniformShuffleDown(__spv::Scope::Flag,
ValueT,
IdT) noexcept;

__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL __SYCL_EXPORT bool
__spirv_GroupNonUniformAll(__spv::Scope::Flag, bool);

Expand Down
88 changes: 18 additions & 70 deletions sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,30 +796,19 @@ AtomicMax(multi_ptr<T, AddressSpace, IsDecorated> MPtr, memory_scope Scope,
// variants for all scalar types
#ifndef __NVPTX__

template <typename T>
struct TypeIsProhibitedForShuffleEmulation
: std::bool_constant<
check_type_in_v<vector_element_t<T>, double, long, long long,
unsigned long, unsigned long long, half>> {};

template <typename T>
struct VecTypeIsProhibitedForShuffleEmulation
: std::bool_constant<
(detail::get_vec_size<T>::size > 1) &&
TypeIsProhibitedForShuffleEmulation<vector_element_t<T>>::value> {};

// Note: Although SPIR-V supports vector shuffles, the OpenCL specification only
// allow scalars in the operations. As such, we scalarize those too, then
// expect vectorization from the device compiler if possible.
// https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_cl_khr_subgroup_shuffle
template <typename T>
using EnableIfNativeShuffle =
std::enable_if_t<detail::is_arithmetic<T>::value &&
!VecTypeIsProhibitedForShuffleEmulation<T>::value &&
!detail::is_marray_v<T>,
!detail::is_marray_v<T> && !detail::is_vec_v<T>,
T>;

template <typename T>
using EnableIfNonScalarShuffle =
std::enable_if_t<VecTypeIsProhibitedForShuffleEmulation<T>::value ||
detail::is_marray_v<T>,
T>;
std::enable_if_t<detail::is_marray_v<T> || detail::is_vec_v<T>, T>;

#else // ifndef __NVPTX__

Expand Down Expand Up @@ -924,23 +913,8 @@ EnableIfNativeShuffle<T> Shuffle(GroupT g, T x, id<1> local_id) {
uint32_t LocalId = MapShuffleID(g, local_id);
#ifndef __NVPTX__
std::ignore = g;
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT> &&
detail::is_vec<T>::value) {
// Temporary work-around due to a bug in IGC.
// TODO: Remove when IGC bug is fixed.
T result;
for (int s = 0; s < x.size(); ++s)
result[s] = Shuffle(g, x[s], local_id);
return result;
} else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
return __spirv_GroupNonUniformShuffle(group_scope<GroupT>::value,
convertToOpenCLType(x), LocalId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleINTEL(convertToOpenCLType(x), LocalId);
}
return __spirv_GroupNonUniformShuffle(group_scope<GroupT>::value,
convertToOpenCLType(x), LocalId);
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
Expand All @@ -957,16 +931,7 @@ EnableIfNativeShuffle<T> ShuffleXor(GroupT g, T x, id<1> mask) {
#ifndef __NVPTX__
std::ignore = g;
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT> &&
detail::is_vec<T>::value) {
// Temporary work-around due to a bug in IGC.
// TODO: Remove when IGC bug is fixed.
T result;
for (int s = 0; s < x.size(); ++s)
result[s] = ShuffleXor(g, x[s], mask);
return result;
} else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
GroupT>) {
// Since the masks are relative to the groups, we could either try to adjust
// the mask or simply do the xor ourselves. Latter option is efficient,
// general, and simple so we go with that.
Expand All @@ -976,8 +941,9 @@ EnableIfNativeShuffle<T> ShuffleXor(GroupT g, T x, id<1> mask) {
convertToOpenCLType(x), TargetId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleXorINTEL(convertToOpenCLType(x),
static_cast<uint32_t>(mask.get(0)));
return __spirv_GroupNonUniformShuffleXor(
__spv::Scope::Subgroup, convertToOpenCLType(x),
static_cast<uint32_t>(mask.get(0)));
}
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
Expand All @@ -1004,16 +970,7 @@ template <typename GroupT, typename T>
EnableIfNativeShuffle<T> ShuffleDown(GroupT g, T x, uint32_t delta) {
#ifndef __NVPTX__
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT> &&
detail::is_vec<T>::value) {
// Temporary work-around due to a bug in IGC.
// TODO: Remove when IGC bug is fixed.
T result;
for (int s = 0; s < x.size(); ++s)
result[s] = ShuffleDown(g, x[s], delta);
return result;
} else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
GroupT>) {
id<1> TargetLocalId = g.get_local_id();
// ID outside the group range is UB, so we just keep the current item ID
// unchanged.
Expand All @@ -1024,8 +981,8 @@ EnableIfNativeShuffle<T> ShuffleDown(GroupT g, T x, uint32_t delta) {
convertToOpenCLType(x), TargetId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleDownINTEL(convertToOpenCLType(x),
convertToOpenCLType(x), delta);
return __spirv_GroupNonUniformShuffleDown(__spv::Scope::Subgroup,
convertToOpenCLType(x), delta);
}
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
Expand All @@ -1049,16 +1006,7 @@ template <typename GroupT, typename T>
EnableIfNativeShuffle<T> ShuffleUp(GroupT g, T x, uint32_t delta) {
#ifndef __NVPTX__
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT> &&
detail::is_vec<T>::value) {
// Temporary work-around due to a bug in IGC.
// TODO: Remove when IGC bug is fixed.
T result;
for (int s = 0; s < x.size(); ++s)
result[s] = ShuffleUp(g, x[s], delta);
return result;
} else if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
GroupT>) {
GroupT>) {
id<1> TargetLocalId = g.get_local_id();
// Underflow is UB, so we just keep the current item ID unchanged.
if (TargetLocalId[0] >= delta)
Expand All @@ -1068,8 +1016,8 @@ EnableIfNativeShuffle<T> ShuffleUp(GroupT g, T x, uint32_t delta) {
convertToOpenCLType(x), TargetId);
} else {
// Subgroup.
return __spirv_SubgroupShuffleUpINTEL(convertToOpenCLType(x),
convertToOpenCLType(x), delta);
return __spirv_GroupNonUniformShuffleUp(__spv::Scope::Subgroup,
convertToOpenCLType(x), delta);
}
#else
if constexpr (ext::oneapi::experimental::is_user_constructed_group_v<
Expand Down
19 changes: 0 additions & 19 deletions sycl/include/syclcompat/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,6 @@
#include <sycl/ext/oneapi/experimental/cuda/masked_shuffles.hpp>
#endif

// TODO: Remove these function definitions once they exist in the DPC++ compiler
#if defined(__SYCL_DEVICE_ONLY__) && defined(__INTEL_LLVM_COMPILER)
template <typename T>
__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT
__attribute__((noduplicate)) T
__spirv_GroupNonUniformShuffle(__spv::Scope::Flag, T, unsigned) noexcept;

template <typename T>
__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT
__attribute__((noduplicate)) T
__spirv_GroupNonUniformShuffleDown(__spv::Scope::Flag, T,
unsigned) noexcept;

template <typename T>
__SYCL_CONVERGENT__ extern SYCL_EXTERNAL __SYCL_EXPORT
__attribute__((noduplicate)) T
__spirv_GroupNonUniformShuffleUp(__spv::Scope::Flag, T, unsigned) noexcept;
#endif

namespace syclcompat {

namespace detail {
Expand Down
Loading