Skip to content

Revert "Shard several of the most costly targets." #2361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,3 @@ build*/

# Python cache
__pycache__/

.cache/

116 changes: 0 additions & 116 deletions cmake/ShardInstantiation.cmake

This file was deleted.

15 changes: 0 additions & 15 deletions cmake/call_shard.in

This file was deleted.

9 changes: 0 additions & 9 deletions cmake/instantiate_shard.in

This file was deleted.

66 changes: 0 additions & 66 deletions include/ck/utility/filter_tuple.hpp

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -688,6 +688,7 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);

void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NGCDHW,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# XDL_DL_WMMA_KERNELS
set(GROUPED_CONV2D_FWD
add_instance_library(device_grouped_conv2d_fwd_instance
#xdl
# GNHWC, GKYXC, GNHWK
xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
Expand All @@ -19,6 +19,8 @@ set(GROUPED_CONV2D_FWD
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp
Expand All @@ -44,10 +46,12 @@ set(GROUPED_CONV2D_FWD
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp
# NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_mem_intra_instance.cpp
Expand All @@ -67,6 +71,7 @@ set(GROUPED_CONV2D_FWD
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp
# NGCHW, GKCYX, NGKHW
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_2x_instance.cpp
Expand Down Expand Up @@ -100,47 +105,3 @@ set(GROUPED_CONV2D_FWD
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp
)
# Add generated files for sharded instantiations.
include(ShardInstantiation)

set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances
TEMPLATE_FILE xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/xdl
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances
TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.in
NUM_SHARDS 16
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/xdl/comp
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in
NUM_SHARDS 21
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated)
generate_sharded_instantiations(
INSTANCES_NAME device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances
TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in
NUM_SHARDS 21
SRC_LIST GROUPED_CONV2D_FWD
OUTPUT_DIR ${GENERATED_DIR}/xdl/mem
)
add_instance_library(device_grouped_conv2d_fwd_instance ${GROUPED_CONV2D_FWD})
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/utility/filter_tuple.hpp"

namespace ck::tensor_operation::device::instance {

using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances =
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
GKCYX,
Expand All @@ -20,23 +22,19 @@ using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances =
BF16,
PassThrough,
PassThrough,
PassThrough>>>;

// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
template <int Shards, int ShardIndex>
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances_shard([[maybe_unused]]
device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances& instances)
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
ck::util::filter_tuple_by_modulo_t<device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>,
Shards,
ShardIndex>{});
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
NGCHW,
GKCYX,
Empty_Tuple,
NGKHW,
ConvFwdDefault>{});
}

} // namespace ck::tensor_operation::device::instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Loading