Skip to content

Commit 8b70ef2

Browse files
authored
Add bindings for FusionExecutorCache (#4513)
This PR creates the bindings for `FusionExecutorCache`, allowing fusions to run CUDA kernels in `nvfuser_direct`. Functions bound for `FusionExecutorCache`: * get_cuda_kernel * get_most_recent_scheduled_ir * get_scheduled_ir * is_compiled * execute Create `python/python_direct/direct_utils.h` for `python_direct`-only helper functions * Add `from_pyiterable` and `to_tensor_vector` to and from `at::Tensor` and `KernelArgumentHolder` New function for python `FusionDefinition`: * `execute` - It creates FusionExecutorCache if it exists and runs the fusion with given input arguments. Testing * Created `test_fusion_execution_cache` and `test_define_tensor` PR Stack: #4409 Create python FusionDefinition for nvfuser_next #4513 Add bindings for FusionExecutorCache **<<< This PR.** #4516 Add the remaining binary operations #4517 Add the bindings for unary operations #4518 Add the bindings for reduction operations #4519 Move helper functions from python_frontend to python_common #4520 Create python reproducer from Fusion IR for nvfuser_direct #4521 Recreate python_frontend test_basic for nvfuser_direct
1 parent 4ebbab9 commit 8b70ef2

File tree

6 files changed

+436
-8
lines changed

6 files changed

+436
-8
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ if(BUILD_PYTHON)
572572
${NVFUSER_PYTHON_DIRECT_BINDINGS}/ir.cpp
573573
${NVFUSER_PYTHON_DIRECT_BINDINGS}/ops.cpp
574574
${NVFUSER_PYTHON_DIRECT_BINDINGS}/runtime.cpp
575+
${NVFUSER_PYTHON_DIRECT_BINDINGS}/direct_utils.cpp
575576
)
576577
add_library(nvf_py_direct_internal OBJECT ${NVFUSER_PYTHON_DIRECT_SRCS})
577578

python/nvfuser_direct/__init__.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import os
1212
import torch
13+
import traceback
1314

1415
# This is needed when libnvfuser_direct.so is patched and doesn't have the pytorch library location available.
1516
pytorch_lib_dir = os.path.join(os.path.dirname(torch.__file__), "lib")
@@ -42,25 +43,35 @@ def __init__(self):
4243
# Monkey patching nvfuser_direct.ops submodule to mimic python_frontend
4344
# FusionDefinition.ops API. This is to maintain backwards compatibilty.
4445
self.ops = _C_DIRECT.ops
45-
self.fusion = _C_DIRECT.Fusion()
46-
self.fusion_guard = None
46+
self._fusion = None
47+
self._fusion_guard = None
48+
49+
@property
50+
def fusion(self):
51+
if not hasattr(self, "fec"):
52+
return self._fusion
53+
else:
54+
return self.fec.fusion()
4755

4856
def __enter__(self):
4957
"""
5058
Enter the context manager.
59+
5160
Returns
5261
-------
5362
FusionDefinition
5463
The FusionDefinition instance
5564
"""
56-
self.fusion_guard = _C_DIRECT.FusionGuard(self.fusion)
65+
self._fusion = _C_DIRECT.Fusion()
66+
self._fusion_guard = _C_DIRECT.FusionGuard(self._fusion)
5767
return self
5868

5969
def __exit__(self, exception_type, exception_value, exception_traceback):
6070
"""
6171
Exit the context manager and handle any exceptions.
6272
This method is called when exiting the 'with' block, whether normally or due to an exception.
6373
The arguments provide information about any exception that occurred:
74+
6475
Parameters
6576
----------
6677
excecption_type : type or None
@@ -73,7 +84,7 @@ def __exit__(self, exception_type, exception_value, exception_traceback):
7384
The traceback object containing the call stack.
7485
None if no exception occurred.
7586
"""
76-
self.fusion_guard = None
87+
del self._fusion_guard
7788
if exception_type is not None:
7889
print(f"Exception occurred: {exception_type.__name__}: {exception_value}")
7990
if exception_traceback is not None:
@@ -84,29 +95,60 @@ def __exit__(self, exception_type, exception_value, exception_traceback):
8495
def define_tensor(self, *args, **kwargs):
8596
"""
8697
Define a new tensor input for the fusion.
98+
8799
Parameters
88100
----------
89101
*args
90102
Positional arguments passed to _C_DIRECT.define_tensor
91103
**kwargs
92104
Keyword arguments passed to _C_DIRECT.define_tensor
105+
93106
Returns
94107
-------
95108
Tensor
96109
The defined tensor
97110
"""
98111
tv = _C_DIRECT.define_tensor(*args, **kwargs)
99-
self.fusion.add_input(tv)
112+
self._fusion.add_input(tv)
100113
return tv
101114

102115
def add_output(self, *args, **kwargs):
103116
"""
104117
Add an output to the fusion.
118+
105119
Parameters
106120
----------
107121
*args
108122
Positional arguments passed to fusion.add_output
109123
**kwargs
110124
Keyword arguments passed to fusion.add_output
111125
"""
112-
self.fusion.add_output(*args, **kwargs)
126+
self._fusion.add_output(*args, **kwargs)
127+
128+
def execute(self, inputs, *, device=None, auto_schedule=True) -> list[torch.Tensor]:
129+
"""
130+
Execute the fusion with the given inputs.
131+
132+
Parameters
133+
----------
134+
inputs : list of torch.Tensor
135+
Input tensors and scalars to the fusion
136+
device : torch.device, optional
137+
Device to execute the fusion on
138+
auto_schedule : bool, default=True
139+
Whether to use automatic scheduling
140+
141+
Returns
142+
-------
143+
list of torch.Tensor
144+
Output tensors from the fusion
145+
"""
146+
if auto_schedule:
147+
if not hasattr(self, "fec"):
148+
self.fec = _C_DIRECT.FusionExecutorCache(self._fusion)
149+
# A copy of fusion is created after construction FusionExecutorCache
150+
# Delete the _fusion and reference the fusion inside FusionExecutorCache
151+
del self._fusion
152+
return self.fec.execute(inputs)
153+
else:
154+
raise RuntimeError("Manual scheduling is not supported yet.")

python/python_direct/direct_utils.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// clang-format off
2+
/*
3+
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
4+
* All rights reserved.
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*/
7+
// clang-format on
8+
9+
#include <direct_utils.h>
10+
#include <algorithm>
11+
12+
namespace nvfuser::python {
13+
14+
KernelArgumentHolder from_pyiterable(
15+
const py::iterable& iter,
16+
std::optional<int64_t> device) {
17+
KernelArgumentHolder args;
18+
for (py::handle obj : iter) {
19+
// Allows for a Vector of Sizes to be inputed as a list/tuple
20+
if (py::isinstance<py::list>(obj) || py::isinstance<py::tuple>(obj)) {
21+
for (py::handle item : obj) {
22+
args.push(torch::jit::toIValue(item, c10::AnyType::get()));
23+
}
24+
} else {
25+
args.push(torch::jit::toIValue(obj, c10::AnyType::get()));
26+
}
27+
}
28+
29+
// Transform int64_t device to int8_t
30+
std::optional<int8_t> selected_device = std::nullopt;
31+
if (device.has_value()) {
32+
NVF_CHECK(device.value() < 256, "Maximum device index is 255");
33+
selected_device = (int8_t)device.value();
34+
}
35+
args.setDeviceIndex(selected_device);
36+
return args;
37+
}
38+
39+
std::vector<at::Tensor> to_tensor_vector(const KernelArgumentHolder& outputs) {
40+
// Convert outputs KernelArgumentHolder to std::vector<at::Tensor>
41+
std::vector<at::Tensor> out_tensors;
42+
out_tensors.reserve(outputs.size());
43+
std::transform(
44+
outputs.begin(),
45+
outputs.end(),
46+
std::back_inserter(out_tensors),
47+
[](const PolymorphicValue& out) { return out.as<at::Tensor>(); });
48+
return out_tensors;
49+
}
50+
51+
} // namespace nvfuser::python

python/python_direct/direct_utils.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// clang-format off
2+
/*
3+
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
4+
* All rights reserved.
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*/
7+
// clang-format on
8+
#pragma once
9+
10+
#include <runtime/executor_kernel_arg.h>
11+
#include <torch/csrc/jit/python/pybind_utils.h>
12+
#include <optional>
13+
#include <vector>
14+
15+
namespace nvfuser::python {
16+
17+
// Convert a py::iterable to a KernelArgumentHolder
18+
nvfuser::KernelArgumentHolder from_pyiterable(
19+
const py::iterable& iter,
20+
std::optional<int64_t> device = std::nullopt);
21+
22+
// Convert a KernelArgumentHolder to a std::vector<at::Tensor>
23+
std::vector<at::Tensor> to_tensor_vector(
24+
const nvfuser::KernelArgumentHolder& outputs);
25+
26+
} // namespace nvfuser::python

0 commit comments

Comments
 (0)