Skip to content

Shard several of the most costly targets. #2266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,6 @@ build*/

# Python cache
__pycache__/

.cache/

116 changes: 116 additions & 0 deletions cmake/ShardInstantiation.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Function to generate templated instantiation functions and caller function.

# In order to reduce build times, we split the instantiation of template functions into multiple files.
# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions,
# which can be placed the TEMPLATE_FILE (typically a .in file).

# This CMake function generates the instantiation functions and a caller function that calls all the instantiation
# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of
# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard,
# and generates a caller function that calls all the instantiation functions.

# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation
# of the template functions in the caller function, and that code is automatically generated by this function.

# In addition to the user-supplied template, this CMake function uses two generic templates:
#
# 1. `instantiate_shard.in`: This is the template for the instantiation functions.
# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions.

# This function takes the following arguments:
#
# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`).
# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions.
# - NUM_SHARDS: The number of shards to generate.
# - OUTPUT_DIR: The build directory where the generated source files will be placed.
# - SRC_LIST: The list of source files to which the generated source files will be added.


function(generate_sharded_instantiations)
cmake_parse_arguments(
GEN_SHARDED
# No boolean arguments
""
# Single-value arguments
"INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST"
# No multi-value arguments.
""
${ARGN}
)
if (NOT GEN_SHARDED_INSTANCES_NAME)
message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations")
endif()
if (NOT GEN_SHARDED_TEMPLATE_FILE)
message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations")
endif()
if (NOT GEN_SHARDED_NUM_SHARDS)
message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations")
endif()
if(NOT GEN_SHARDED_OUTPUT_DIR)
message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations")
endif()
if (NOT GEN_SHARDED_SRC_LIST)
message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations")
endif()

file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR})


set(GENERATED_SOURCE_FILES "")
set(EXTERN_TEMPLATE_STATEMENTS "")
set(CALL_STATEMENTS "")
message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}")

set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}")

# Generate the inc file with the template function defintions.
# This include file will hold the template function definitions and a using alias for all the shard
# instantiation functions.
configure_file(
"${GEN_SHARDED_TEMPLATE_FILE}"
"${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc"
@ONLY
)

# Generate the sharded instantiation functions.
# This is where the build parallelization happens.
# Each of these source files will contain a single instantiation function for a shard,
# which will be called sequentially by the caller function.
set(INC_DIR "${GEN_SHARDED_INC_DIR}")
math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1")
foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID})
set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}")
set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp")
set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in")
configure_file(
"${SHARD_FUNCTION_TEMPLATE}"
"${SHARD_FUNCTION_PATH}"
@ONLY
)
list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}")
set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>")
list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)")
list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)")
endforeach()

# Join the include statements, the extern template declarations, and the call statements each
# into a single string for variable substitution in the caller function.
string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}")
string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}")
string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}")

# Generate the caller function.
set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp")
set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in")
configure_file(
"${FUNCTION_TEMPLATE}"
"${CALLER_FUNCTION_PATH}"
@ONLY
)
list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}")

# Add the generated source files to the list of source files.
# This allows the generated source files to be included in the build.
list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES})
set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE)
endfunction()
15 changes: 15 additions & 0 deletions cmake/call_shard.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include "@[email protected]"

namespace ck::tensor_operation::device::instance {

@EXTERN_TEMPLATE_STATEMENTS@;

void add_@INSTANCES@(
@INSTANCES@& instances) {
@CALL_STATEMENTS@;
}

} // namespace ck::tensor_operation::device::instance
9 changes: 9 additions & 0 deletions cmake/instantiate_shard.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include "@[email protected]"

namespace ck::tensor_operation::device::instance {
template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>(
@INSTANCES@& instances);
} // namespace ck::tensor_operation::device::instance
66 changes: 66 additions & 0 deletions include/ck/utility/filter_tuple.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <tuple>
#include <type_traits>
#include <utility>

#include "ck/utility/functional.hpp"
#include "ck/utility/sequence.hpp"

namespace ck::util {

template <typename Tuple, std::size_t Stride, std::size_t Offset>
struct filter_tuple_by_modulo
{
// Validate Stride and Offset.
static_assert(Stride > 0, "Offset must be positive.");
static_assert(Offset >= 0 && Offset < Stride,
"Offset must be positive and less than the stride.");

// Generate filtered indices for this stride and offset.
static constexpr int new_size = (std::tuple_size_v<Tuple> + Stride - Offset - 1) / Stride;

template <std::size_t... Is>
static constexpr auto to_index(std::index_sequence<Is...>)
{
return std::index_sequence<(Offset + Is * Stride)...>{};
}

using filtered_indices = decltype(to_index(std::make_index_sequence<new_size>{}));

// Helper struct to construct the new tuple type from the filtered indices.
template <typename T, typename Indices>
struct make_filtered_tuple_type_impl;

template <typename T, std::size_t... Is>
struct make_filtered_tuple_type_impl<T, std::index_sequence<Is...>>
{
using type = std::tuple<std::tuple_element_t<Is, T>...>;
};

using type = typename make_filtered_tuple_type_impl<Tuple, filtered_indices>::type;
};

// Filter a tuple with a stride and offset.
//
// Tuple is a std::tuple or equivalent
// Stride is a positive integer
// Offset is a positive integer smaller than ofset
//
// Evaluates to a smaller tuple type from elements of T with stride M and offset I.
//
// Can be used to filter a tuple of types for sharded instantiations.
template <typename Tuple, std::size_t Stride, std::size_t Offset>
using filter_tuple_by_modulo_t = typename filter_tuple_by_modulo<Tuple, Stride, Offset>::type;

// Example compile-time test:
// using OriginalTuple =
// std::tuple<int, double, char, float, long, short, bool, char, long long, unsigned int>;
// using NewTuple_Every3rdFrom2nd = filter_tuple_by_modulo_t<OriginalTuple, 3, 1>;
// static_assert(std::is_same_v<NewTuple_Every3rdFrom2nd, std::tuple<double, long, char>>,
// "Test Case 1 Failed: Every 3rd from 2nd");

} // namespace ck::util
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -688,7 +688,6 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);

void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NGCDHW,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# XDL_DL_WMMA_KERNELS
add_instance_library(device_grouped_conv2d_fwd_instance
set(GROUPED_CONV2D_FWD
#xdl
# GNHWC, GKYXC, GNHWK
xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
Expand All @@ -19,8 +19,6 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp
Expand All @@ -46,12 +44,10 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp
# NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instance.cpp
Expand All @@ -71,7 +67,6 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instance.cpp
Expand Down Expand Up @@ -105,3 +100,47 @@ add_instance_library(device_grouped_conv2d_fwd_instance
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp
)
# Add generated files for sharded instantiations.
include(ShardInstantiation)

set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in
NUM_SHARDS 21
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in
NUM_SHARDS 21
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
add_instance_library(device_grouped_conv2d_fwd_instance ${GROUPED_CONV2D_FWD})
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/utility/filter_tuple.hpp"

namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances(
namespace ck::tensor_operation::device::instance {

using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances =
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKCYX,
Expand All @@ -22,19 +20,23 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances(
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
PassThrough>>>;

// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
template <int Shards, int ShardIndex>
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances_shard([[maybe_unused]]
device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>,
Shards,
ShardIndex>{});
}

} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
} // namespace ck::tensor_operation::device::instance
Loading