Skip to content

Blackwell matmul scheduler smem epilogue support #4541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f0d3866
blackwell smem epilogue
zasdfgbnm May 29, 2025
e4cd526
save
zasdfgbnm May 29, 2025
defa155
unskip
zasdfgbnm May 29, 2025
e0be8db
save
zasdfgbnm May 29, 2025
d8a41d0
errmsg
zasdfgbnm May 29, 2025
378a25f
save
zasdfgbnm May 29, 2025
4bd5d83
save
zasdfgbnm May 29, 2025
61b5324
save
zasdfgbnm May 29, 2025
0905925
save
zasdfgbnm May 29, 2025
232d8b2
save
zasdfgbnm May 29, 2025
9b7f2dd
save
zasdfgbnm May 29, 2025
9af0554
cleanup print
zasdfgbnm May 29, 2025
e961520
save
zasdfgbnm May 29, 2025
e4fe3f7
save
zasdfgbnm May 29, 2025
1073fe1
save
zasdfgbnm May 29, 2025
5809090
save
zasdfgbnm May 29, 2025
2937b01
save
zasdfgbnm May 29, 2025
b9a80c9
fix
zasdfgbnm May 29, 2025
f248b29
save
zasdfgbnm May 29, 2025
07209b1
save
zasdfgbnm May 29, 2025
ad007cc
save
zasdfgbnm May 29, 2025
a902814
save
zasdfgbnm May 29, 2025
e15f235
save
zasdfgbnm May 29, 2025
abd3bef
save
zasdfgbnm May 29, 2025
2dd0fba
save
zasdfgbnm May 30, 2025
0500579
try1
zasdfgbnm May 30, 2025
a14764f
try
zasdfgbnm May 30, 2025
9983de7
try
zasdfgbnm May 30, 2025
685f4bd
save
zasdfgbnm May 30, 2025
b6a80de
save
zasdfgbnm May 30, 2025
721bc09
save
zasdfgbnm May 30, 2025
2dc4f7c
save
zasdfgbnm May 30, 2025
2c7abf5
save
zasdfgbnm May 30, 2025
6f64440
save
zasdfgbnm May 30, 2025
b1078fd
save
zasdfgbnm May 30, 2025
8e8c28f
save
zasdfgbnm May 30, 2025
20fee8b
save
zasdfgbnm May 30, 2025
8e00e7d
save doc
zasdfgbnm May 30, 2025
09985a3
try
zasdfgbnm May 30, 2025
1965ea7
save
zasdfgbnm May 30, 2025
a5f4020
move up
zasdfgbnm Jun 3, 2025
c57e3fa
remove ldmatrix
zasdfgbnm Jun 3, 2025
6be1759
Merge branch 'main' into smem-epilogue
zasdfgbnm Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 116 additions & 7 deletions csrc/scheduler/matmul_hopper+.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ namespace schedule_matmul {

namespace {

constexpr int64_t hardcoded_smem_vectorize_factor = 4;
constexpr int64_t hardcoded_blackwell_splitk_vectorization_factor = 4;

// Find the first MatmulDimRole from left to right in a vector of roles
int64_t findFirstRole(
std::vector<MatmulDimRole>& roles,
Expand Down Expand Up @@ -649,6 +652,7 @@ int64_t HopperPlus::getLdTMemVectorizeFactor() const {

void HopperPlus::scheduleEpilogueWithoutSmemEpilogueBlackwell() {
const bool has_splitk = params_->splitk_factor != 1;
int64_t tmem_vectorize_factor = getLdTMemVectorizeFactor();
std::vector<TensorView*> cached_tvs;
std::vector<TensorView*> propagate_to =
splitk_sums_.empty() ? mma_results_ : splitk_sums_;
Expand All @@ -674,7 +678,6 @@ void HopperPlus::scheduleEpilogueWithoutSmemEpilogueBlackwell() {
// vectorize the TMem load with a factor of v (tmem_vectorize_factor).
// [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v (Vectorize)]
d->axis(-2)->parallelize(ParallelType::TIDx);
int64_t tmem_vectorize_factor = getLdTMemVectorizeFactor();
if (tmem_vectorize_factor < getN(params_->mma_macro)) {
d->split(-1, tmem_vectorize_factor);
}
Expand Down Expand Up @@ -767,17 +770,14 @@ void HopperPlus::scheduleEpilogueWithoutSmemEpilogue() {
}
}

void HopperPlus::scheduleEpilogueWithSmemEpilogue() {
void HopperPlus::scheduleEpilogueWithSmemEpilogueHopper() {
constexpr int64_t ldst_matrix_tile_m = 16;
constexpr int64_t ldst_matrix_tile_n = 16;
fusion_->manage("ldst_matrix_m_tile", ldst_matrix_tile_m);
fusion_->manage("ldst_matrix_n_tile", ldst_matrix_tile_n);
fusion_->manage("ldst_matrix_m_smem", params_->tile_sizes.warp_tile.m);
fusion_->manage("ldst_matrix_n_smem", params_->tile_sizes.warp_tile.n);

// Apply LdMatrix to any epilogue inputs loaded to smem with TMA.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this unused?

std::vector<TensorView*> tma_load_epilogue_inputs;

// Propagate to (not including) the splitk output if there is a splitk
// else this is just mma_results_
std::vector<TensorView*> propagate_to =
Expand Down Expand Up @@ -930,6 +930,117 @@ void HopperPlus::scheduleEpilogueWithSmemEpilogue() {
}
}

void HopperPlus::scheduleEpilogueWithSmemEpilogueBlackwell() {
const bool has_splitk = params_->splitk_factor != 1;
int64_t tmem_vectorize_factor = getLdTMemVectorizeFactor();

std::vector<TensorView*> tmem_ld_tvs =
!has_splitk ? createTMemLoad() : std::vector<TensorView*>{};

// Propagate to (not including) the splitk output if there is a splitk
// else this is just mma_results_
std::vector<TensorView*> register_tvs;
std::vector<TensorView*> propagate_to =
splitk_sums_.empty() ? mma_results_ : splitk_sums_;
for (auto& [c, c_cache] : cached_epilogue_inputs_) {
bool is_2d_epilogue_input =
TensorDomain::noBroadcasts(c_cache->domain()->logical()).size() == 2;
if (is_2d_epilogue_input && params_->async_gmem_load_operands) {
// Schedule TMA load into shared memory for epilogue input
c_cache->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);
c_cache->setMemoryType(MemoryType::Shared);
blockTileTensors({c_cache});
parallelizeBlocks({c_cache});
transformLikeMmaOutputWithoutK(c_cache);
c_cache->setAllocationDomain(c_cache->getLoopDomain(), true);
for (int64_t i = -5; i <= -1; i++) {
c_cache->axis(i)->parallelize(ParallelType::Bulk);
}

// Schedule smem->register load for epilogue input
TensorView* reg_tv = cacheAfter(c_cache);
register_tvs.push_back(reg_tv);
blockTileTensors({reg_tv});
parallelizeBlocks({reg_tv});
transformLikeMmaOutputWithoutK(reg_tv);
}
// Propagate changes to the cache_after tensor
propagate_to.push_back(c);
}

// TMem load is scheduled separately, so don't propagate to it.
propagate_to.insert(
propagate_to.end(), tmem_ld_tvs.begin(), tmem_ld_tvs.end());

// The chain of operations storing data to global memory:
// dc (registers) -> d_smem -> [tma_store] -> d (gmem)
// We schedule d_smem and propagate it back.
for (Val* dv : fusion_->outputs()) {
TensorView* d = dv->as<TensorView>();
NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
TensorView* dc = d->definition()->input(0)->as<TensorView>();
TensorView* d_smem = cacheBefore(d, LoadStoreOpType::Set);
dc->setMemoryType(MemoryType::Local);
d_smem->setMemoryType(MemoryType::Shared);

// We schedule the epilogue like:
// (v = tmem_vectorize_factor, vv = smem_vectorize_factor
// [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v/vv, vv]
blockTileTensors({d, d_smem});
parallelizeBlocks({d, d_smem});
for (auto tv : {d, d_smem}) {
transformLikeMmaOutputWithoutK(tv);
tv->axis(-2)->parallelize(ParallelType::TIDx);
if (tmem_vectorize_factor < getN(params_->mma_macro)) {
tv->split(-1, tmem_vectorize_factor);
}
}
if (tmem_vectorize_factor > hardcoded_smem_vectorize_factor) {
d_smem->split(-1, hardcoded_smem_vectorize_factor);
}

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
d_smem,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());

d_smem->axis(-1)->parallelize(ParallelType::Vectorize);
d_smem->setAllocationDomain(d_smem->getLoopDomain(), true);

// Schedule global memory output; Output from TMA Store
d->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);
for (int64_t i = -5; i <= -1; i++) {
d->axis(i)->parallelize(ParallelType::Bulk);
}
}

// Schedule TMem load as:
// (v = tmem_vectorize_factor)
// [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v (Vectorize)]
blockTileTensors(tmem_ld_tvs);
parallelizeBlocks(tmem_ld_tvs);
for (TensorView* tmem_ld_tv : tmem_ld_tvs) {
transformLikeMmaOutputWithoutK(tmem_ld_tv);
tmem_ld_tv->axis(-2)->parallelize(ParallelType::TIDx);
if (tmem_vectorize_factor < getN(params_->mma_macro)) {
tmem_ld_tv->split(-1, tmem_vectorize_factor);
}
tmem_ld_tv->axis(-1)->parallelize(ParallelType::Vectorize);
}
}

void HopperPlus::scheduleEpilogueWithSmemEpilogue() {
if (isHopper(params_->mma_macro)) {
scheduleEpilogueWithSmemEpilogueHopper();
} else {
scheduleEpilogueWithSmemEpilogueBlackwell();
}
}

void HopperPlus::scheduleEpilogue() {
if (params_->use_smem_epilogue) {
scheduleEpilogueWithSmemEpilogue();
Expand All @@ -954,8 +1065,6 @@ void HopperPlus::scheduleSplitKSumHopper() {
}
}

constexpr int64_t hardcoded_blackwell_splitk_vectorization_factor = 4;

// Schedule TMem load tv and splitk_sum tv as follows:
// v = vectorization factor for TMem load
// vv = vectorization factor for splitk_sum, hardcoded to 4
Expand Down
2 changes: 2 additions & 0 deletions csrc/scheduler/matmul_hopper+.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ class HopperPlus : public Common {
void scheduleEpilogueWithoutSmemEpilogueHopper();
void scheduleEpilogueWithoutSmemEpilogueBlackwell();
void scheduleEpilogueWithoutSmemEpilogue();
void scheduleEpilogueWithSmemEpilogueHopper();
void scheduleEpilogueWithSmemEpilogueBlackwell();
void scheduleEpilogueWithSmemEpilogue();
void scheduleEpilogue();

Expand Down
5 changes: 3 additions & 2 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,9 @@ TensorView* TensorView::split(int64_t axis, Val* factor, bool inner_split) {

NVF_CHECK(
this->axis(axis)->getParallelType() == ParallelType::Serial,
"Splitting an axis of non-Serial parallel type is not supported at this "
"time."
"Splitting an axis (",
this->axis(axis)->toString(),
") of non-Serial parallel type is not supported at this time."
" Parallelization strategy must be set after calling split: ",
toString());

Expand Down
3 changes: 0 additions & 3 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3330,9 +3330,6 @@ class HopperPlusMatmulSchedulerTest
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0);
} else {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(10, 0, 11, 0);
if (use_smem_epilogue) {
GTEST_SKIP() << "TMA store is not supported for Blackwell yet.";
}
}

if (a_k_inner) {
Expand Down