Skip to content

Improve fmha_bwd tests performance #2376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 20 additions & 27 deletions example/ck_tile/01_fmha/fmha_bwd.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#include "fmha_bwd.hpp"
#include "ck_tile/host.hpp"
Expand Down Expand Up @@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser)

if(p_drop > 0)
{
p_hp_host_ref.ForEach(
[&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); });
p_dropped_hp_host_ref = p_hp_host_ref;
randval_host_ref.ForEach([&](auto& self, auto idx) {
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
});
ck_tile::reference_batched_dropout(
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) {
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
});
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
}
else
{
p_hp_host_ref.ForEach([&](auto& self, auto idx) {
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
});
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
}

// O = P * V
Expand Down Expand Up @@ -854,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
}

// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) {
AccDataType do_dot_o = 0;
for(int o = 0; o < hdim_v; o++)
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
ck_tile::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
}
self(idx_gmn) = ck_tile::type_convert<AccDataType>(
p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o));
});
ck_tile::make_ParallelTensorFunctor(
[&](auto i0, auto i1, auto i2) {
AccDataType do_dot_o = 0;
for(int o = 0; o < hdim_v; o++)
{
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
}
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
},
ds_hp_host_ref.mDesc.get_lengths()[0],
ds_hp_host_ref.mDesc.get_lengths()[1],
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());

if(use_dbias)
{
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
});
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
}

ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
ds_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
});
ds_lp_host_ref = ds_hp_host_ref.template CopyAsType<GemmDataType>();

// dV = P_drop^T@dO^T
// dV = P^T@dO^T w/o dropout
Expand Down
6 changes: 3 additions & 3 deletions include/ck/library/utility/host_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ struct HostTensorDescriptor
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}

std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
{
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
Expand Down Expand Up @@ -600,12 +600,12 @@ struct Tensor
ck::packed_size_v<ck::remove_cvref_t<T>>];
}

T& operator()(std::vector<std::size_t> idx)
T& operator()(const std::vector<std::size_t>& idx)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
}

const T& operator()(std::vector<std::size_t> idx) const
const T& operator()(const std::vector<std::size_t>& idx) const
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v<ck::remove_cvref_t<T>>];
}
Expand Down
9 changes: 6 additions & 3 deletions include/ck_tile/host/host_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ struct HostTensorDescriptor
* @param iss Vector containing the multi-dimensional indices
* @return The calculated linear offset as a size_t
*/
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
{
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
Expand Down Expand Up @@ -540,9 +540,12 @@ struct HostTensor
return mData[GetOffsetFromMultiIndex(is...)];
}

T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
T& operator()(const std::vector<std::size_t>& idx)
{
return mData[GetOffsetFromMultiIndex(idx)];
}

const T& operator()(std::vector<std::size_t> idx) const
const T& operator()(const std::vector<std::size_t>& idx) const
{
return mData[GetOffsetFromMultiIndex(idx)];
}
Expand Down