Skip to content

Blackwell matmul scheduler smem epilogue support #4541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Jun 3, 2025
Merged

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented May 29, 2025

Example kernel for `General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/MN_512_256_128_MmaMacro_m128_n128_k16_tma_store`
// Codegen generated code
__device__ __inline__ void fenceAsyncProxy() {
  asm volatile("fence.proxy.async;\n");
}
__device__ __inline__ void cpAsyncBulkCommitGroup() {
  asm volatile("cp.async.bulk.commit_group;\n");
}
template <int64_t in0>
__device__ __inline__ void cpAsyncBulkWaitGroup() {
  asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(in0):"memory");
}

namespace tcgen05 {
__device__ __inline__ void alloc(uint32_t in0, uint32_t in1) {
  asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;\n"::"r"(in0), "r"(in1));
}
__device__ __inline__ void relinquishAllocPermit() {
  asm volatile("tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;\n");
}
__device__ __inline__ void mma_f16(uint32_t in0, uint64_t in1, uint64_t in2, uint32_t in3, bool in4) {
  asm volatile(
    "{\n"
    "  .reg .pred p0; \n"
    "  setp.ne.b32 p0, %4, 0;\n"
    "  tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p0;\n"
    "}\n"
    :
    :"r"(in0),
     "l"(in1),
     "l"(in2),
     "r"(in3),
     "r"((uint32_t)(in4))
  );
}
__device__ __inline__ void commit(uint32_t in0) {
  asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];\n"::"r"(in0));
}
__device__ __inline__ void load32x32b(Array<float, 128, 1>& out0, uint32_t in0) {
  asm(
    "tcgen05.ld.sync.aligned.32x32b.x128.b32 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, [%128];\n"
    :"=f"(out0[0]),
     "=f"(out0[1]),
     "=f"(out0[2]),
     "=f"(out0[3]),
     "=f"(out0[4]),
     "=f"(out0[5]),
     "=f"(out0[6]),
     "=f"(out0[7]),
     "=f"(out0[8]),
     "=f"(out0[9]),
     "=f"(out0[10]),
     "=f"(out0[11]),
     "=f"(out0[12]),
     "=f"(out0[13]),
     "=f"(out0[14]),
     "=f"(out0[15]),
     "=f"(out0[16]),
     "=f"(out0[17]),
     "=f"(out0[18]),
     "=f"(out0[19]),
     "=f"(out0[20]),
     "=f"(out0[21]),
     "=f"(out0[22]),
     "=f"(out0[23]),
     "=f"(out0[24]),
     "=f"(out0[25]),
     "=f"(out0[26]),
     "=f"(out0[27]),
     "=f"(out0[28]),
     "=f"(out0[29]),
     "=f"(out0[30]),
     "=f"(out0[31]),
     "=f"(out0[32]),
     "=f"(out0[33]),
     "=f"(out0[34]),
     "=f"(out0[35]),
     "=f"(out0[36]),
     "=f"(out0[37]),
     "=f"(out0[38]),
     "=f"(out0[39]),
     "=f"(out0[40]),
     "=f"(out0[41]),
     "=f"(out0[42]),
     "=f"(out0[43]),
     "=f"(out0[44]),
     "=f"(out0[45]),
     "=f"(out0[46]),
     "=f"(out0[47]),
     "=f"(out0[48]),
     "=f"(out0[49]),
     "=f"(out0[50]),
     "=f"(out0[51]),
     "=f"(out0[52]),
     "=f"(out0[53]),
     "=f"(out0[54]),
     "=f"(out0[55]),
     "=f"(out0[56]),
     "=f"(out0[57]),
     "=f"(out0[58]),
     "=f"(out0[59]),
     "=f"(out0[60]),
     "=f"(out0[61]),
     "=f"(out0[62]),
     "=f"(out0[63]),
     "=f"(out0[64]),
     "=f"(out0[65]),
     "=f"(out0[66]),
     "=f"(out0[67]),
     "=f"(out0[68]),
     "=f"(out0[69]),
     "=f"(out0[70]),
     "=f"(out0[71]),
     "=f"(out0[72]),
     "=f"(out0[73]),
     "=f"(out0[74]),
     "=f"(out0[75]),
     "=f"(out0[76]),
     "=f"(out0[77]),
     "=f"(out0[78]),
     "=f"(out0[79]),
     "=f"(out0[80]),
     "=f"(out0[81]),
     "=f"(out0[82]),
     "=f"(out0[83]),
     "=f"(out0[84]),
     "=f"(out0[85]),
     "=f"(out0[86]),
     "=f"(out0[87]),
     "=f"(out0[88]),
     "=f"(out0[89]),
     "=f"(out0[90]),
     "=f"(out0[91]),
     "=f"(out0[92]),
     "=f"(out0[93]),
     "=f"(out0[94]),
     "=f"(out0[95]),
     "=f"(out0[96]),
     "=f"(out0[97]),
     "=f"(out0[98]),
     "=f"(out0[99]),
     "=f"(out0[100]),
     "=f"(out0[101]),
     "=f"(out0[102]),
     "=f"(out0[103]),
     "=f"(out0[104]),
     "=f"(out0[105]),
     "=f"(out0[106]),
     "=f"(out0[107]),
     "=f"(out0[108]),
     "=f"(out0[109]),
     "=f"(out0[110]),
     "=f"(out0[111]),
     "=f"(out0[112]),
     "=f"(out0[113]),
     "=f"(out0[114]),
     "=f"(out0[115]),
     "=f"(out0[116]),
     "=f"(out0[117]),
     "=f"(out0[118]),
     "=f"(out0[119]),
     "=f"(out0[120]),
     "=f"(out0[121]),
     "=f"(out0[122]),
     "=f"(out0[123]),
     "=f"(out0[124]),
     "=f"(out0[125]),
     "=f"(out0[126]),
     "=f"(out0[127])
    :"r"(in0)
  );
}
__device__ __inline__ void waitLoad() {
  asm volatile("tcgen05.wait::ld.sync.aligned;\n");
}
__device__ __inline__ void dealloc(uint32_t in0, uint32_t in1) {
  asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;\n"::"r"(in0), "r"(in1));
}
} // namespace tcgen05
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__half, 3, 3> T0, Tensor<__half, 3, 3> T1, const __grid_constant__ TensorMap var0, const __grid_constant__ TensorMap var1, const __grid_constant__ TensorMap var2, Tensor<__half, 2, 2> T3) {
  alignas(16) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  nvfuser_index_t i3;
  i3 = ceilDiv(T0.logical_size[0LL], 32);
  nvfuser_index_t i4;
  i4 = -1 + i3;
  nvfuser_index_t i5;
  i5 = i4 % 2;
  uint32_t i6;
  i6 = (uint32_t)(((i4 / 2) % 2));
  const TensorMap* ptr7;
  ptr7 = &var0;
  nvfuser_index_t i8;
  i8 = 128 * ((nvfuser_index_t)blockIdx.x);
  __half* T5 = reinterpret_cast<__half*>(array + smem_offset + 33792);
  uint32_t i9;
  i9 = toSmem(T5);
  const TensorMap* ptr10;
  ptr10 = &var1;
  nvfuser_index_t i11;
  i11 = 256 * ((nvfuser_index_t)blockIdx.y);
  __half* T4 = reinterpret_cast<__half*>(array + smem_offset + 1024);
  uint32_t i12;
  i12 = toSmem(T4);
  nvfuser_index_t i13;
  i13 = 8192 * ((nvfuser_index_t)threadIdx.y);
  uint32_t i14;
  i14 = i12 + i13;
  nvfuser_index_t i15;
  i15 = 128 * ((nvfuser_index_t)threadIdx.y);
  uint16_t i16;
  i16 = (uint16_t)(i15);
  Array<uint16_t, 2, 1> a17;
  a17 = Array<uint16_t, 2, 1>{0, i16};
  uint32_t i18;
  i18 = (i12 + (16384 * i5)) + i13;
  uint32_t i19;
  i19 = i9 + (8192 * i5);
  nvfuser_index_t i20;
  i20 = (128 * ((nvfuser_index_t)threadIdx.x)) + (16384 * ((nvfuser_index_t)threadIdx.y));
  bool b21;
  b21 = (((nvfuser_index_t)threadIdx.x) + (((nvfuser_index_t)threadIdx.y) * 128)) < 32LL;
  bool b22;
  b22 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
  bool b23;
  b23 = ((nvfuser_index_t)threadIdx.y) == 0ULL;
  __half* T8 = reinterpret_cast<__half*>(array + smem_offset + 128);
  uint32_t* T9 = reinterpret_cast<uint32_t*>(array + smem_offset + 0);
  if (b21) {
    tcgen05::alloc((uint32_t)(toSmem(T9)), 256U);
  }
  if (b21) {
    tcgen05::relinquishAllocPermit();
  }
  __syncthreads();
  TMemTensor T2(T9[0], 0, (uint16_t)(0));
  uint64_t* T11 = reinterpret_cast<uint64_t*>(array + smem_offset + 16);
  #pragma unroll
  for(nvfuser_index_t i24 = 0; i24 < 2; ++i24) {
    if (((Hopper::electSync(4294967295U) && b22) && b23)) {
      mbarrier::init(toSmem((&T11[i24])), 2U);
    }
  }
  __syncthreads();
  if (((Hopper::electSync(4294967295U) && b22) && b23)) {
    mbarrier::arriveExpectTX(toSmem((&T11[0])), 8192U);
    #pragma unroll
    for(nvfuser_index_t i25 = 0; i25 < 2; ++i25) {
      Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr7, (Array<int, 2, 1>{(int32_t)((i8 + (64 * i25))), 0}), toSmem((&T11[0])) }), (i9 + (4096 * i25)));
    }
    mbarrier::arriveExpectTX(toSmem((&T11[0])), 16384U);
    #pragma unroll
    for(nvfuser_index_t i26 = 0; i26 < 4; ++i26) {
      Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr10, (Array<int, 2, 1>{(int32_t)((i11 + (64 * i26))), 0}), toSmem((&T11[0])) }), (i12 + (4096 * i26)));
    }
  }
  #pragma unroll 1
  for(nvfuser_index_t i27 = 0; i27 < i4; ++i27) {
    int i28;
    i28 = (int32_t)((32 + (32 * i27)));
    nvfuser_index_t i29;
    i29 = (1 + i27) % 2;
    uint32_t i30;
    i30 = i9 + (8192 * i29);
    uint32_t i31;
    i31 = i12 + (16384 * i29);
    nvfuser_index_t i32;
    i32 = i27 % 2;
    uint32_t i33;
    i33 = i14 + (16384 * i32);
    uint32_t i34;
    i34 = i9 + (8192 * i32);
    if (((Hopper::electSync(4294967295U) && b22) && b23)) {
      mbarrier::arriveExpectTX(toSmem((&T11[((1LL + i27) % 2)])), 8192U);
      #pragma unroll
      for(nvfuser_index_t i25 = 0; i25 < 2; ++i25) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr7, (Array<int, 2, 1>{(int32_t)((i8 + (64 * i25))), i28}), toSmem((&T11[((1LL + i27) % 2)])) }), (i30 + (4096 * i25)));
      }
      mbarrier::arriveExpectTX(toSmem((&T11[((1LL + i27) % 2)])), 16384U);
      #pragma unroll
      for(nvfuser_index_t i26 = 0; i26 < 4; ++i26) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr10, (Array<int, 2, 1>{(int32_t)((i11 + (64 * i26))), i28}), toSmem((&T11[((1LL + i27) % 2)])) }), (i31 + (4096 * i26)));
      }
    }
    mbarrier::waitParity(toSmem((&T11[(i27 % 2)])), (uint32_t)(((i27 / 2) % 2)));
    #pragma unroll
    for(nvfuser_index_t i35 = 0; i35 < 2; ++i35) {
      nvfuser_index_t i36;
      i36 = 2048 * i35;
      uint32_t i37;
      i37 = i33 + i36;
      uint32_t i38;
      i38 = i34 + i36;
      uint64_t* T10 = reinterpret_cast<uint64_t*>(array + smem_offset + 50176);
      mbarrier::init(toSmem(T10), 1U);
      __syncthreads();
      if ((Hopper::electSync(4294967295U) && b22)) {
        tcgen05::mma_f16((uint32_t)(T2 + a17), (4611756662066249728ULL | ((262143ULL & (uint64_t)(i37)) >> 4ULL)), (4611756662066249728ULL | ((262143ULL & (uint64_t)(i38)) >> 4ULL)), 136413200U, (!((i27 == 0) && (i35 == 0))));
        tcgen05::commit(toSmem(T10));
      }
      mbarrier::waitParity(toSmem((&T10[0])), 0U);
      __syncthreads();
      mbarrier::inval(toSmem(T10));
    }
  }
  mbarrier::waitParity(toSmem((&T11[i5])), i6);
  #pragma unroll
  for(nvfuser_index_t i35 = 0; i35 < 2; ++i35) {
    nvfuser_index_t i39;
    i39 = 2048 * i35;
    uint32_t i40;
    i40 = i18 + i39;
    uint32_t i41;
    i41 = i19 + i39;
    uint64_t* T10 = reinterpret_cast<uint64_t*>(array + smem_offset + 50176);
    mbarrier::init(toSmem(T10), 1U);
    __syncthreads();
    if ((Hopper::electSync(4294967295U) && b22)) {
      tcgen05::mma_f16((uint32_t)(T2 + a17), (4611756662066249728ULL | ((262143ULL & (uint64_t)(i40)) >> 4ULL)), (4611756662066249728ULL | ((262143ULL & (uint64_t)(i41)) >> 4ULL)), 136413200U, true);
      tcgen05::commit(toSmem(T10));
    }
    mbarrier::waitParity(toSmem((&T10[0])), 0U);
    __syncthreads();
    mbarrier::inval(toSmem(T10));
  }
  #pragma unroll
  for(nvfuser_index_t i42 = 0; i42 < 2; ++i42) {
    if (((Hopper::electSync(4294967295U) && b22) && b23)) {
      mbarrier::inval(toSmem((&T11[i42])));
    }
  }
  Array<float, 128, 128> T7;
  tcgen05::load32x32b((*reinterpret_cast<Array<float, 128, 1>*>(&T7[0])), (uint32_t)(T2 + (Array<uint16_t, 2, 1>{(uint16_t)((32LL * (((nvfuser_index_t)threadIdx.x) / 32LL))), i16})));
  tcgen05::waitLoad();
  __syncthreads();
  if (((((nvfuser_index_t)threadIdx.x) + i15) < 32LL)) {
    tcgen05::dealloc(T9[0], 256U);
  }
  #pragma unroll
  for(nvfuser_index_t i43 = 0; i43 < 32; ++i43) {
    nvfuser_index_t i44;
    i44 = 4 * i43;
    Array<__half, 4, 4> T6;
    #pragma unroll
    for(nvfuser_index_t i45 = 0; i45 < 4; ++i45) {
      T6[i45]
         = __float2half(T7[(i44 + i45)]);
    }
    loadGeneric<__half, 4>( &T8[(i20 + i44)],  &T6[0]);
  }
  __syncthreads();
  fenceAsyncProxy();
  if ((b22 && b23)) {
    Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ (&var2), (Array<int, 2, 1>{(int32_t)(i8), (int32_t)(i11)}) }), toSmem(T8));
  }
  cpAsyncBulkCommitGroup();
  cpAsyncBulkWaitGroup<0LL>();
}

Copy link

github-actions bot commented May 29, 2025

Review updated until commit 6be1759

Description

  • Added support for shared memory epilogue in Blackwell.

  • Refactored scheduleEpilogueWithSmemEpilogue to handle both Hopper and Blackwell.

  • Enhanced error message in TensorView::split.

  • Removed skipped tests for TMA store in Blackwell.


Changes walkthrough 📝

Relevant files
Enhancement
matmul_hopper+.cpp
Add Blackwell-specific shared memory epilogue scheduling 

csrc/scheduler/matmul_hopper+.cpp

  • Added hardcoded constants for vectorization factors.
  • Refactored scheduleEpilogueWithoutSmemEpilogueBlackwell to use
    tmem_vectorize_factor.
  • Added scheduleEpilogueWithSmemEpilogueBlackwell for Blackwell-specific
    scheduling.
  • Refactored scheduleEpilogueWithSmemEpilogue to handle both Hopper and
    Blackwell.
  • +116/-7 
    tensor_view.cpp
    Improve error message in TensorView::split                             

    csrc/tensor_view.cpp

  • Enhanced error message in TensorView::split to include axis
    information.
  • +3/-2     
    matmul_hopper+.h
    Add declarations for Blackwell-specific scheduling functions

    csrc/scheduler/matmul_hopper+.h

  • Added declarations for scheduleEpilogueWithSmemEpilogueHopper and
    scheduleEpilogueWithSmemEpilogueBlackwell.
  • +2/-0     
    Tests
    test_matmul_scheduler.cpp
    Remove skipped TMA store tests for Blackwell                         

    tests/cpp/test_matmul_scheduler.cpp

    • Removed skipped tests for TMA store in Blackwell.
    +0/-3     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Performance Goals

    The PR description does not provide specific performance goals or metrics. It is important to clearly state the expected performance improvements and include data to support these claims.

    namespace schedule_matmul {
    
    namespace {
    
    constexpr int64_t hardcoded_smem_vectorize_factor = 4;
    constexpr int64_t hardcoded_blackwell_splitk_vectorization_factor = 4;
    
    // Find the first MatmulDimRole from left to right in a vector of roles
    int64_t findFirstRole(
        std::vector<MatmulDimRole>& roles,
        MatmulDimRole role_to_find) {
      auto role_iter =
          std::find_if(roles.begin(), roles.end(), [&](MatmulDimRole role) {
            return role == role_to_find;
          });
      if (role_iter == roles.end()) {
        return -1;
      }
      return std::distance(roles.begin(), role_iter);
    }
    
    } // namespace
    
    void HopperPlus::transformLikeMmaOutputWithK(TensorView* tv) {
      NVF_ERROR(tv->axis(-1)->isReduction(), "Inner axis should be Reduction.");
      // The input is originally block tiled so that the inner dims are the CTA tile
      // size
      //
      // We split this into warp tiles then instruction tiles
      // Original: [..., M, N, K]
      tv->split(-3, params_->tile_sizes.warp_tile.m);
      tv->split(-3, getM(params_->mma_macro));
      tv->split(-2, params_->tile_sizes.warp_tile.n);
      tv->split(-2, getN(params_->mma_macro));
      // K dimension is present for mma_result
      // We don't need to split by warp_tile.k, since we always have
      // cta_tile.k == warp_tile.k
      tv->split(-1, getK(params_->mma_macro));
      // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
      tv->reorder({
          {-8, -8}, // Mo
          {-7, -6}, // Mw
          {-6, -3}, // Mi
          {-5, -7}, // No
          {-4, -5}, // Nw
          {-3, -2}, // Ni
          {-2, -4}, // Kw
          {-1, -1}, // Ki
      });
      // After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
      tv->merge(-8);
      // After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
      if (isCooperative()) {
        tv->axis(-7)->parallelize(ParallelType::TIDy);
        // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
      }
    }
    
    void HopperPlus::transformLikeMmaOutputWithoutK(TensorView* tv) {
      NVF_ERROR(
          tv->domain()->loop().size() >= 4,
          "transformLikeMmaOutputWithoutK requires at least four iterDomains but ",
          tv->toString(),
          " only has ",
          tv->domain()->loop().size(),
          ".");
      NVF_ERROR(
          !tv->axis(-1)->isReduction(), "Inner axis should not be Reduction.");
    
      // The input is originally block tiled so that the inner dims are the CTA tile
      // size
      // Original: [..., M, N]
      // We split this into warp tiles then instruction tiles
      tv->split(-2, params_->tile_sizes.warp_tile.m);
      tv->split(-2, getM(params_->mma_macro));
      tv->split(-1, params_->tile_sizes.warp_tile.n);
      tv->split(-1, getN(params_->mma_macro));
      // After Split: [..., Mo, Mw, Mi, No, Nw, Ni]
      tv->reorder({
          {-3, -5},
          {-2, -3},
      });
      // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]
      tv->merge(-6);
      // After Merge: [..., Mo * No, Mw, Nw, Mi, Ni]
      if (isCooperative()) {
        tv->axis(-5)->parallelize(ParallelType::TIDy);
        // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
      }
    }
    
    MatmulDimRole HopperPlus::findMatmulDimRole(IterDomain* id) {
      ValGroup vg = graph_->toGroup(id);
      auto it = id_roles_.find(vg);
      NVF_ERROR(it != id_roles_.end());
      return it->second;
    }
    
    void HopperPlus::validate() const {
      const auto device_prop = at::cuda::getCurrentDeviceProperties();
      const int cc = device_prop->major * 10 + device_prop->minor;
      NVF_ERROR(
          cc >= 90, "This matmul scheduler is restricted to Hopper & Blackwell.");
    
      if (params_->tiling_strategy != MatmulParams::TilingStrategy::OneTilePerCTA) {
        NVF_CHECK(
            params_->splitk_factor == 1,
            "Hopper+ matmul scheduler does not support scheduling persistent "
            "split-K kernels");
      }
    
      NVF_CHECK(
          params_->tiling_strategy !=
              MatmulParams::TilingStrategy::DistributeStagesAcrossSMs,
          "Hopper+ matmul scheduler does not support distributing stages across "
          "SMs a la stream-K");
    
      NVF_CHECK(
          isCooperative(),
          "Hopper+ matmul scheduler only supports cooperatively buffering at the "
          "CTA level (no ping-pong)");
      if (isCooperative()) {
        NVF_CHECK(
            params_->tile_sizes.cta_tile.m % params_->tile_sizes.warp_tile.m == 0,
            "Expected m dimension for cta_tile to be divisble by warp_tile.");
        NVF_CHECK(
            params_->tile_sizes.cta_tile.n % params_->tile_sizes.warp_tile.n == 0,
            "Expected m dimension for cta_tile to be divisble by warp_tile.");
        NVF_CHECK(
            params_->tile_sizes.cta_tile.k % params_->tile_sizes.warp_tile.k == 0,
            "Expected m dimension for cta_tile to be divisble by warp_tile.");
      } else if (isPingPong()) {
        NVF_CHECK(
            params_->tile_sizes.cta_tile == params_->tile_sizes.warp_tile,
            "Expected cta_tile and warp_tile to be the same for Ping-Pong Matmul "
            "Kernels");
      }
    }
    
    void HopperPlus::run() {
      // Finds matmul patterns and translates them to MmaOps, then finds tensor
      // and dimension roles for all tensors in the fusion
      findPatterns();
      translatePatterns();
      // We use the tensor roles to cache operands and epilogue inputs differently
      findRoles();
    
      // Clears memory spaces on intermediate tensors, calls
      // cache{After,Before,Fork} on inputs and outputs.
      // Defines acw_smem/bcw_smem and acr/bcr by possibly calling cacheAfter.
      cacheInputsAndOutputs(/*skip_intermediates=*/true);
    
      // We need to find roles again after caching, since we will need to rebuild
      // the IdModel.
      // TODO: update the val graph on the fly in cacheInputsAndOutputs using
      // cacheAfter and missing cacheFork and cacheBefore utilities instead of doing
      // a full rebuild here
      findRoles();
    
      inspectPrologues();
    
      setCGADims();
    
      scheduleOperands();
    
      // schedule mma instruction output (mma_result)
      scheduleMmaResults();
    
      // schedule epilogue
      scheduleEpilogue();
    
      // schedule splitk_sum
      scheduleSplitKSum();
    
      setUpInlining();
    
      // set up circular buffering. This must come after everything up to
      // mma_result is scheduled, since everything in the main loop will need to
      // be rotated
      setUpCircularBuffering();
    }
    
    void HopperPlus::reorderBlockTileTraversal(
        TensorView* tv,
        std::vector<MatmulDimRole>& outer_dim_roles) {
      NVF_ERROR(params_->grid_traversal_factor.first >= 1);
      NVF_ERROR(params_->grid_traversal_factor.second >= 1);
    
      // short-circuit: If grid traversal factor is 1x1, we don't need to reorder.
      if (params_->grid_traversal_factor.first == 1 &&
          params_->grid_traversal_factor.second == 1) {
        return;
      }
    
      // Find position of outer M and N dims in schedule_.tiled
      int64_t Mo_pos = findFirstRole(outer_dim_roles, MatmulDimRole::M);
      int64_t No_pos = findFirstRole(outer_dim_roles, MatmulDimRole::N);
    
      // Multi-factor grid traversal.
      // M and N roles must be present and consecutive.
      if (params_->grid_traversal_factor.first > 1 &&
          params_->grid_traversal_factor.second > 1) {
        NVF_ERROR(
            Mo_pos >= 0 || No_pos >= 0, "Either M or N role must be present.");
        NVF_ERROR(
            Mo_pos != No_pos, "The position of M and N roles must be different.");
        NVF_ERROR(abs(Mo_pos - No_pos) == 1, "M and N roles must be consecutive.");
    
        bool is_M_present = Mo_pos >= 0;
        bool is_N_present = No_pos >= 0;
        bool is_N_right_of_M = No_pos > Mo_pos;
        const int64_t min_axis_pos = std::min(Mo_pos, No_pos);
    
        // original: [M, N]
        // split:   [M, N/second_factor, second_factor]
        // split: [M/first_factor, first_factor, N/second_factor, second_factor]
        // reorder: [M/first_factor, N/second_factor, first_factor,
        // second_factor]
        // merge:
        // [M/first_factor * N/second_factor, first_factor, second_factor]
        // merge:
        // [M/first_factor * N/second_factor, first_factor * second_factor]
    
        // If N axis exists, then split by second grid traversal factor.
        if (is_N_present) {
          // split:   [M, N/second_factor, second_factor]
          tv->split(No_pos, params_->grid_traversal_factor.second);
        }
        // If N is to the left of M, then shift M by 1 because of second factor.
        if (!is_N_right_of_M) {
          Mo_pos++;
        }
    
        // If M axis exists, then split by first grid traveral factor.
        if (is_M_present) {
          // split: [M/first_factor, first_factor, N/second_factor, second_factor]
          tv->split(Mo_pos, params_->grid_traversal_factor.first);
        }
        // If N is to the right of M, then shift M by 1 because of the first factor.
        if (is_N_right_of_M) {
          No_pos++;
        }
    
        if (is_N_present && is_M_present) {
          NVF_ERROR(min_axis_pos >= 0, "Both M and N roles must exist.");
          // reorder: [M/first_factor, N/second_factor, first_factor,
          // second_factor]
          tv->reorder(
              {{min_axis_pos + 1, min_axis_pos + 2},
               {min_axis_pos + 2, min_axis_pos + 1}});
          // merge:
          // [M/first_factor * N/second_factor, first_factor, second_factor]
          tv->merge(min_axis_pos, min_axis_pos + 1);
          // merge:
          // [M/first_factor * N/second_factor, first_factor * second_factor]
          tv->merge(min_axis_pos + 1, min_axis_pos + 2);
        } else if (is_N_present) {
          // M is missing, so we skip the merge above. In this case we
          // should update the dim roles to reflect the new split axis.
          outer_dim_roles.insert(
              outer_dim_roles.begin() + No_pos, MatmulDimRole::N);
        } else if (is_M_present) {
          // N is missing, so we skip the merge above. In this case we
          // should update the dim roles to reflect the new split axis.
          outer_dim_roles.insert(
              outer_dim_roles.begin() + Mo_pos, MatmulDimRole::M);
        }
        return;
      }
    
      // Single factor grid traversal.
      NVF_ERROR(params_->grid_traversal_factor.first > 1);
      NVF_ERROR(params_->grid_traversal_factor.second == 1);
      int factor = params_->grid_traversal_factor.first;
      switch (params_->cta_order) {
        case MatmulParams::TileRasterizationOrder::ColumnMajor: {
          // split   [I1, I2/factor, factor]
          // reorder [I1, factor, I2/factor]
          // merge   [I1*factor, I2/factor]
          // where I1 and I2 are the outer M and N dimensions, respectively
          if (No_pos >= 0) {
            tv->split(No_pos, factor);
            // If No_pos < Mo_pos, then the split above shifts Mo_pos by one
            if (No_pos < Mo_pos) {
              Mo_pos++;
            }
            tv->reorder({{No_pos, No_pos + 1}});
            if (Mo_pos >= 0) {
              tv->merge(Mo_pos, No_pos);
            } else {
              // M is missing, so we skip the merge above. In this case we
              // should update the dim roles to reflect the new split axis.
              outer_dim_roles.insert(
                  outer_dim_roles.begin() + No_pos, MatmulDimRole::N);
            }
          }
          break;
        }
    
        case MatmulParams::TileRasterizationOrder::RowMajor: {
          // split   [I1/factor, factor, I2]
          // reorder [I1/factor, I2, factor]
          // merge   [I1/factor, I2*factor]
          // where I1 and I2 are the outer M and N dimensions, respectively
          if (Mo_pos >= 0) {
            tv->split(Mo_pos, factor);
            // If No_pos < Mo_pos, then the split above shifts Mo_pos by one
            if (No_pos > Mo_pos) {
              No_pos++;
            }
            if (No_pos >= 0) {
              tv->reorder({{Mo_pos + 1, No_pos}});
              tv->merge(Mo_pos + 1, No_pos);
            } else {
              // N is missing, so we skip the merge above. In this case we
              // should update the dim roles to reflect the new split axis.
              outer_dim_roles.insert(
                  outer_dim_roles.begin() + Mo_pos, MatmulDimRole::M);
            }
          }
          break;
        }
        default:
          NVF_THROW("Invalid TileRasterizationOrder passed to Matmul scheduler");
      }
    }
    
    std::vector<std::vector<MatmulDimRole>> HopperPlus::blockTileTensors(
        const std::vector<TensorView*>& tvs) {
      if (canonical_dim_ordering_.empty()) {
        canonical_dim_ordering_ =
            mma_utils::canonicalDimOrdering(tensor_roles_, id_roles_, *graph_);
      }
    
      std::vector<std::vector<MatmulDimRole>> all_merged_roles;
      for (TensorView* tv : tvs) {
        // Find dimensions in canonical_dim_ordering_ that exist in tv's loop
        // domain. Reorder those according to the canonical dim ordering then
        std::unordered_map<ValGroup, IterDomain*> tv_dims;
        std::unordered_set<MatmulDimRole> axis_roles;
        for (IterDomain* id : tv->getLoopDomain()) {
          ValGroup vg = graph_->toGroup(id);
          tv_dims.emplace(vg, id);
          // track axis roles in this tensor to use in makeTile
          auto it = id_roles_.find(vg);
          NVF_ERROR(it != id_roles_.end());
          axis_roles.insert(it->second);
        }
        std::vector<IterDomain*> new_loop;
        new_loop.reserve(tv->nDims());
        for (const ValGroup& vg : canonical_dim_ordering_) {
          auto it = tv_dims.find(vg);
          if (it != tv_dims.end()) {
            new_loop.push_back(it->second);
          }
        }
        NVF_ERROR((int64_t)new_loop.size() == tv->nDims());
        tv->setLoopDomain(new_loop);
    
        // There could be multiple dimensions with the same role at this point, so
        // now we collect them. After this, tv will be at most 4 dimensions e.g.
        // BMNK based on canonical_dim_ordering_, with any of these dimensions
        // possibly missing.
        mma_utils::mergeConsecutiveAxesWithSameRole(tv, id_roles_, graph_);
    
        // Find order the axes that are present in the merged tensor
        std::vector<MatmulDimRole> merged_roles;
        merged_roles.reserve(tv->nDims());
        for (const ValGroup& vg : canonical_dim_ordering_) {
          MatmulDimRole role = id_roles_[vg];
          if (axis_roles.count(role) != 0) {
            if (merged_roles.empty() || merged_roles.back() != role) {
              merged_roles.push_back(role);
            }
          }
        }
        NVF_ERROR(merged_roles.size() == axis_roles.size());
    
        // TODO: (to be pursued after the multi-matmul refactor is fully merged)
        // this currently creates a separate AbstractMatmulTensor for each
        // TensorView. Instead, we should create a single AbstractMatmulTensor
        // then apply it (with "forwarding") to each TV instead. We already cache
        // a vector<ValGroup> as canonical_dim_ordering_ so AbstractTensor
        // scheduling is the next step in this modernization.
        mma_utils::makeTile(tv, params_->tile_sizes.cta_tile, merged_roles);
    
        reorderBlockTileTraversal(tv, merged_roles);
    
        if (params_->splitk_factor > 1) {
          // Outer K dimension in tv is in same position found in merged_roles
          for (size_t i : arange(merged_roles.size())) {
            if (merged_roles[i] == MatmulDimRole::K) {
              tv->split((int64_t)i, params_->splitk_factor, /*inner*/ false);
            }
          }
        }
    
        // Merge in batch dims to the BIDy dim for non-persistent
        if (params_->tiling_strategy ==
            MatmulParams::TilingStrategy::OneTilePerCTA) {
          if (num_local_batch_dims_ > 0) {
            NVF_ERROR(merged_roles.front() == MatmulDimRole::Batch);
            // Merge batch dim into the dimension that will be parallelized BIDy
            if (params_->cta_order ==
                MatmulParams::TileRasterizationOrder::ColumnMajor) {
              int64_t outer_grid_dim = num_device_dims_ + 2L;
              // [..., Batch, M, N, ...]
              tv->merge(num_device_dims_, outer_grid_dim);
              // [..., Batch*N, M, ...]
              // Now we need to transpose so that Batch*N is to the right of M
              tv->reorder({{num_device_dims_, num_device_dims_ + 1}});
            } else { // row major
              int64_t outer_grid_dim = num_device_dims_ + 1L;
              tv->merge(num_device_dims_, outer_grid_dim);
            }
            merged_roles.erase(merged_roles.begin());
          }
        } else if (
            params_->tiling_strategy ==
            MatmulParams::TilingStrategy::DistributeTilesAcrossSMs) {
          // Persistent kernel scheduling
          if (params_->cta_order ==
              MatmulParams::TileRasterizationOrder::ColumnMajor) {
            tv->reorder(
                {{num_device_and_batch_dims_, num_device_and_batch_dims_ + 1}});
          }
          tv->merge(num_device_and_batch_dims_, num_device_and_batch_dims_ + 1);
    
          if (num_local_batch_dims_ > 0) {
            NVF_ERROR(merged_roles.front() == MatmulDimRole::Batch);
            // Merge batch dims before doing the persistent split
            tv->merge(num_device_dims_);
            merged_roles.erase(merged_roles.begin());
          }
    
          const int64_t num_sms =
              at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
          tv->split(num_device_dims_, num_sms);
        } else {
          NVF_THROW("Unsupported tiling strategy");
        }
    
        all_merged_roles.push_back(merged_roles);
      }
      return all_merged_roles;
    }
    
    void HopperPlus::inspectPrologues() const {
      for (TensorView* mma_result : mma_results_) {
        for (Val* v : mma_result->definition()->inputs()) {
          TensorView* op_input = v->as<TensorView>();
    
          // We currently require all operands to lie in smem, meaning we cannot yet
          // handle any prologue computation. This includes `BroadcastOp` which
          // might be introduced when translating a MatmulOp or LinearOp to MmaOp.
          Expr* def = op_input->definition();
          NVF_ERROR(def != nullptr && def->isA<LoadStoreOp>());
          NVF_ERROR(
              def->input(0)->as<TensorView>()->getMemoryType() ==
              MemoryType::Global);
        }
      }
    }
    
    void HopperPlus::scheduleOperands() {
      NVF_CHECK(
          params_->async_gmem_load_operands,
          "Hopper+ matmul scheduler currently requires TMA to be enabled");
      auto scheduleBranch = [&](const std::vector<TensorView*>& gmem_operands,
                                const std::vector<TensorView*>& smem_operands,
                                MmaOperand operand_type) {
        blockTileTensors(smem_operands);
        parallelizeBlocks(smem_operands);
        for (TensorView* tv : smem_operands) {
          if (params_->promote_prologue_smem_reuse) {
            tv->promoteReuse();
          }
          mma_utils::orderTiledConcreteIdAsMaybeAllocationDomain(tv);
          MmaInputSmemSwizzle swizzle_type = mma_utils::tmaSwizzleSharedMemory(tv);
          tv->applyMmaSwizzleForTMALoad(swizzle_type);
        }
      };
      scheduleBranch(as_, acw_smems_, MmaOperand::A);
      scheduleBranch(bs_, bcw_smems_, MmaOperand::B);
    }
    
    void HopperPlus::parallelizeBlocks(const std::vector<TensorView*>& tvs) const {
      for (TensorView* tv : tvs) {
        switch (params_->tiling_strategy) {
          case MatmulParams::TilingStrategy::OneTilePerCTA:
            // Data-parallel kernels are parallelized BIDx BIDy
            switch (params_->cta_order) {
              // TODO: Should we instead check the roles of these dimensions to take
              // the outermost two M or N axes?
              case MatmulParams::TileRasterizationOrder::ColumnMajor:
                tv->axis(num_device_dims_)->parallelize(ParallelType::BIDx);
                tv->axis(num_device_dims_ + 1)->parallelize(ParallelType::BIDy);
                break;
              case MatmulParams::TileRasterizationOrder::RowMajor:
                tv->axis(num_device_dims_)->parallelize(ParallelType::BIDy);
                tv->axis(num_device_dims_ + 1)->parallelize(ParallelType::BIDx);
                break;
              default:
                NVF_THROW(
                    "Invalid TileRasterizationOrder passed to Matmul scheduler");
            }
            break;
          case MatmulParams::TilingStrategy::DistributeTilesAcrossSMs:
          case MatmulParams::TilingStrategy::DistributeStagesAcrossSMs:
            // For persistent kernels, we just parallelize the SM dimension
            tv->axis(num_device_dims_ + 1)->parallelize(ParallelType::BIDx);
            break;
        }
      }
    }
    
    void HopperPlus::setMmaResultAllocationDomain(TensorView* mma_result) {
      if (isBlackwell(params_->mma_macro)) {
        mma_result->setMemoryType(MemoryType::Tensor);
        // So far, we only support M128 Blackwell MMA macros. For these macros,
        // Rows of the accumulator span all 128 lanes of TMem. That is, the
        // allocation domain should be [Mi, (DimSep), ...other]
        // We want to move Mi to the front of the domain.
        std::vector<IterDomain*> allocation_domain = mma_result->getLoopDomain();
        auto item = allocation_domain[allocation_domain.size() - 3];
        allocation_domain.erase(
            allocation_domain.begin() + allocation_domain.size() - 3);
        allocation_domain.insert(allocation_domain.begin(), item);
        mma_result->setAllocationDomain(allocation_domain, true);
        mma_result->setTMemDimSepPos(1);
      } else {
        auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
            mma_result->getLoopDomain());
        mma_result->setAllocationDomain(s.as<IterDomain*>(), true);
      }
    }
    
    void HopperPlus::scheduleMmaResults() {
      GemmTile instruction_tile = getMmaOpShape(params_->mma_macro);
      NVF_CHECK(
          params_->tile_sizes.cta_tile.k == params_->tile_sizes.warp_tile.k,
          "CTA tile must match warp tile K dimension for Hopper+ matmul but found ",
          toString(params_->tile_sizes));
      // If cta_tile is not divisible by instruction tile the mma instruction will
      // be predicated.
      NVF_CHECK(
          params_->tile_sizes.cta_tile.m % instruction_tile.m == 0 &&
              params_->tile_sizes.cta_tile.n % instruction_tile.n == 0 &&
              params_->tile_sizes.cta_tile.k % instruction_tile.k == 0,
          "CTA tile must be divisible by macro size but found cta_tile: ",
          toString(params_->tile_sizes.cta_tile),
          " and macro: ",
          toString(params_->mma_macro));
    
      // Schedule mma results and propagate forward
      auto all_merged_roles = blockTileTensors(mma_results_);
      parallelizeBlocks(mma_results_);
      for (auto&& [i, mma_result] : enumerate(mma_results_)) {
        const std::vector<MatmulDimRole>& merged_roles = all_merged_roles[i];
    
        // Test that mma_result logical is MNK
        // TODO: This currently checks leaf domain only which does not necessarily
        // match logical
        // TODO: Lift this constraint. Use commitLeafToLogical if necessary. We
        // might just want to match using id_roles_
        NVF_ERROR(merged_roles.size() >= 3);
        const auto checkSingleDimRole =
            [&merged_roles](int64_t pos, MatmulDimRole expected_role) {
              if (pos < 0) {
                pos += (int64_t)merged_roles.size();
              }
              NVF_ERROR(pos >= 0);
              NVF_ERROR(pos < (int64_t)merged_roles.size());
              const auto& actual_role = merged_roles[(size_t)pos];
              NVF_ERROR(actual_role == expected_role);
            };
        checkSingleDimRole(-3, MatmulDimRole::M);
        checkSingleDimRole(-2, MatmulDimRole::N);
        checkSingleDimRole(-1, MatmulDimRole::K);
    
        // do split-K rFactor to define splitk_sum and smem_epilogue
        if (params_->splitk_factor != 1) {
          // Note that the split-K split is already done in blockTileTensors
          TensorView* splitk_sum = mma_result->rFactor({-4, -1});
          std::swap(splitk_sum, mma_result);
          splitk_sums_.push_back(splitk_sum);
        }
    
        transformLikeMmaOutputWithK(mma_result);
        setMmaResultAllocationDomain(mma_result);
    
        mma_result->axis(-1)->parallelize(ParallelType::Mma);
        mma_result->axis(-2)->parallelize(ParallelType::Mma);
        mma_result->axis(-3)->parallelize(ParallelType::Mma);
      }
    }
    
    std::vector<TensorView*> HopperPlus::createTMemLoad() {
      if (!isBlackwell(params_->mma_macro)) {
        return {};
      }
      std::vector<TensorView*> tmem_ld_tvs;
      for (auto mma_result : mma_results_) {
        TensorView* tmem_ld_tv = cacheAfter(mma_result);
        tmem_ld_tv->definition()->as<LoadStoreOp>()->setOpType(
            LoadStoreOpType::LdTMem);
        tmem_ld_tvs.push_back(tmem_ld_tv);
      }
      return tmem_ld_tvs;
    }
    
    int64_t HopperPlus::getLdTMemVectorizeFactor() const {
      const int64_t n_mma = getN(params_->mma_macro);
      int64_t tmem_vectorize_factor = 1;
      while (n_mma % tmem_vectorize_factor == 0 && tmem_vectorize_factor <= 128) {
        tmem_vectorize_factor *= 2;
      }
      return tmem_vectorize_factor / 2;
    }
    
    void HopperPlus::scheduleEpilogueWithoutSmemEpilogueBlackwell() {
      const bool has_splitk = params_->splitk_factor != 1;
      int64_t tmem_vectorize_factor = getLdTMemVectorizeFactor();
      std::vector<TensorView*> cached_tvs;
      std::vector<TensorView*> propagate_to =
          splitk_sums_.empty() ? mma_results_ : splitk_sums_;
      // When there is a split-K, the TMem load happens before split-K sum,
      // when there is no split-K, the TMem load happens in the epilogue.
      std::vector<TensorView*> tmem_ld_tvs =
          !has_splitk ? createTMemLoad() : std::vector<TensorView*>{};
      for (auto& [c, c_cache] : cached_epilogue_inputs_) {
        cached_tvs.push_back(c_cache);
        propagate_to.push_back(c);
      }
      for (Val* dv : fusion_->outputs()) {
        TensorView* d = dv->as<TensorView>();
        NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
    
        // Apply the default scheduling that is common to all register
        // TensorViews after wgmma.
        blockTileTensors({d});
        parallelizeBlocks({d});
        transformLikeMmaOutputWithoutK(d);
    
        // TIDx is 128, so we use it for lanes of the accumulator. Also, we
        // vectorize the TMem load with a factor of v (tmem_vectorize_factor).
        // [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v (Vectorize)]
        d->axis(-2)->parallelize(ParallelType::TIDx);
        if (tmem_vectorize_factor < getN(params_->mma_macro)) {
          d->split(-1, tmem_vectorize_factor);
        }
    
        // TODO: We need to check bank conflicts in this path.
        // Propagate schedule changes back to the outputs of the Mma op.
        scheduler_utils::BoundedDirectionalTransformPropagator::backward(
            d,
            -1,
            propagate_to,
            scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                .propagateParallelType());
    
        // Vectorize the epilogue input load and output store. TMem load can
        // be vectorized to 512 byte, but gmem load/store can only be vectorized
        // to 16 bytes. So we need to further split the last dimension and use
        // multiple vector loads/stores. for each TMem load/store.
        // After split and parallelization:
        // (v = tmem_vectorize_factor, vv = params_->supported_vec_size.epilogue)
        // [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v/vv, vv]
        // TODO: Support vectorization_factor in MatmulParams
        if (tmem_vectorize_factor > params_->supported_vec_size.epilogue) {
          d->split(-1, params_->supported_vec_size.epilogue);
          for (auto c : cached_tvs) {
            bool is_2d_epilogue_input =
                TensorDomain::noBroadcasts(c->domain()->logical()).size() == 2;
            if (is_2d_epilogue_input) {
              c->split(-1, params_->supported_vec_size.epilogue);
            }
          }
        }
        d->axis(-1)->parallelize(ParallelType::Vectorize);
        if (!cached_tvs.empty()) {
          scheduler_utils::parallelizeAllLike(d, -1, cached_tvs);
        }
      }
      // Vectorize the TMem load, if any.
      for (auto tmem_ld_tv : tmem_ld_tvs) {
        tmem_ld_tv->axis(-1)->parallelize(ParallelType::Vectorize);
      }
    }
    
    void HopperPlus::scheduleEpilogueWithoutSmemEpilogueHopper() {
      std::vector<TensorView*> cached_tvs;
      std::vector<TensorView*> propagate_to =
          splitk_sums_.empty() ? mma_results_ : splitk_sums_;
      for (auto& [c, c_cache] : cached_epilogue_inputs_) {
        cached_tvs.push_back(c_cache);
        propagate_to.push_back(c);
      }
      for (Val* dv : fusion_->outputs()) {
        TensorView* d = dv->as<TensorView>();
        NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
    
        // Apply the default scheduling that is common to all register
        // TensorViews after wgmma.
        blockTileTensors({d});
        parallelizeBlocks({d});
        transformLikeMmaOutputWithoutK(d);
    
        const AbstractTensor s =
            mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(d->getLoopDomain());
        d->setLoopDomain(s.as<IterDomain*>());
    
        // TODO: We need to check bank conflicts in this path.
        // Propagate schedule changes back to the outputs of the Mma op.
        scheduler_utils::BoundedDirectionalTransformPropagator::backward(
            d,
            -1,
            propagate_to,
            scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                .propagateParallelType());
    
        // We do not respect the vectorization_factor parameter, but always
        // vectorize the inner-dim with extent 2.
        NVF_ERROR(params_->supported_vec_size.epilogue >= 2);
        // TODO: Support vectorization_factor in MatmulParams
        d->axis(-1)->parallelize(ParallelType::Vectorize);
        if (!cached_tvs.empty()) {
          scheduler_utils::parallelizeAllLike(d, -1, cached_tvs);
        }
      }
    }
    
    void HopperPlus::scheduleEpilogueWithoutSmemEpilogue() {
      if (isBlackwell(params_->mma_macro)) {
        scheduleEpilogueWithoutSmemEpilogueBlackwell();
      } else {
        scheduleEpilogueWithoutSmemEpilogueHopper();
      }
    }
    
    void HopperPlus::scheduleEpilogueWithSmemEpilogueHopper() {
      constexpr int64_t ldst_matrix_tile_m = 16;
      constexpr int64_t ldst_matrix_tile_n = 16;
      fusion_->manage("ldst_matrix_m_tile", ldst_matrix_tile_m);
      fusion_->manage("ldst_matrix_n_tile", ldst_matrix_tile_n);
      fusion_->manage("ldst_matrix_m_smem", params_->tile_sizes.warp_tile.m);
      fusion_->manage("ldst_matrix_n_smem", params_->tile_sizes.warp_tile.n);
    
      // Propagate to (not including) the splitk output if there is a splitk
      // else this is just mma_results_
      std::vector<TensorView*> propagate_to =
          splitk_sums_.empty() ? mma_results_ : splitk_sums_;
      for (auto& [c, c_cache] : cached_epilogue_inputs_) {
        bool load_with_ldmatrix =
            params_->use_ldst_matrix && dataTypeSize(c_cache->dtype()) == 2;
        bool is_2d_epilogue_input =
            TensorDomain::noBroadcasts(c_cache->domain()->logical()).size() == 2;
        if (load_with_ldmatrix && is_2d_epilogue_input &&
            params_->async_gmem_load_operands) {
          // Schedule TMA load into shared memory for epilogue input
          c_cache->definition()->as<LoadStoreOp>()->setOpType(
              LoadStoreOpType::CpAsyncBulkTensorTile);
          c_cache->setMemoryType(MemoryType::Shared);
    
          // Apply the default scheduling that is common to all register
          // TensorViews after wgmma.
          blockTileTensors({c_cache});
          parallelizeBlocks({c_cache});
          transformLikeMmaOutputWithoutK(c_cache);
    
          // Swizzle to avoid shared memory bank conflicts
          MmaInputSmemSwizzle swizzle_type =
              mma_utils::tmaSwizzleSharedMemory(c_cache);
          c_cache->applyMmaSwizzleForTMALoad(swizzle_type);
    
          TensorView* reg_tv = cacheAfter(c_cache);
          reg_tv->definition()->as<LoadStoreOp>()->setOpType(
              LoadStoreOpType::LdMatrix);
    
          // Apply the default scheduling that is common to all register
          // TensorViews after wgmma.
          blockTileTensors({reg_tv});
          parallelizeBlocks({reg_tv});
          transformLikeMmaOutputWithoutK(reg_tv);
    
          // Schedule the loop and allocation domain of LdMatrix like the
          // accumulation register TensorView of wgmma.
          AbstractTensor s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
              reg_tv->getLoopDomain());
          reg_tv->setLoopDomain(s.as<IterDomain*>());
          reg_tv->setAllocationDomain(
              reg_tv->getLoopDomain(), /*new_contiguity=*/true);
    
          // Apply LdStMatrix scheduling to the wgmma loop domain
          mma_utils::scheduleLdStMatrixForMmaOutput(
              reg_tv, ldst_matrix_tile_m, ldst_matrix_tile_n);
    
          // Vectorize last iterDomain because LdMatrix loads all eight values with
          // a single LdMatrix.x4 operation
          reg_tv->axis(-1)->parallelize(ParallelType::Vectorize);
    
          // Do not propagate any other changes to LdMatrix.
          propagate_to.push_back(reg_tv);
        } else {
          // Propagate changes to the cache_after tensor if not using TMA load.
          propagate_to.push_back(c);
        }
      }
    
      // Manually schedule register cache and output TensorView
      for (Val* dv : fusion_->outputs()) {
        TensorView* d = dv->as<TensorView>();
        NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
        TensorView* dc = d->definition()->input(0)->as<TensorView>();
    
        // The chain of operations storing data to global memory:
        //   registers -> (stmatrix) -> smem -> (tma_store) -> gmem
        TensorView* d_smem = cacheBefore(d, LoadStoreOpType::Set);
    
        std::vector<TensorView*> tvs_to_schedule{d, d_smem};
        bool dc_is_mma_result =
            std::find(mma_results_.begin(), mma_results_.end(), dc) !=
            mma_results_.end();
        bool dc_is_splitk_sum = params_->splitk_factor > 1 &&
            std::find(splitk_sums_.begin(), splitk_sums_.end(), dc) !=
                splitk_sums_.end();
    
        if (!dc_is_mma_result && !dc_is_splitk_sum) {
          // Skip scheduling dc if it is an mma_result. This can happen if we are
          // not casting back to half-precision in the output
          tvs_to_schedule.push_back(dc);
        }
    
        // Set MemoryType
        dc->setMemoryType(MemoryType::Local);
        d_smem->setMemoryType(MemoryType::Shared);
    
        // Set LoadStoreOpType
        bool store_with_stmatrix =
            params_->use_ldst_matrix && dataTypeSize(dc->dtype()) == 2;
        if (store_with_stmatrix) {
          d_smem->definition()->as<LoadStoreOp>()->setOpType(
              LoadStoreOpType::StMatrix);
        }
        d->definition()->as<LoadStoreOp>()->setOpType(
            LoadStoreOpType::CpAsyncBulkTensorTile);
    
        // Apply the common transforms to dc, d_smem, d
        // After these transforms we schedule the inner two non-reduction loops
        // (instruction tile) of dc and propagate is back till the outputs of mma.
        blockTileTensors(tvs_to_schedule);
        parallelizeBlocks(tvs_to_schedule);
        for (auto tv : tvs_to_schedule) {
          transformLikeMmaOutputWithoutK(tv);
        }
    
        // Should not propagate if the dc is a mma output as the mma output has
        // already been scheduled.
        if (!dc_is_mma_result && !dc_is_splitk_sum) {
          auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
              dc->getLoopDomain());
          dc->setLoopDomain(s.as<IterDomain*>());
          dc->setAllocationDomain(s.as<IterDomain*>(), true);
    
          scheduler_utils::BoundedDirectionalTransformPropagator::backward(
              dc,
              -1,
              propagate_to,
              scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                  .propagateParallelType());
        }
    
        // Determine swizzle for TMA Store
        MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);
    
        // First, create loop domain that matches wgmma register accumulator using
        // original loop domain.
        const AbstractTensor s =
            mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
                d_smem->getLoopDomain());
        // Create allocation domain with swizzle for TMA Store.
        // This step modifies loop domain and the creates a new allocation domain.
        if (swizzle != MmaInputSmemSwizzle::None) {
          mma_utils::scheduleTMAStoreForMmaOutput(d_smem, swizzle);
        }
        // Finally, set loop domain using saved AbstractTensor.
        d_smem->setLoopDomain(s.as<IterDomain*>());
    
        if (store_with_stmatrix) {
          // Apply LdStMatrix scheduling to the wgmma loop domain
          mma_utils::scheduleLdStMatrixForMmaOutput(
              d_smem, ldst_matrix_tile_m, ldst_matrix_tile_n);
        }
        d_smem->axis(-1)->parallelize(ParallelType::Vectorize);
    
        // Schedule global memory output; Output from TMA Store
        mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle);
      }
    }
    
    void HopperPlus::scheduleEpilogueWithSmemEpilogueBlackwell() {
      const bool has_splitk = params_->splitk_factor != 1;
      int64_t tmem_vectorize_factor = getLdTMemVectorizeFactor();
    
      std::vector<TensorView*> tmem_ld_tvs =
          !has_splitk ? createTMemLoad() : std::vector<TensorView*>{};
    
      // Propagate to (not including) the splitk output if there is a splitk
      // else this is just mma_results_
      std::vector<TensorView*> register_tvs;
      std::vector<TensorView*> propagate_to =
          splitk_sums_.empty() ? mma_results_ : splitk_sums_;
      for (auto& [c, c_cache] : cached_epilogue_inputs_) {
        bool is_2d_epilogue_input =
            TensorDomain::noBroadcasts(c_cache->domain()->logical()).size() == 2;
        if (is_2d_epilogue_input && params_->async_gmem_load_operands) {
          // Schedule TMA load into shared memory for epilogue input
          c_cache->definition()->as<LoadStoreOp>()->setOpType(
              LoadStoreOpType::CpAsyncBulkTensorTile);
          c_cache->setMemoryType(MemoryType::Shared);
          blockTileTensors({c_cache});
          parallelizeBlocks({c_cache});
          transformLikeMmaOutputWithoutK(c_cache);
          c_cache->setAllocationDomain(c_cache->getLoopDomain(), true);
          for (int64_t i = -5; i <= -1; i++) {
            c_cache->axis(i)->parallelize(ParallelType::Bulk);
          }
    
          // Schedule smem->register load for epilogue input
          TensorView* reg_tv = cacheAfter(c_cache);
          register_tvs.push_back(reg_tv);
          blockTileTensors({reg_tv});
          parallelizeBlocks({reg_tv});
          transformLikeMmaOutputWithoutK(reg_tv);
        }
        // Propagate changes to the cache_after tensor
        propagate_to.push_back(c);
      }
    
      // TMem load is scheduled separately, so don't propagate to it.
      propagate_to.insert(
          propagate_to.end(), tmem_ld_tvs.begin(), tmem_ld_tvs.end());
    
      // The chain of operations storing data to global memory:
      //   dc (registers) -> d_smem -> [tma_store] -> d (gmem)
      // We schedule d_smem and propagate it back.
      for (Val* dv : fusion_->outputs()) {
        TensorView* d = dv->as<TensorView>();
        NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
        TensorView* dc = d->definition()->input(0)->as<TensorView>();
        TensorView* d_smem = cacheBefore(d, LoadStoreOpType::Set);
        dc->setMemoryType(MemoryType::Local);
        d_smem->setMemoryType(MemoryType::Shared);
    
        // We schedule the epilogue like:
        // (v = tmem_vectorize_factor, vv = smem_vectorize_factor
        // [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v/vv, vv]
        blockTileTensors({d, d_smem});
        parallelizeBlocks({d, d_smem});
        for (auto tv : {d, d_smem}) {
          transformLikeMmaOutputWithoutK(tv);
          tv->axis(-2)->parallelize(ParallelType::TIDx);
          if (tmem_vectorize_factor < getN(params_->mma_macro)) {
            tv->split(-1, tmem_vectorize_factor);
          }
        }
        if (tmem_vectorize_factor > hardcoded_smem_vectorize_factor) {
          d_smem->split(-1, hardcoded_smem_vectorize_factor);
        }
    
        scheduler_utils::BoundedDirectionalTransformPropagator::backward(
            d_smem,
            -1,
            propagate_to,
            scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                .propagateParallelType());
    
        d_smem->axis(-1)->parallelize(ParallelType::Vectorize);
        d_smem->setAllocationDomain(d_smem->getLoopDomain(), true);
    
        // Schedule global memory output; Output from TMA Store
        d->definition()->as<LoadStoreOp>()->setOpType(
            LoadStoreOpType::CpAsyncBulkTensorTile);
        for (int64_t i = -5; i <= -1; i++) {
          d->axis(i)->parallelize(ParallelType::Bulk);
        }
      }
    
      // Schedule TMem load as:
      // (v = tmem_vectorize_factor)
      // [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v (Vectorize)]
      blockTileTensors(tmem_ld_tvs);
      parallelizeBlocks(tmem_ld_tvs);
      for (TensorView* tmem_ld_tv : tmem_ld_tvs) {
        transformLikeMmaOutputWithoutK(tmem_ld_tv);
        tmem_ld_tv->axis(-2)->parallelize(ParallelType::TIDx);
        if (tmem_vectorize_factor < getN(params_->mma_macro)) {
          tmem_ld_tv->split(-1, tmem_vectorize_factor);
        }
        tmem_ld_tv->axis(-1)->parallelize(ParallelType::Vectorize);
      }
    }
    
    void HopperPlus::scheduleEpilogueWithSmemEpilogue() {
      if (isHopper(params_->mma_macro)) {
        scheduleEpilogueWithSmemEpilogueHopper();
      } else {
        scheduleEpilogueWithSmemEpilogueBlackwell();
      }
    }
    
    void HopperPlus::scheduleEpilogue() {
    Code Duplication

    There is significant code duplication between scheduleEpilogueWithoutSmemEpilogueBlackwell and scheduleEpilogueWithSmemEpilogueBlackwell. Consider refactoring to reduce duplication and improve maintainability.

    void HopperPlus::scheduleEpilogueWithoutSmemEpilogueBlackwell() {
      const bool has_splitk = params_->splitk_factor != 1;
      int64_t tmem_vectorize_factor = getLdTMemVectorizeFactor();
      std::vector<TensorView*> cached_tvs;
      std::vector<TensorView*> propagate_to =
          splitk_sums_.empty() ? mma_results_ : splitk_sums_;
      // When there is a split-K, the TMem load happens before split-K sum,
      // when there is no split-K, the TMem load happens in the epilogue.
      std::vector<TensorView*> tmem_ld_tvs =
          !has_splitk ? createTMemLoad() : std::vector<TensorView*>{};
      for (auto& [c, c_cache] : cached_epilogue_inputs_) {
        cached_tvs.push_back(c_cache);
        propagate_to.push_back(c);
      }
      for (Val* dv : fusion_->outputs()) {
        TensorView* d = dv->as<TensorView>();
        NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
    
        // Apply the default scheduling that is common to all register
        // TensorViews after wgmma.
        blockTileTensors({d});
        parallelizeBlocks({d});
        transformLikeMmaOutputWithoutK(d);
    
        // TIDx is 128, so we use it for lanes of the accumulator. Also, we
        // vectorize the TMem load with a factor of v (tmem_vectorize_factor).
        // [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v (Vectorize)]
        d->axis(-2)->parallelize(ParallelType::TIDx);
        if (tmem_vectorize_factor < getN(params_->mma_macro)) {
          d->split(-1, tmem_vectorize_factor);
        }
    
        // TODO: We need to check bank conflicts in this path.
        // Propagate schedule changes back to the outputs of the Mma op.
        scheduler_utils::BoundedDirectionalTransformPropagator::backward(
            d,
            -1,
            propagate_to,
            scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                .propagateParallelType());
    
        // Vectorize the epilogue input load and output store. TMem load can
        // be vectorized to 512 byte, but gmem load/store can only be vectorized
        // to 16 bytes. So we need to further split the last dimension and use
        // multiple vector loads/stores. for each TMem load/store.
        // After split and parallelization:
        // (v = tmem_vectorize_factor, vv = params_->supported_vec_size.epilogue)
        // [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v/vv, vv]
        // TODO: Support vectorization_factor in MatmulParams
        if (tmem_vectorize_factor > params_->supported_vec_size.epilogue) {
          d->split(-1, params_->supported_vec_size.epilogue);
          for (auto c : cached_tvs) {
            bool is_2d_epilogue_input =
                TensorDomain::noBroadcasts(c->domain()->logical()).size() == 2;
            if (is_2d_epilogue_input) {
              c->split(-1, params_->supported_vec_size.epilogue);
            }
          }
        }
        d->axis(-1)->parallelize(ParallelType::Vectorize);
        if (!cached_tvs.empty()) {
          scheduler_utils::parallelizeAllLike(d, -1, cached_tvs);
        }
      }
      // Vectorize the TMem load, if any.
      for (auto tmem_ld_tv : tmem_ld_tvs) {
        tmem_ld_tv->axis(-1)->parallelize(ParallelType::Vectorize);
      }
    }
    
    void HopperPlus::scheduleEpilogueWithoutSmemEpilogueHopper() {
      std::vector<TensorView*> cached_tvs;
      std::vector<TensorView*> propagate_to =
          splitk_sums_.empty() ? mma_results_ : splitk_sums_;
      for (auto& [c, c_cache] : cached_epilogue_inputs_) {
        cached_tvs.push_back(c_cache);
        propagate_to.push_back(c);
      }
      for (Val* dv : fusion_->outputs()) {
        TensorView* d = dv->as<TensorView>();
        NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
    
        // Apply the default scheduling that is common to all register
        // TensorViews after wgmma.
        blockTileTensors({d});
        parallelizeBlocks({d});
        transformLikeMmaOutputWithoutK(d);
    
        const AbstractTensor s =
            mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(d->getLoopDomain());
        d->setLoopDomain(s.as<IterDomain*>());
    
        // TODO: We need to check bank conflicts in this path.
        // Propagate schedule changes back to the outputs of the Mma op.
        scheduler_utils::BoundedDirectionalTransformPropagator::backward(
            d,
            -1,
            propagate_to,
            scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                .propagateParallelType());
    
        // We do not respect the vectorization_factor parameter, but always
        // vectorize the inner-dim with extent 2.
        NVF_ERROR(params_->supported_vec_size.epilogue >= 2);
        // TODO: Support vectorization_factor in MatmulParams
        d->axis(-1)->parallelize(ParallelType::Vectorize);
        if (!cached_tvs.empty()) {
          scheduler_utils::parallelizeAllLike(d, -1, cached_tvs);
        }
      }
    }
    
    void HopperPlus::scheduleEpilogueWithoutSmemEpilogue() {
      if (isBlackwell(params_->mma_macro)) {
        scheduleEpilogueWithoutSmemEpilogueBlackwell();
      } else {
        scheduleEpilogueWithoutSmemEpilogueHopper();
      }
    }
    
    void HopperPlus::scheduleEpilogueWithSmemEpilogueHopper() {
      constexpr int64_t ldst_matrix_tile_m = 16;
      constexpr int64_t ldst_matrix_tile_n = 16;
      fusion_->manage("ldst_matrix_m_tile", ldst_matrix_tile_m);
      fusion_->manage("ldst_matrix_n_tile", ldst_matrix_tile_n);
      fusion_->manage("ldst_matrix_m_smem", params_->tile_sizes.warp_tile.m);
      fusion_->manage("ldst_matrix_n_smem", params_->tile_sizes.warp_tile.n);
    
      // Propagate to (not including) the splitk output if there is a splitk
      // else this is just mma_results_
      std::vector<TensorView*> propagate_to =
          splitk_sums_.empty() ? mma_results_ : splitk_sums_;
      for (auto& [c, c_cache] : cached_epilogue_inputs_) {
        bool load_with_ldmatrix =
            params_->use_ldst_matrix && dataTypeSize(c_cache->dtype()) == 2;
        bool is_2d_epilogue_input =
            TensorDomain::noBroadcasts(c_cache->domain()->logical()).size() == 2;
        if (load_with_ldmatrix && is_2d_epilogue_input &&
            params_->async_gmem_load_operands) {
          // Schedule TMA load into shared memory for epilogue input
          c_cache->definition()->as<LoadStoreOp>()->setOpType(
              LoadStoreOpType::CpAsyncBulkTensorTile);
          c_cache->setMemoryType(MemoryType::Shared);
    
          // Apply the default scheduling that is common to all register
          // TensorViews after wgmma.
          blockTileTensors({c_cache});
          parallelizeBlocks({c_cache});
          transformLikeMmaOutputWithoutK(c_cache);
    
          // Swizzle to avoid shared memory bank conflicts
          MmaInputSmemSwizzle swizzle_type =
              mma_utils::tmaSwizzleSharedMemory(c_cache);
          c_cache->applyMmaSwizzleForTMALoad(swizzle_type);
    
          TensorView* reg_tv = cacheAfter(c_cache);
          reg_tv->definition()->as<LoadStoreOp>()->setOpType(
              LoadStoreOpType::LdMatrix);
    
          // Apply the default scheduling that is common to all register
          // TensorViews after wgmma.
          blockTileTensors({reg_tv});
          parallelizeBlocks({reg_tv});
          transformLikeMmaOutputWithoutK(reg_tv);
    
          // Schedule the loop and allocation domain of LdMatrix like the
          // accumulation register TensorView of wgmma.
          AbstractTensor s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
              reg_tv->getLoopDomain());
          reg_tv->setLoopDomain(s.as<IterDomain*>());
          reg_tv->setAllocationDomain(
              reg_tv->getLoopDomain(), /*new_contiguity=*/true);
    
          // Apply LdStMatrix scheduling to the wgmma loop domain
          mma_utils::scheduleLdStMatrixForMmaOutput(
              reg_tv, ldst_matrix_tile_m, ldst_matrix_tile_n);
    
          // Vectorize last iterDomain because LdMatrix loads all eight values with
          // a single LdMatrix.x4 operation
          reg_tv->axis(-1)->parallelize(ParallelType::Vectorize);
    
          // Do not propagate any other changes to LdMatrix.
          propagate_to.push_back(reg_tv);
        } else {
          // Propagate changes to the cache_after tensor if not using TMA load.
          propagate_to.push_back(c);
        }
      }
    
      // Manually schedule register cache and output TensorView
      for (Val* dv : fusion_->outputs()) {
        TensorView* d = dv->as<TensorView>();
        NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
        TensorView* dc = d->definition()->input(0)->as<TensorView>();
    
        // The chain of operations storing data to global memory:
        //   registers -> (stmatrix) -> smem -> (tma_store) -> gmem
        TensorView* d_smem = cacheBefore(d, LoadStoreOpType::Set);
    
        std::vector<TensorView*> tvs_to_schedule{d, d_smem};
        bool dc_is_mma_result =
            std::find(mma_results_.begin(), mma_results_.end(), dc) !=
            mma_results_.end();
        bool dc_is_splitk_sum = params_->splitk_factor > 1 &&
            std::find(splitk_sums_.begin(), splitk_sums_.end(), dc) !=
                splitk_sums_.end();
    
        if (!dc_is_mma_result && !dc_is_splitk_sum) {
          // Skip scheduling dc if it is an mma_result. This can happen if we are
          // not casting back to half-precision in the output
          tvs_to_schedule.push_back(dc);
        }
    
        // Set MemoryType
        dc->setMemoryType(MemoryType::Local);
        d_smem->setMemoryType(MemoryType::Shared);
    
        // Set LoadStoreOpType
        bool store_with_stmatrix =
            params_->use_ldst_matrix && dataTypeSize(dc->dtype()) == 2;
        if (store_with_stmatrix) {
          d_smem->definition()->as<LoadStoreOp>()->setOpType(
              LoadStoreOpType::StMatrix);
        }
        d->definition()->as<LoadStoreOp>()->setOpType(
            LoadStoreOpType::CpAsyncBulkTensorTile);
    
        // Apply the common transforms to dc, d_smem, d
        // After these transforms we schedule the inner two non-reduction loops
        // (instruction tile) of dc and propagate is back till the outputs of mma.
        blockTileTensors(tvs_to_schedule);
        parallelizeBlocks(tvs_to_schedule);
        for (auto tv : tvs_to_schedule) {
          transformLikeMmaOutputWithoutK(tv);
        }
    
        // Should not propagate if the dc is a mma output as the mma output has
        // already been scheduled.
        if (!dc_is_mma_result && !dc_is_splitk_sum) {
          auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
              dc->getLoopDomain());
          dc->setLoopDomain(s.as<IterDomain*>());
          dc->setAllocationDomain(s.as<IterDomain*>(), true);
    
          scheduler_utils::BoundedDirectionalTransformPropagator::backward(
              dc,
              -1,
              propagate_to,
              scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                  .propagateParallelType());
        }
    
        // Determine swizzle for TMA Store
        MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);
    
        // First, create loop domain that matches wgmma register accumulator using
        // original loop domain.
        const AbstractTensor s =
            mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
                d_smem->getLoopDomain());
        // Create allocation domain with swizzle for TMA Store.
        // This step modifies loop domain and the creates a new allocation domain.
        if (swizzle != MmaInputSmemSwizzle::None) {
          mma_utils::scheduleTMAStoreForMmaOutput(d_smem, swizzle);
        }
        // Finally, set loop domain using saved AbstractTensor.
        d_smem->setLoopDomain(s.as<IterDomain*>());
    
        if (store_with_stmatrix) {
          // Apply LdStMatrix scheduling to the wgmma loop domain
          mma_utils::scheduleLdStMatrixForMmaOutput(
              d_smem, ldst_matrix_tile_m, ldst_matrix_tile_n);
        }
        d_smem->axis(-1)->parallelize(ParallelType::Vectorize);
    
        // Schedule global memory output; Output from TMA Store
        mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle);
      }
    }
    
    void HopperPlus::scheduleEpilogueWithSmemEpilogueBlackwell() {
      const bool has_splitk = params_->splitk_factor != 1;
      int64_t tmem_vectorize_factor = getLdTMemVectorizeFactor();
    
      std::vector<TensorView*> tmem_ld_tvs =
          !has_splitk ? createTMemLoad() : std::vector<TensorView*>{};
    
      // Propagate to (not including) the splitk output if there is a splitk
      // else this is just mma_results_
      std::vector<TensorView*> register_tvs;
      std::vector<TensorView*> propagate_to =
          splitk_sums_.empty() ? mma_results_ : splitk_sums_;
      for (auto& [c, c_cache] : cached_epilogue_inputs_) {
        bool is_2d_epilogue_input =
            TensorDomain::noBroadcasts(c_cache->domain()->logical()).size() == 2;
        if (is_2d_epilogue_input && params_->async_gmem_load_operands) {
          // Schedule TMA load into shared memory for epilogue input
          c_cache->definition()->as<LoadStoreOp>()->setOpType(
              LoadStoreOpType::CpAsyncBulkTensorTile);
          c_cache->setMemoryType(MemoryType::Shared);
          blockTileTensors({c_cache});
          parallelizeBlocks({c_cache});
          transformLikeMmaOutputWithoutK(c_cache);
          c_cache->setAllocationDomain(c_cache->getLoopDomain(), true);
          for (int64_t i = -5; i <= -1; i++) {
            c_cache->axis(i)->parallelize(ParallelType::Bulk);
          }
    
          // Schedule smem->register load for epilogue input
          TensorView* reg_tv = cacheAfter(c_cache);
          register_tvs.push_back(reg_tv);
          blockTileTensors({reg_tv});
          parallelizeBlocks({reg_tv});
          transformLikeMmaOutputWithoutK(reg_tv);
        }
        // Propagate changes to the cache_after tensor
        propagate_to.push_back(c);
      }
    
      // TMem load is scheduled separately, so don't propagate to it.
      propagate_to.insert(
          propagate_to.end(), tmem_ld_tvs.begin(), tmem_ld_tvs.end());
    
      // The chain of operations storing data to global memory:
      //   dc (registers) -> d_smem -> [tma_store] -> d (gmem)
      // We schedule d_smem and propagate it back.
      for (Val* dv : fusion_->outputs()) {
        TensorView* d = dv->as<TensorView>();
        NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
        TensorView* dc = d->definition()->input(0)->as<TensorView>();
        TensorView* d_smem = cacheBefore(d, LoadStoreOpType::Set);
        dc->setMemoryType(MemoryType::Local);
        d_smem->setMemoryType(MemoryType::Shared);
    
        // We schedule the epilogue like:
        // (v = tmem_vectorize_factor, vv = smem_vectorize_factor
        // [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v/vv, vv]
        blockTileTensors({d, d_smem});
        parallelizeBlocks({d, d_smem});
        for (auto tv : {d, d_smem}) {
          transformLikeMmaOutputWithoutK(tv);
          tv->axis(-2)->parallelize(ParallelType::TIDx);
          if (tmem_vectorize_factor < getN(params_->mma_macro)) {
            tv->split(-1, tmem_vectorize_factor);
          }
        }
        if (tmem_vectorize_factor > hardcoded_smem_vectorize_factor) {
          d_smem->split(-1, hardcoded_smem_vectorize_factor);
        }
    
        scheduler_utils::BoundedDirectionalTransformPropagator::backward(
            d_smem,
            -1,
            propagate_to,
            scheduler_utils::BoundedDirectionalTransformPropagator::Options()
                .propagateParallelType());
    
        d_smem->axis(-1)->parallelize(ParallelType::Vectorize);
        d_smem->setAllocationDomain(d_smem->getLoopDomain(), true);
    
        // Schedule global memory output; Output from TMA Store
        d->definition()->as<LoadStoreOp>()->setOpType(
            LoadStoreOpType::CpAsyncBulkTensorTile);
        for (int64_t i = -5; i <= -1; i++) {
          d->axis(i)->parallelize(ParallelType::Bulk);
        }
      }
    
      // Schedule TMem load as:
      // (v = tmem_vectorize_factor)
      // [..., Mo * No, Mw, Nw, Mi (TIDx), Ni / v, v (Vectorize)]
      blockTileTensors(tmem_ld_tvs);
      parallelizeBlocks(tmem_ld_tvs);
      for (TensorView* tmem_ld_tv : tmem_ld_tvs) {
        transformLikeMmaOutputWithoutK(tmem_ld_tv);
        tmem_ld_tv->axis(-2)->parallelize(ParallelType::TIDx);
        if (tmem_vectorize_factor < getN(params_->mma_macro)) {
          tmem_ld_tv->split(-1, tmem_vectorize_factor);
        }
        tmem_ld_tv->axis(-1)->parallelize(ParallelType::Vectorize);
      }
    }
    Test Skips

    The test HopperPlusMatmulSchedulerTest skips certain configurations for Blackwell. It is important to ensure that all configurations are tested or provide a clear rationale for skipping them.

        splitk_factor) = GetParam();
    
    if (isHopper(mma_macro)) {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0);
    } else {
      NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(10, 0, 11, 0);
    }
    
    if (a_k_inner) {
      layout = b_k_inner ? MmaLayout::TN : MmaLayout::TT;

    @zasdfgbnm zasdfgbnm changed the title blackwell smem epilogue Blackwell matmul scheduler smem epilogue support May 30, 2025
    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    @zasdfgbnm zasdfgbnm marked this pull request as ready for review May 30, 2025 01:43
    Comment on lines 941 to 942
    // TODO: should we rename use_ldst_matrix to use_tma_for_epilogue_input?
    bool load_with_tma = params_->use_ldst_matrix;
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I believe we should rename use_ldst_matrix, I will do it in a separate PR if no objection.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is separate from use_smem_epilogue which for Hopper+ means "use TMA for epilogue stores and loads". If that is true but use_ldst_matrix is false, then we will use vectorized loads/stores between smem and registers in the epilogue even though we are using TMA to do the gmem<->smem transfers.

    I think the most flexible parametrization would allow us to set the following for each epilogue input:

    • Whether to use TMA to load from gmem to smem
    • Whether to use ldmatrix
      and for each output:
    • Whether to use stmatrix
    • Whether to use TMA to store from smem to gmem

    In other words we could have this:

      struct EpilogueInputConfig {
        bool use_tma_load;
        bool use_ldmatrix;   // used if use_tma_load == true
      };
      std::vector<EpilogueInputConfig> epilogue_input_configs;
    
      struct OutputConfig {
        bool use_tma_store;
        bool use_stmatrix;   // used if use_tma_store == true
      };
      std::vector<OutputConfig> output_configs;
    

    The order of these vectors would correspond to the order found in tensor_roles which is deterministic.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is separate from use_smem_epilogue which for Hopper+ means "use TMA for epilogue stores and loads". If that is true but use_ldst_matrix is false, then we will use vectorized loads/stores between smem and registers in the epilogue even though we are using TMA to do the gmem<->smem transfers.

    I think this is not how it is currently implemented.

    Here:

    // Schedule TMA load into shared memory for epilogue input
    c_cache->definition()->as<LoadStoreOp>()->setOpType(
    LoadStoreOpType::CpAsyncBulkTensorTile);

    We only use TMA to load epilogue input if use_ldst_matrix is true, which seems weird to me.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    We only use TMA to load epilogue input if use_ldst_matrix is true, which seems weird to me.

    You're right. We should change that condition.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Removed use_ldst_matrix from blackwell scheduler

    Copy link
    Collaborator

    @rdspring1 rdspring1 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM.

    constexpr int64_t ldst_matrix_tile_m = 16;
    constexpr int64_t ldst_matrix_tile_n = 16;
    fusion_->manage("ldst_matrix_m_tile", ldst_matrix_tile_m);
    fusion_->manage("ldst_matrix_n_tile", ldst_matrix_tile_n);
    fusion_->manage("ldst_matrix_m_smem", params_->tile_sizes.warp_tile.m);
    fusion_->manage("ldst_matrix_n_smem", params_->tile_sizes.warp_tile.n);

    // Apply LdMatrix to any epilogue inputs loaded to smem with TMA.
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Was this unused?

    @@ -926,6 +923,122 @@ void HopperPlus::scheduleEpilogueWithSmemEpilogue() {
    }
    }

    constexpr int64_t hardcoded_smem_vectorize_factor = 4;
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do we want to move this constant to the top or more central location?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Moved this and hardcoded_blackwell_splitk_vectorization_factor up.

    @zasdfgbnm
    Copy link
    Collaborator Author

    !test

    @zasdfgbnm zasdfgbnm merged commit 03a274d into main Jun 3, 2025
    52 of 53 checks passed
    @zasdfgbnm zasdfgbnm deleted the smem-epilogue branch June 3, 2025 20:19
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    3 participants