Skip to content

[CK_TILE] Blockwise GEMM Pipeline V5 #2360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
Open

[CK_TILE] Blockwise GEMM Pipeline V5 #2360

wants to merge 16 commits into from

Conversation

aledudek
Copy link
Contributor

Proposed changes

Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You changed file mode

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You changed file mode

using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;

static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably the same for A

static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }

// TODO check KRepeat
static constexpr index_t KRepeat = KPerBlock / GetSmemPackA();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it needed?

Comment on lines +392 to +410
auto&& aWindows = PipelineImplBase::GetAWindows(
a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
auto&& bWindows = PipelineImplBase::GetBWindows(
b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);

// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
acopy_dram_type& a_copy_dram_window = aWindows.at(I0);
a_copy_lds_window_type& a_copy_lds_window = aWindows.at(I1);
a_lds_load_tile_distr_type& a_lds_gemm_window = aWindows.at(I2);

// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
bcopy_dram_type& b_copy_dram_window = bWindows.at(I0);
b_copy_lds_window_type& b_copy_lds_window = bWindows.at(I1);
b_lds_load_tile_distr_type& b_lds_gemm_window = bWindows.at(I2);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can change all these acopy_dram_type,a_copy_lds_window_type,a_lds_load_tile_distr_type... to the auto since we get this from tuple

Comment on lines +359 to +390
using acopy_dram_type =
remove_cvref_t<decltype(PipelineImplBase::GetAWindows(a_dram_block_window_tmp,
a_lds_block,
a_lds_load_tile_distr)
.at(I0))>;
using bcopy_dram_type =
remove_cvref_t<decltype(PipelineImplBase::GetBWindows(b_dram_block_window_tmp,
b_lds_block,
b_lds_load_tile_distr)
.at(I0))>;

using a_copy_lds_window_type =
remove_cvref_t<decltype(PipelineImplBase::GetAWindows(a_dram_block_window_tmp,
a_lds_block,
a_lds_load_tile_distr)
.at(I1))>;
using b_copy_lds_window_type =
remove_cvref_t<decltype(PipelineImplBase::GetBWindows(b_dram_block_window_tmp,
b_lds_block,
b_lds_load_tile_distr)
.at(I1))>;

using a_lds_load_tile_distr_type =
remove_cvref_t<decltype(PipelineImplBase::GetAWindows(a_dram_block_window_tmp,
a_lds_block,
a_lds_load_tile_distr)
.at(I2))>;
using b_lds_load_tile_distr_type =
remove_cvref_t<decltype(PipelineImplBase::GetBWindows(b_dram_block_window_tmp,
b_lds_block,
b_lds_load_tile_distr)
.at(I2))>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using acopy_dram_type =
remove_cvref_t<decltype(PipelineImplBase::GetAWindows(a_dram_block_window_tmp,
a_lds_block,
a_lds_load_tile_distr)
.at(I0))>;
using bcopy_dram_type =
remove_cvref_t<decltype(PipelineImplBase::GetBWindows(b_dram_block_window_tmp,
b_lds_block,
b_lds_load_tile_distr)
.at(I0))>;
using a_copy_lds_window_type =
remove_cvref_t<decltype(PipelineImplBase::GetAWindows(a_dram_block_window_tmp,
a_lds_block,
a_lds_load_tile_distr)
.at(I1))>;
using b_copy_lds_window_type =
remove_cvref_t<decltype(PipelineImplBase::GetBWindows(b_dram_block_window_tmp,
b_lds_block,
b_lds_load_tile_distr)
.at(I1))>;
using a_lds_load_tile_distr_type =
remove_cvref_t<decltype(PipelineImplBase::GetAWindows(a_dram_block_window_tmp,
a_lds_block,
a_lds_load_tile_distr)
.at(I2))>;
using b_lds_load_tile_distr_type =
remove_cvref_t<decltype(PipelineImplBase::GetBWindows(b_dram_block_window_tmp,
b_lds_block,
b_lds_load_tile_distr)
.at(I2))>;

Copy link
Contributor

@bartekxk bartekxk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants