Skip to content

<mdspan>: Fix layout_stride::mapping<E>::is_exhaustive() #5477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 103 additions & 2 deletions stl/inc/mdspan
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#if !_HAS_CXX23
_EMIT_STL_WARNING(STL4038, "The contents of <mdspan> are available only with C++23 or later.");
#else // ^^^ !_HAS_CXX23 / _HAS_CXX23 vvv
#include <algorithm>
#include <array>
#include <span>
#include <tuple>
Expand Down Expand Up @@ -311,6 +312,10 @@ public:
return ((0 <= _Indices && _Indices < extent(_Seq)) && ...);
}
}

template <class _ExtentsT>
friend constexpr pair<size_t, size_t> _Count_dynamic_extents_equal_to_zero_or_one(
const _ExtentsT&) noexcept; // NB: used by 'layout_stride::mapping<E>::is_exhaustive'
};

template <class>
Expand Down Expand Up @@ -788,6 +793,28 @@ concept _Layout_mapping_alike = requires {
bool_constant<_Mp::is_always_unique()>::value;
};

template <class _Extents, size_t _Val>
constexpr size_t _Count_static_extents_equal_to = 0;

template <class _IndexType, size_t... _Extents, size_t _Val>
constexpr size_t _Count_static_extents_equal_to<extents<_IndexType, _Extents...>, _Val> =
(static_cast<size_t>(_Extents == _Val) + ... + 0);

template <class _Extents>
_NODISCARD constexpr pair<size_t, size_t> _Count_dynamic_extents_equal_to_zero_or_one(const _Extents& _Exts) noexcept {
_STL_INTERNAL_STATIC_ASSERT(_Is_extents<_Extents> && _Extents::rank_dynamic() != 0);
size_t _Zero_extents = 0;
size_t _One_extents = 0;
for (const auto& _Ext : _Exts._Array) {
if (_Ext == 0) {
++_Zero_extents;
} else if (_Ext == 1) {
++_One_extents;
}
}
return {_Zero_extents, _One_extents};
}

template <class _Extents>
class layout_stride::mapping : private _Maybe_fully_static_extents<_Extents>,
private _Maybe_empty_array<typename _Extents::index_type, _Extents::rank()> {
Expand Down Expand Up @@ -963,11 +990,53 @@ public:
}

_NODISCARD constexpr bool is_exhaustive() const noexcept {
constexpr size_t _Static_zero_extents = _Count_static_extents_equal_to<extents_type, 0>;
if constexpr (extents_type::rank() == 0) {
return true;
} else if constexpr (extents_type::rank() == 1) {
return this->_Array[0] == 1;
} else if constexpr (_Static_zero_extents >= 2) {
// Per N5008 [mdspan.layout.stride.obs]/5.2, we are looking for a permutation P of integers in the range
// '[0, rank)' such that 'stride(p[i]) == stride(p[i-1])*extent(p[i-1])' is true for 'i' in the range
// '[1, rank)'. Knowing that at least two extents are equal to zero, we can deduce that such a permutation
// does not exist:
// - Some 'stride(p[j])' would have to be equal to 'stride(p[j-1])*extents(p[j-1]) = stride(p[j-1])*0 = 0'
// which is not possible.
// - Only 'extent(p[rank-1])' can be equal to 0, because it's not required to satisfy the condition above.
// Since we have two or more extents equal to 0 this is not possible either.
return false;
} else if constexpr (extents_type::rank() == 2) {
return (this->_Array[0] == 1 && this->_Array[1] == this->_Exts.extent(0))
|| (this->_Array[1] == 1 && this->_Array[0] == this->_Exts.extent(1));
} else {
return required_span_size()
== _Fwd_prod_of_extents<extents_type>::_Calculate(this->_Exts, extents_type::_Rank);
// NB: Extents equal to 1 are problematic too - sometimes in such cases even when the mapping is exhaustive
// this function should return false.
// For example, when the extents are [2, 1, 2] and the strides are [1, 5, 2], the mapping is exhaustive
// per N5008 [mdspan.layout.reqmts]/16 but not per N5008 [mdspan.layout.stride.obs]/5.2.
constexpr size_t _Static_zero_or_one_extents =
_Static_zero_extents + _Count_static_extents_equal_to<extents_type, 1>;

if constexpr (extents_type::rank_dynamic() != 0) {
const auto [_Dynamic_zero_extents, _Dynamic_one_extents] =
_Count_dynamic_extents_equal_to_zero_or_one(this->_Exts);

const size_t _All_zero_extents = _Static_zero_extents + _Dynamic_zero_extents;
if (_All_zero_extents >= 2) {
return false;
}

const size_t _All_zero_or_one_extents =
_Static_zero_or_one_extents + _Dynamic_zero_extents + _Dynamic_one_extents;
if (_All_zero_or_one_extents == 0) {
return _Is_exhaustive_common_case();
}

return _Is_exhaustive_special_case();
} else if constexpr (_Static_zero_or_one_extents == 0) {
return _Is_exhaustive_common_case();
} else {
return _Is_exhaustive_special_case();
}
}
}

Expand Down Expand Up @@ -1036,6 +1105,38 @@ private:

return static_cast<index_type>(((_Indices * this->_Array[_Seq]) + ... + 0));
}

_NODISCARD constexpr bool _Is_exhaustive_common_case() const noexcept {
return required_span_size() == _Fwd_prod_of_extents<extents_type>::_Calculate(this->_Exts, extents_type::_Rank);
}

_NODISCARD constexpr bool _Is_exhaustive_special_case() const noexcept {
using _Stride_extent_pair = pair<rank_type, rank_type>;
array<_Stride_extent_pair, extents_type::rank()> _Pairs;
for (rank_type _Idx = 0; _Idx < extents_type::_Rank; ++_Idx) {
rank_type _Ext = static_cast<rank_type>(this->_Exts.extent(_Idx));
if (_Ext == 0) {
// NB: _Ext equal to zero is special - we want it to end up as close to the end of the sorted range as
// possible, so we assign max value of rank_type to it.
_Ext = static_cast<rank_type>(-1);
}

_Pairs[_Idx] = {static_cast<rank_type>(this->_Array[_Idx]), _Ext};
}

_RANGES sort(_Pairs);
if (_Pairs[0].first != 1) {
return false;
}

for (rank_type _Idx = 1; _Idx < extents_type::_Rank; ++_Idx) {
if (_Pairs[_Idx].first != _Pairs[_Idx - 1].first * _Pairs[_Idx - 1].second) {
return false;
}
}

return true;
}
};

_EXPORT_STD template <class _ElementType>
Expand Down
6 changes: 3 additions & 3 deletions tests/libcxx/expected_results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ std/ranges/range.adaptors/range.join/range.join.iterator/arrow.pass.cpp FAIL
# If any feature-test macro test is failing, this consolidated test will also fail.
std/language.support/support.limits/support.limits.general/version.version.compile.pass.cpp FAIL

# libc++ incorrectly implements `layout_stride::mapping<E>::is_exhaustive()`
std/containers/views/mdspan/layout_stride/is_exhaustive_corner_case.pass.cpp FAIL


# *** INTERACTIONS WITH MSVC THAT UPSTREAM LIKELY WON'T FIX ***
# These tests set an allocator with a max_size() too small to default construct an unordered container
Expand Down Expand Up @@ -919,9 +922,6 @@ std/time/time.syn/formatter.month_day_last.pass.cpp FAIL
# Our monotonic_buffer_resource takes "user" space for metadata, which it probably should not do.
std/utilities/utility/mem.res/mem.res.monotonic.buffer/mem.res.monotonic.buffer.mem/allocate_with_initial_size.pass.cpp FAIL

# Likely STL bug in layout_stride::mapping::is_exhaustive().
std/containers/views/mdspan/layout_stride/is_exhaustive_corner_case.pass.cpp FAIL


# *** NOT YET ANALYZED ***
# Not analyzed. Clang instantiates BoomOnAnything during template argument substitution.
Expand Down
95 changes: 79 additions & 16 deletions tests/std/tests/P0009R18_mdspan_layout_stride/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,22 +400,85 @@ constexpr void check_required_span_size() {
}

constexpr void check_is_exhaustive() {
{ // Check exhaustive mappings (all possibilities)
using E = extents<int, 2, 3, 5>;
assert((layout_stride::mapping<E>{E{}, array{1, 2, 6}}.is_exhaustive()));
assert((layout_stride::mapping<E>{E{}, array{1, 10, 2}}.is_exhaustive()));
assert((layout_stride::mapping<E>{E{}, array{3, 1, 6}}.is_exhaustive()));
assert((layout_stride::mapping<E>{E{}, array{15, 1, 3}}.is_exhaustive()));
assert((layout_stride::mapping<E>{E{}, array{5, 10, 1}}.is_exhaustive()));
assert((layout_stride::mapping<E>{E{}, array{15, 5, 1}}.is_exhaustive()));
}

{ // Check non-exhaustive mappings
using E = extents<int, 2, 5, 8>;
assert((!layout_stride::mapping<E>{E{}, array{1, 2, 12}}.is_exhaustive()));
assert((!layout_stride::mapping<E>{E{}, array{8, 18, 1}}.is_exhaustive()));
assert((!layout_stride::mapping<E>{E{}, array{5, 1, 12}}.is_exhaustive()));
}
auto check = [](const auto& exts, const auto& strides, bool expected) {
layout_stride::mapping m{exts, strides};
assert(m.is_exhaustive() == expected);
};

// rank() is equal to 0
check(extents<int>{}, array<int, 0>{}, true);

// rank() is equal to 1
check(extents<int, 0>{}, array{1}, true);
check(dextents<int, 1>{0}, array{2}, false);
check(extents<int, 1>{}, array{3}, false);
check(dextents<int, 1>{2}, array{2}, false);
check(extents<int, 3>{}, array{1}, true);
check(dextents<int, 1>{4}, array{1}, true);

// rank() is equal to 2
check(extents<int, 3, 3>{}, array{1, 3}, true);
check(extents<int, dynamic_extent, 3>{3}, array{3, 1}, true);
check(extents<int, 3, dynamic_extent>{3}, array{4, 1}, false);
check(dextents<int, 2>{3, 3}, array{3, 1}, true);
check(extents<int, 4, 5>{}, array{5, 1}, true);
check(extents<int, 6, dynamic_extent>{5}, array{1, 6}, true);
check(extents<int, dynamic_extent, 7>{5}, array{1, 8}, false);
check(dextents<int, 2>{6, 5}, array{1, 10}, false);
check(extents<int, 0, 3>{}, array{3, 1}, true);
check(extents<int, 0, 3>{}, array{6, 2}, false);
check(extents<int, dynamic_extent, 3>{0}, array{6, 1}, false);
check(extents<int, 0, dynamic_extent>{3}, array{6, 2}, false);
check(dextents<int, 2>{0, 3}, array{7, 2}, false);
check(extents<int, 0, 0>{}, array{1, 1}, false);
check(extents<int, 0, dynamic_extent>{0, 0}, array{1, 1}, false);
check(dextents<int, 2>{0, 0}, array{1, 2}, false);
check(extents<int, 1, dynamic_extent>{0}, array{1, 2}, false);

// rank() is greater than 2
check(extents<int, 2, 3, 5>{}, array{1, 2, 6}, true);
check(extents<int, dynamic_extent, 3, 5>{2}, array{1, 10, 2}, true);
check(extents<int, 2, 3, dynamic_extent>{5}, array{3, 1, 6}, true);
check(extents<int, dynamic_extent, dynamic_extent, 5>{2, 3}, array{15, 1, 3}, true);
check(extents<int, 2, dynamic_extent, dynamic_extent>{3, 5}, array{5, 10, 1}, true);
check(dextents<int, 3>{2, 3, 5}, array{15, 5, 1}, true);
check(extents<int, 2, 5, 8>{}, array{1, 2, 12}, false);
check(extents<int, 2, dynamic_extent, 8>{5}, array{8, 18, 1}, false);
check(dextents<int, 3>{2, 5, 8}, array{5, 1, 12}, false);

// rank() is greater than 2 and some extents are equal to 0
check(extents<int, 2, 0, 7>{}, array{7, 14, 1}, true);
check(extents<int, dynamic_extent, 0, 7>{2}, array{1, 14, 2}, true);
check(extents<int, 2, dynamic_extent, 7>{0}, array{14, 28, 1}, false);
check(extents<int, 2, dynamic_extent, dynamic_extent>{0, 7}, array{1, 2, 2}, false);
check(dextents<int, 3>{2, 0, 7}, array{2, 28, 4}, false);
check(extents<int, 5, 0, 0>{}, array{3, 1, 1}, false);
check(extents<int, 5, dynamic_extent, 0>{0}, array{1, 5, 1}, false);
check(dextents<int, 3>{5, 0, 0}, array{2, 1, 10}, false);
check(extents<int, 0, 0, 0>{}, array{1, 1, 1}, false);
check(extents<int, 0, 1, 1>{}, array{1, 1, 1}, true);

// rank() is greater than 2 - one extent is equal to 0 while others are equal to each other
check(extents<int, 3, 0, 3>{}, array{1, 9, 3}, true);
check(extents<int, dynamic_extent, 0, 3>{3}, array{3, 9, 1}, true);
check(extents<int, 3, dynamic_extent, dynamic_extent>{0, 3}, array{1, 3, 3}, false);
check(dextents<int, 3>{3, 0, 3}, array{1, 4, 8}, false);
check(dextents<int, 3>{0, 1, 1}, array{1, 1, 1}, true);

// required_span_size() is equal to 1
check(extents<int, 1>{}, array{1}, true);
check(dextents<int, 1>{1}, array{3}, false);
check(extents<int, 1, dynamic_extent>{1}, array{1, 1}, true);
check(extents<int, 1, 1, 1>{}, array{1, 2, 1}, false);

// Mapping is exhaustive, but is_exhaustive() should return false because of the way standard defined this function
check(extents<int, 3, 1>{}, array{1, 4}, false);
check(dextents<int, 3>{5, 1, 2}, array{2, 11, 1}, false);
check(dextents<int, 3>{2, 3, 1}, array{3, 1, 8}, false);
check(extents<int, 1, dynamic_extent, 7>{6}, array{50, 7, 1}, false);
check(dextents<int, 2>{1, 2}, array{5, 1}, false);
check(extents<int, 6, 1>{}, array{1, 10}, false);
check(dextents<int, 3>{2, 1, 2}, array{3, 3, 1}, false);
}

constexpr void check_call_operator() {
Expand Down