Skip to content

support y-direction step length greater than 1 for SimplifiedGenericAttentionMask #2338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
183 changes: 143 additions & 40 deletions include/ck_tile/ops/fmha/block/block_masking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,42 @@ struct GenericAttentionMask

// clang-format off
namespace impl {
template <bool IsMasking_> struct SimplifiedMaskName;
template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
template <bool IsMasking_, bool EnableRatio_> struct SimplifiedMaskName;
template<> struct SimplifiedMaskName<false, false> { static constexpr const char * name = "nomask"; };
template<> struct SimplifiedMaskName<true, false> { static constexpr const char * name = "mask"; };
template<> struct SimplifiedMaskName<false, true> { static constexpr const char * name = "nomask_ratio"; };
template<> struct SimplifiedMaskName<true, true> { static constexpr const char * name = "mask_ratio"; };
}
// clang-format on

// this version only have 2 variation: masking and non-masking
// This is more friendly to codegen (e.g. need generate less kernel)
// ... with the trade-off that may have more instruction in causal mode
template <bool IsMasking_ = true>

// clang-format off
/* y_ratio is used to describe the step length of y-direction changes
in certain performance optimization scenarios like merging seqlen
and qk_head_ratio, for example:

x=1/y=6/y_ratio=2(top-left)
1 * * * * * * *
1 * * * * * * *
1 1 * * * * * *
1 1 * * * * * *
1 1 1 * * * * *
1 1 1 * * * * *

set EnableRatio_=true to enable this feature
*/
// clang-format on
template <bool IsMasking_ = true, bool EnableRatio_ = false>
struct SimplifiedGenericAttentionMask
{
static constexpr bool IsMasking = IsMasking_; // false will disable masking

static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
static constexpr bool EnableRatio = EnableRatio_; // false will disable y-ratio

static constexpr const char* name = impl::SimplifiedMaskName<IsMasking, EnableRatio>::name;

CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
Expand All @@ -260,6 +281,13 @@ struct SimplifiedGenericAttentionMask
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
{
}
CK_TILE_HOST_DEVICE
SimplifiedGenericAttentionMask(
index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_ratio_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_), y_ratio(y_ratio_)
{
y_ratio_mdiv = mdiv{static_cast<uint32_t>(y_ratio_)};
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
Expand All @@ -282,20 +310,46 @@ struct SimplifiedGenericAttentionMask
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}();
if constexpr(!EnableRatio)
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}();

// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();

return ck_tile::make_tuple(x_start, x_end);
return ck_tile::make_tuple(x_start, x_end);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp_offset = -y + i_y + y_ratio;
index_t tmp_div =
static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(tmp_offset)));
index_t tmp = tmp_offset > 0 ? tmp_div : 0; // clamp by zero

return (tmp / XTile) * XTile; // round to tile aligned
}();

// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
uint32_t y_offset = i_y + YTile - 1;
index_t tmp =
min(static_cast<index_t>(y_ratio_mdiv.div(y_offset)) + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();

return ck_tile::make_tuple(x_start, x_end);
}
}
}

Expand Down Expand Up @@ -329,20 +383,40 @@ struct SimplifiedGenericAttentionMask
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t y_start = [&]() {
index_t tmp = max(-x + i_x + 1, 0);
return (tmp / YTile) * YTile; // round to tile aligned
}();

// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t y_end = [&]() {
index_t tmp = min(i_x + XTile - 1 + y, y_total);
return ((tmp + YTile - 1) / YTile) * YTile;
}();

return ck_tile::make_tuple(y_start, y_end);
if constexpr(!EnableRatio)
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t y_start = [&]() {
index_t tmp = max(-x + i_x + 1, 0);
return (tmp / YTile) * YTile; // round to tile aligned
}();

// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t y_end = [&]() {
index_t tmp = min(i_x + XTile - 1 + y, y_total);
return ((tmp + YTile - 1) / YTile) * YTile;
}();

return ck_tile::make_tuple(y_start, y_end);
}
else
{
// get the tile start/end range assum we loop over along Y tile by tile
index_t y_start = [&]() {
index_t tmp = max((-x + i_x + 1) * y_ratio, 0);
return (tmp / YTile) * YTile; // round to tile aligned
}();

// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t y_end = [&]() {
index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total);
return ((tmp + YTile - 1) / YTile) * YTile;
}();

return ck_tile::make_tuple(y_start, y_end);
}
}
}

Expand All @@ -357,10 +431,24 @@ struct SimplifiedGenericAttentionMask
}
else
{
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded

return i_x < x_start || i_x >= x_end || i_y >= y_total;
if constexpr(!EnableRatio)
{
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = min(i_y + x, x_total); // need min in case x is padded
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
else
{
index_t start_tmp = -y + i_y + y_ratio;
index_t start_tmp_div =
static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(start_tmp)));
index_t x_start = start_tmp > 0 ? start_tmp_div : 0; // clamp by zero

uint32_t end_tmp = static_cast<uint32_t>(i_y);
index_t x_end = min(static_cast<index_t>(y_ratio_mdiv.div(end_tmp)) + x,
x_total); // need min in case x is padded
return i_x < x_start || i_x >= x_end || i_y >= y_total;
}
}
}

Expand Down Expand Up @@ -388,17 +476,32 @@ struct SimplifiedGenericAttentionMask
index_t i_y_end = i_y + TileHeight;
// index_t x_end = min(i_y + x, x_total);

bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now

return top_right_edge || bottom_left_edge;
if constexpr(!EnableRatio)
{
bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for
// now
return top_right_edge || bottom_left_edge;
}
else
{
uint32_t y_tmp = static_cast<uint32_t>(i_y);
bool top_right_edge =
i_x_end > min(static_cast<index_t>(y_ratio_mdiv.div(y_tmp)) + x,
x_total); // consider right pad
bool bottom_left_edge =
i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
return top_right_edge || bottom_left_edge;
}
}
}

private:
index_t y, x;
index_t y_total, x_total;
index_t y_ratio = 1;
mdiv y_ratio_mdiv;
};

// TODO: prefer use this function in host code
Expand Down