Skip to content

Commit 215c9f0

Browse files
amitsabne1The ml_dtypes Authors
authored andcommitted
[XLA] Start adding S1 and U1 as data types
PiperOrigin-RevId: 706755357
1 parent 401ed6a commit 215c9f0

File tree

2 files changed

+35
-12
lines changed

2 files changed

+35
-12
lines changed

ml_dtypes/include/intn.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ struct intN {
241241
}
242242
};
243243

244+
using int1 = intN<1, int8_t>;
244245
using int2 = intN<2, int8_t>;
246+
using uint1 = intN<1, uint8_t>;
245247
using uint2 = intN<2, uint8_t>;
246248
using int4 = intN<4, int8_t>;
247249
using uint4 = intN<4, uint8_t>;
@@ -295,6 +297,12 @@ struct intN_numeric_limits_base {
295297

296298
namespace std {
297299

300+
template <>
301+
struct numeric_limits<ml_dtypes::int1>
302+
: public ml_dtypes::internal::intN_numeric_limits_base<ml_dtypes::int1> {};
303+
template <>
304+
struct numeric_limits<ml_dtypes::uint1>
305+
: public ml_dtypes::internal::intN_numeric_limits_base<ml_dtypes::uint1> {};
298306
template <>
299307
struct numeric_limits<ml_dtypes::int2>
300308
: public ml_dtypes::internal::intN_numeric_limits_base<ml_dtypes::int2> {};

ml_dtypes/tests/intn_test.cc

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct IntNTestParamNames {
5858
}
5959
};
6060

61-
using IntNTypes = ::testing::Types<int2, uint2, int4, uint4>;
61+
using IntNTypes = ::testing::Types<int1, uint1, int2, uint2, int4, uint4>;
6262
TYPED_TEST_SUITE(IntNTest, IntNTypes, IntNTestParamNames);
6363

6464
TEST(IntNTest, NumericLimits) {
@@ -69,6 +69,13 @@ TEST(IntNTest, NumericLimits) {
6969
EXPECT_EQ(static_cast<int>(std::numeric_limits<int4>::lowest()), -8);
7070
EXPECT_EQ(std::numeric_limits<int4>::digits, 3);
7171
EXPECT_EQ(std::numeric_limits<int4>::digits10, 0);
72+
EXPECT_EQ(std::numeric_limits<int1>::is_signed, true);
73+
EXPECT_EQ(std::numeric_limits<int1>::is_modulo, false);
74+
EXPECT_EQ(static_cast<int>(std::numeric_limits<int1>::min()), -1);
75+
EXPECT_EQ(static_cast<int>(std::numeric_limits<int1>::max()), 0);
76+
EXPECT_EQ(static_cast<int>(std::numeric_limits<int1>::lowest()), -1);
77+
EXPECT_EQ(std::numeric_limits<int1>::digits, 0);
78+
EXPECT_EQ(std::numeric_limits<int1>::digits10, 0);
7279
}
7380

7481
TEST(UIntNTest, NumericLimits) {
@@ -79,6 +86,13 @@ TEST(UIntNTest, NumericLimits) {
7986
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint4>::lowest()), 0);
8087
EXPECT_EQ(std::numeric_limits<uint4>::digits, 4);
8188
EXPECT_EQ(std::numeric_limits<uint4>::digits10, 1);
89+
EXPECT_EQ(std::numeric_limits<uint1>::is_signed, false);
90+
EXPECT_EQ(std::numeric_limits<uint1>::is_modulo, true);
91+
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint1>::min()), 0);
92+
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint1>::max()), 1);
93+
EXPECT_EQ(static_cast<int>(std::numeric_limits<uint4>::lowest()), 0);
94+
EXPECT_EQ(std::numeric_limits<uint1>::digits, 1);
95+
EXPECT_EQ(std::numeric_limits<uint1>::digits10, 0);
8296
}
8397

8498
TYPED_TEST(IntNTest, NumericLimitsBase) {
@@ -141,7 +155,7 @@ struct ConstexprEvaluator {
141155
};
142156

143157
// To avoid warnings about unused left-side of comma expressions,
144-
// we additionally pass the expression through a contexpr function.
158+
// we additionally pass the expression through a constexpr function.
145159
template <typename T>
146160
constexpr void ConstexprEvaluatorFunc(T&&) {}
147161

@@ -211,8 +225,8 @@ TYPED_TEST(IntNTest, Casts) {
211225
}
212226

213227
// Implicit conversion to optional.
214-
std::optional<int64_t> c = IntN(1);
215-
EXPECT_EQ(c, 1);
228+
std::optional<int64_t> c = IntN(0);
229+
EXPECT_EQ(c, 0);
216230

217231
// Loop through all valid values.
218232
for (int i = static_cast<int>(std::numeric_limits<IntN>::min());
@@ -329,17 +343,18 @@ struct CustomInt {
329343
int x;
330344
};
331345

332-
#define GEN_DEST_TYPES(Type) \
333-
std::pair<Type, bool>, std::pair<Type, uint2>, std::pair<Type, uint4>, \
334-
std::pair<Type, uint8_t>, std::pair<Type, uint16_t>, \
335-
std::pair<Type, uint32_t>, std::pair<Type, uint64_t>, \
336-
std::pair<Type, int2>, std::pair<Type, int4>, std::pair<Type, int8_t>, \
337-
std::pair<Type, int16_t>, std::pair<Type, int32_t>, \
346+
#define GEN_DEST_TYPES(Type) \
347+
std::pair<Type, bool>, std::pair<Type, uint1>, std::pair<Type, uint2>, \
348+
std::pair<Type, uint4>, std::pair<Type, uint8_t>, \
349+
std::pair<Type, uint16_t>, std::pair<Type, uint32_t>, \
350+
std::pair<Type, uint64_t>, std::pair<Type, int1>, std::pair<Type, int2>, \
351+
std::pair<Type, int4>, std::pair<Type, int8_t>, \
352+
std::pair<Type, int16_t>, std::pair<Type, int32_t>, \
338353
std::pair<Type, int64_t>, std::pair<Type, CustomInt>
339354

340355
#define GEN_TYPE_PAIRS() \
341-
GEN_DEST_TYPES(int2), GEN_DEST_TYPES(uint2), GEN_DEST_TYPES(int4), \
342-
GEN_DEST_TYPES(uint4)
356+
GEN_DEST_TYPES(int1), GEN_DEST_TYPES(uint1), GEN_DEST_TYPES(int2), \
357+
GEN_DEST_TYPES(uint2), GEN_DEST_TYPES(int4), GEN_DEST_TYPES(uint4)
343358

344359
using IntNCastTypePairs = ::testing::Types<GEN_TYPE_PAIRS()>;
345360
template <typename CastPair>

0 commit comments

Comments
 (0)