-
Notifications
You must be signed in to change notification settings - Fork 60
Blackwell matmul scheduler smem epilogue support #4541
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Review updated until commit 6be1759 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
!test |
csrc/scheduler/matmul_hopper+.cpp
Outdated
// TODO: should we rename use_ldst_matrix to use_tma_for_epilogue_input? | ||
bool load_with_tma = params_->use_ldst_matrix; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we should rename use_ldst_matrix
, I will do it in a separate PR if no objection.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is separate from use_smem_epilogue
which for Hopper+ means "use TMA for epilogue stores and loads". If that is true but use_ldst_matrix
is false, then we will use vectorized loads/stores between smem and registers in the epilogue even though we are using TMA to do the gmem<->smem transfers.
I think the most flexible parametrization would allow us to set the following for each epilogue input:
- Whether to use TMA to load from gmem to smem
- Whether to use ldmatrix
and for each output: - Whether to use stmatrix
- Whether to use TMA to store from smem to gmem
In other words we could have this:
struct EpilogueInputConfig {
bool use_tma_load;
bool use_ldmatrix; // used if use_tma_load == true
};
std::vector<EpilogueInputConfig> epilogue_input_configs;
struct OutputConfig {
bool use_tma_store;
bool use_stmatrix; // used if use_tma_store == true
};
std::vector<OutputConfig> output_configs;
The order of these vectors would correspond to the order found in tensor_roles
which is deterministic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is separate from
use_smem_epilogue
which for Hopper+ means "use TMA for epilogue stores and loads". If that is true butuse_ldst_matrix
is false, then we will use vectorized loads/stores between smem and registers in the epilogue even though we are using TMA to do the gmem<->smem transfers.
I think this is not how it is currently implemented.
Here:
Fuser/csrc/scheduler/matmul_hopper+.cpp
Lines 788 to 790 in 8ea060c
// Schedule TMA load into shared memory for epilogue input | |
c_cache->definition()->as<LoadStoreOp>()->setOpType( | |
LoadStoreOpType::CpAsyncBulkTensorTile); |
We only use TMA to load epilogue input if use_ldst_matrix
is true, which seems weird to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only use TMA to load epilogue input if use_ldst_matrix is true, which seems weird to me.
You're right. We should change that condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed use_ldst_matrix
from blackwell scheduler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
constexpr int64_t ldst_matrix_tile_m = 16; | ||
constexpr int64_t ldst_matrix_tile_n = 16; | ||
fusion_->manage("ldst_matrix_m_tile", ldst_matrix_tile_m); | ||
fusion_->manage("ldst_matrix_n_tile", ldst_matrix_tile_n); | ||
fusion_->manage("ldst_matrix_m_smem", params_->tile_sizes.warp_tile.m); | ||
fusion_->manage("ldst_matrix_n_smem", params_->tile_sizes.warp_tile.n); | ||
|
||
// Apply LdMatrix to any epilogue inputs loaded to smem with TMA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this unused?
csrc/scheduler/matmul_hopper+.cpp
Outdated
@@ -926,6 +923,122 @@ void HopperPlus::scheduleEpilogueWithSmemEpilogue() { | |||
} | |||
} | |||
|
|||
constexpr int64_t hardcoded_smem_vectorize_factor = 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to move this constant to the top or more central location?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this and hardcoded_blackwell_splitk_vectorization_factor
up.
!test |
Example kernel for `General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/MN_512_256_128_MmaMacro_m128_n128_k16_tma_store`