TorchDynamo and C++ extensions #197
RaulPPelaez
started this conversation in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
We have a torch C++ Autograd extension (called
get_neighbor_pairs
) which provides a CPU and CUDA backends.When trying to torch.compile a model whose forward function calls this extension:
I get the following error:
Click me
The extension is defined as:
torchmd-net/torchmdnet/neighbors/neighbors.cpp
Lines 1 to 5 in a116847
And the CUDA implementation is here
torchmd-net/torchmdnet/neighbors/neighbors_cuda.cu
Lines 74 to 89 in a116847
I jit compile this function when import happens:
torchmd-net/torchmdnet/neighbors/__init__.py
Lines 2 to 16 in a116847
Eventually I call this extension in the forward function of the model in the example above:
torchmd-net/torchmdnet/models/utils.py
Line 110 in a116847
Which stores the extension function as
self.kernel
torchmd-net/torchmdnet/models/utils.py
Lines 234 to 237 in a116847
Beta Was this translation helpful? Give feedback.
All reactions