-
Notifications
You must be signed in to change notification settings - Fork 199
Shard several of the most costly targets. #2266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
2d74123
Shard several of the most costly targets.
shumway e62f7c7
fix clang format
illsilin 1a8db6e
Fix build errors in instantiation code.
shumway 0dd29e8
Fix link errors from mismatched declarations.
shumway 52c368d
Migrate the design to a code-generation approach.
shumway fa71c45
Shard the longest 2D convolution builds
shumway 77c6945
Use PROJECT_SOURCE_DIR for submodule compatibility
shumway 1f79ac7
Merge branch 'develop' into shumway/refactor_targets
shumway File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,3 +68,6 @@ build*/ | |
|
||
# Python cache | ||
__pycache__/ | ||
|
||
.cache/ | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Function to generate templated instantiation functions and caller function. | ||
|
||
# In order to reduce build times, we split the instantiation of template functions into multiple files. | ||
# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions, | ||
# which can be placed the TEMPLATE_FILE (typically a .in file). | ||
|
||
# This CMake function generates the instantiation functions and a caller function that calls all the instantiation | ||
# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of | ||
# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard, | ||
# and generates a caller function that calls all the instantiation functions. | ||
|
||
# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation | ||
# of the template functions in the caller function, and that code is automatically generated by this function. | ||
|
||
# In addition to the user-supplied template, this CMake function uses two generic templates: | ||
# | ||
# 1. `instantiate_shard.in`: This is the template for the instantiation functions. | ||
# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions. | ||
|
||
# This function takes the following arguments: | ||
# | ||
# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`). | ||
# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions. | ||
# - NUM_SHARDS: The number of shards to generate. | ||
# - OUTPUT_DIR: The build directory where the generated source files will be placed. | ||
# - SRC_LIST: The list of source files to which the generated source files will be added. | ||
|
||
|
||
function(generate_sharded_instantiations) | ||
cmake_parse_arguments( | ||
GEN_SHARDED | ||
# No boolean arguments | ||
"" | ||
# Single-value arguments | ||
"INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST" | ||
# No multi-value arguments. | ||
"" | ||
${ARGN} | ||
) | ||
if (NOT GEN_SHARDED_INSTANCES_NAME) | ||
message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations") | ||
endif() | ||
if (NOT GEN_SHARDED_TEMPLATE_FILE) | ||
message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations") | ||
endif() | ||
if (NOT GEN_SHARDED_NUM_SHARDS) | ||
message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations") | ||
endif() | ||
if(NOT GEN_SHARDED_OUTPUT_DIR) | ||
message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations") | ||
endif() | ||
if (NOT GEN_SHARDED_SRC_LIST) | ||
message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations") | ||
endif() | ||
|
||
file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR}) | ||
|
||
|
||
set(GENERATED_SOURCE_FILES "") | ||
set(EXTERN_TEMPLATE_STATEMENTS "") | ||
set(CALL_STATEMENTS "") | ||
message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}") | ||
|
||
set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}") | ||
|
||
# Generate the inc file with the template function defintions. | ||
# This include file will hold the template function definitions and a using alias for all the shard | ||
# instantiation functions. | ||
configure_file( | ||
"${GEN_SHARDED_TEMPLATE_FILE}" | ||
"${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc" | ||
@ONLY | ||
) | ||
|
||
# Generate the sharded instantiation functions. | ||
# This is where the build parallelization happens. | ||
# Each of these source files will contain a single instantiation function for a shard, | ||
# which will be called sequentially by the caller function. | ||
set(INC_DIR "${GEN_SHARDED_INC_DIR}") | ||
math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1") | ||
foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID}) | ||
set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}") | ||
set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp") | ||
set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in") | ||
configure_file( | ||
"${SHARD_FUNCTION_TEMPLATE}" | ||
"${SHARD_FUNCTION_PATH}" | ||
@ONLY | ||
) | ||
list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}") | ||
set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>") | ||
list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)") | ||
list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)") | ||
endforeach() | ||
|
||
# Join the include statements, the extern template declarations, and the call statements each | ||
# into a single string for variable substitution in the caller function. | ||
string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}") | ||
string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}") | ||
string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}") | ||
|
||
# Generate the caller function. | ||
set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp") | ||
set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in") | ||
configure_file( | ||
"${FUNCTION_TEMPLATE}" | ||
"${CALLER_FUNCTION_PATH}" | ||
@ONLY | ||
) | ||
list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}") | ||
|
||
# Add the generated source files to the list of source files. | ||
# This allows the generated source files to be included in the build. | ||
list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES}) | ||
set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE) | ||
endfunction() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#include "@[email protected]" | ||
|
||
namespace ck::tensor_operation::device::instance { | ||
|
||
@EXTERN_TEMPLATE_STATEMENTS@; | ||
|
||
void add_@INSTANCES@( | ||
@INSTANCES@& instances) { | ||
@CALL_STATEMENTS@; | ||
} | ||
|
||
} // namespace ck::tensor_operation::device::instance |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#include "@[email protected]" | ||
|
||
namespace ck::tensor_operation::device::instance { | ||
template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>( | ||
@INSTANCES@& instances); | ||
} // namespace ck::tensor_operation::device::instance |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#pragma once | ||
|
||
#include <tuple> | ||
#include <type_traits> | ||
#include <utility> | ||
|
||
#include "ck/utility/functional.hpp" | ||
#include "ck/utility/sequence.hpp" | ||
|
||
namespace ck::util { | ||
|
||
template <typename Tuple, std::size_t Stride, std::size_t Offset> | ||
struct filter_tuple_by_modulo | ||
{ | ||
// Validate Stride and Offset. | ||
static_assert(Stride > 0, "Offset must be positive."); | ||
static_assert(Offset >= 0 && Offset < Stride, | ||
"Offset must be positive and less than the stride."); | ||
|
||
// Generate filtered indices for this stride and offset. | ||
static constexpr int new_size = (std::tuple_size_v<Tuple> + Stride - Offset - 1) / Stride; | ||
|
||
template <std::size_t... Is> | ||
static constexpr auto to_index(std::index_sequence<Is...>) | ||
{ | ||
return std::index_sequence<(Offset + Is * Stride)...>{}; | ||
} | ||
|
||
using filtered_indices = decltype(to_index(std::make_index_sequence<new_size>{})); | ||
|
||
// Helper struct to construct the new tuple type from the filtered indices. | ||
template <typename T, typename Indices> | ||
struct make_filtered_tuple_type_impl; | ||
|
||
template <typename T, std::size_t... Is> | ||
struct make_filtered_tuple_type_impl<T, std::index_sequence<Is...>> | ||
{ | ||
using type = std::tuple<std::tuple_element_t<Is, T>...>; | ||
}; | ||
|
||
using type = typename make_filtered_tuple_type_impl<Tuple, filtered_indices>::type; | ||
}; | ||
|
||
// Filter a tuple with a stride and offset. | ||
// | ||
// Tuple is a std::tuple or equivalent | ||
// Stride is a positive integer | ||
// Offset is a positive integer smaller than ofset | ||
// | ||
// Evaluates to a smaller tuple type from elements of T with stride M and offset I. | ||
// | ||
// Can be used to filter a tuple of types for sharded instantiations. | ||
template <typename Tuple, std::size_t Stride, std::size_t Offset> | ||
using filter_tuple_by_modulo_t = typename filter_tuple_by_modulo<Tuple, Stride, Offset>::type; | ||
|
||
// Example compile-time test: | ||
// using OriginalTuple = | ||
// std::tuple<int, double, char, float, long, short, bool, char, long long, unsigned int>; | ||
// using NewTuple_Every3rdFrom2nd = filter_tuple_by_modulo_t<OriginalTuple, 3, 1>; | ||
// static_assert(std::is_same_v<NewTuple_Every3rdFrom2nd, std::tuple<double, long, char>>, | ||
// "Test Case 1 Failed: Every 3rd from 2nd"); | ||
|
||
} // namespace ck::util |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.