Skip to content

Respect min and max of inputs to create more precise repro scripts #4535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
82 changes: 59 additions & 23 deletions python/nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import re
from typing import Callable
from numbers import Number
import warnings

import torch
Expand Down Expand Up @@ -55,6 +56,53 @@ def disable_automatic_serialization():
atexit.unregister(_C.serialize)


# NOTE(crcrpar): The main motivation of this dataclass is to avoid unexpected negative values being supplied to indexing such as embedding.
# See NVIDIA/Fuser#4529.
class InputTensorFactory:
"""Dataclass to describe an input tensor."""

low: Number
high: Number
device: str
dtype: torch.dtype
size: tuple[int, ...]
strides: tuple[int, ...]
storage_offset: int
requires_grad: bool
is_contiguous: bool

def __init__(self, tensor: torch.Tensor) -> None:
if type(tensor) is not torch.Tensor:
msg = f"Repro script only supports {torch.Tensor} but {type(tensor)}"
raise RuntimeError(msg)
self.low, self.high = InputTensorFactory._get_min_and_max(tensor)
self.dtype = tensor.dtype
self.device = f'"{str(tensor.device)}"'
self.size = tuple(tensor.size())
self.strides = tuple(tensor.stride())
self.storage_offset = tensor.storage_offset()
self.requires_grad = tensor.requires_grad
self.is_contiguous = tensor.is_contiguous()

@torch.inference_mode()
@staticmethod
def _get_min_and_max(tensor: torch.Tensor) -> tuple[Number | None, Number | None]:
if tensor.dtype is torch.bool:
return 0, 2

t = tensor
if t.dtype.is_floating_point and t.dtype.itemsize == 1:
t = t.float()

min_max = torch.aminmax(t)
return min_max[0].cpu().item(), min_max[1].cpu().item()

# We might want to have a method that returns `torch.Tensor` in the future.
def __str__(self) -> str:
"""String representing tensor factory of appropriate metadata."""
return f"torch.testing.make_tensor({self.size}, dtype={self.dtype}, device={self.device}, low={self.low}, high={self.high}, requires_grad={self.requires_grad}).as_strided({self.size}, {self.strides}, {self.storage_offset})"


class FusionDefinition(_C._FusionDefinition):
def __init__(
self,
Expand Down Expand Up @@ -310,11 +358,13 @@ def execute(
self.schedule()
self._finalize_schedule(inputs)

# TODO(crcrpar): Think about deprecating `self.fake_inputs` in favor of `InputDescriptor`.
if save_repro_inputs:
from torch._subclasses.fake_tensor import FakeTensorMode

fake_mode = FakeTensorMode()
self.fake_inputs = [fake_mode.from_tensor(inp) for inp in inputs]
self.input_descriptors = [InputTensorFactory(tensor) for tensor in inputs]

if hasattr(self, "segments") and len(self.segments) > 0:
return self._execute_segments(inputs, device=device, profile=profile)
Expand Down Expand Up @@ -499,10 +549,13 @@ def last_repro_script(self) -> str:
assert (
self.fake_inputs is not None
), "fd.last_repro_script() cannot provide a repro because fd.execute(inputs, save_repro_state=True) was not executed!"
script = self.repro_script_for(self.fake_inputs)
script = self.repro_script_for(self.input_descriptors)
return script

def repro_script_for(self, inputs: list | None = None) -> str:
def repro_script_for(
self,
inputs: list[torch.Tensor] | list[InputTensorFactory] | None = None,
) -> str:
msg = "# CUDA devices:\n"
for i in range(torch.cuda.device_count()):
msg += f"# {i}: {torch.cuda.get_device_name(i)}\n"
Expand All @@ -519,28 +572,11 @@ def repro_script_for(self, inputs: list | None = None) -> str:
if inputs is not None:
msg += "\ninputs = [\n"
for i in inputs:
# TODO(crcrpar): Think about how to support tensor wrapper subclasses such as DTensor
if isinstance(i, torch.Tensor):
if i.is_contiguous():
msg += f" torch.testing.make_tensor({tuple(i.size())}, dtype={i.dtype}, device='{i.device}'),\n"
else:
# max linear index determines number of elements to generate
sz = 1
for szi, stri in zip(i.size(), i.stride()):
if szi == 0:
sz = 0
break
sz += (szi - 1) * stri
if i.dtype.is_floating_point:
msg += (
f" torch.randn({sz}, dtype={i.dtype}, device='{i.device}')"
f".as_strided({tuple(i.size())}, {tuple(i.stride())}),\n"
)
else:
upper_bound = 2 if i.dtype == torch.bool else 10
msg += (
f" torch.randint(0, {upper_bound}, ({sz},), dtype={i.dtype}, device='{i.device}')"
f".as_strided({tuple(i.size())}, {tuple(i.stride())}),\n"
)
i = InputTensorFactory(i)
if isinstance(i, InputTensorFactory):
msg += f" {i},\n"
else:
input_as_string = str(i)
# `nan` and `inf` are stringified as is, which are not
Expand Down
12 changes: 8 additions & 4 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4515,15 +4515,19 @@ def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
nvfuser_fusion_id0(fd)

inputs = [
torch.testing.make_tensor((4, 4), dtype=torch.float32, device='cuda:0'),
torch.testing.make_tensor((4, 4), dtype=torch.float32, device='cuda:0'),
torch.testing.make_tensor((4, 4), dtype=torch.float32, device="cuda:0", low=0.0, high=1.0, requires_grad=False).as_strided((4, 4), (4, 1), 0),
torch.testing.make_tensor((4, 4), dtype=torch.float32, device="cuda:0", low=0.0, high=1.0, requires_grad=False).as_strided((4, 4), (4, 1), 0),
]
fd.execute(inputs)
"""

@torch.inference_mode()
def generate_sample_with_deterministic_low_and_high(shape):
t = torch.rand(shape, device="cuda:0")
return torch.where(t < 0.5, 0.0, 1.0)

inputs = [
torch.randn(4, 4, device="cuda:0"),
torch.randn(4, 4, device="cuda:0"),
generate_sample_with_deterministic_low_and_high((4, 4)) for _ in range(2)
]

def fusion_func(fd: FusionDefinition):
Expand Down
Loading