Skip to content

Commit 79da54b

Browse files
authored
Merge branch 'main' into xpu_nondeterministic_roi_align
2 parents 332e2be + d462da2 commit 79da54b

File tree

13 files changed

+135
-32
lines changed

13 files changed

+135
-32
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import distutils.spawn
33
import glob
44
import os
5+
import shlex
56
import shutil
67
import subprocess
78
import sys
@@ -123,7 +124,7 @@ def get_macros_and_flags():
123124
if NVCC_FLAGS is None:
124125
nvcc_flags = []
125126
else:
126-
nvcc_flags = NVCC_FLAGS.split(" ")
127+
nvcc_flags = shlex.split(NVCC_FLAGS)
127128
extra_compile_args["nvcc"] = nvcc_flags
128129

129130
if sys.platform == "win32":

test/datasets_utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ class ImageDatasetTestCase(DatasetTestCase):
611611
"""
612612

613613
FEATURE_TYPES = (PIL.Image.Image, int)
614+
SUPPORT_TV_IMAGE_DECODE: bool = False
614615

615616
@contextlib.contextmanager
616617
def create_dataset(
@@ -632,22 +633,34 @@ def create_dataset(
632633
# This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
633634
# image, but never use the underlying data. During normal operation it is reasonable to assume that the
634635
# user wants to work with the image he just opened rather than deleting the underlying file.
635-
with self._force_load_images():
636+
with self._force_load_images(loader=(config or {}).get("loader", None)):
636637
yield dataset, info
637638

638639
@contextlib.contextmanager
639-
def _force_load_images(self):
640-
open = PIL.Image.open
640+
def _force_load_images(self, loader: Optional[Callable[[str], Any]] = None):
641+
open = loader or PIL.Image.open
641642

642643
def new(fp, *args, **kwargs):
643644
image = open(fp, *args, **kwargs)
644-
if isinstance(fp, (str, pathlib.Path)):
645+
if isinstance(fp, (str, pathlib.Path)) and isinstance(image, PIL.Image.Image):
645646
image.load()
646647
return image
647648

648-
with unittest.mock.patch("PIL.Image.open", new=new):
649+
with unittest.mock.patch(open.__module__ + "." + open.__qualname__, new=new):
649650
yield
650651

652+
def test_tv_decode_image_support(self):
653+
if not self.SUPPORT_TV_IMAGE_DECODE:
654+
pytest.skip(f"{self.DATASET_CLASS.__name__} does not support torchvision.io.decode_image.")
655+
656+
with self.create_dataset(
657+
config=dict(
658+
loader=torchvision.io.decode_image,
659+
)
660+
) as (dataset, _):
661+
image = dataset[0][0]
662+
assert isinstance(image, torch.Tensor)
663+
651664

652665
class VideoDatasetTestCase(DatasetTestCase):
653666
"""Abstract base class for video dataset testcases.

test/test_datasets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
405405
REQUIRED_PACKAGES = ("scipy",)
406406
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "val"))
407407

408+
SUPPORT_TV_IMAGE_DECODE = True
409+
408410
def inject_fake_data(self, tmpdir, config):
409411
tmpdir = pathlib.Path(tmpdir)
410412

@@ -2308,6 +2310,7 @@ def inject_fake_data(self, tmpdir, config):
23082310
class EuroSATTestCase(datasets_utils.ImageDatasetTestCase):
23092311
DATASET_CLASS = datasets.EuroSAT
23102312
FEATURE_TYPES = (PIL.Image.Image, int)
2313+
SUPPORT_TV_IMAGE_DECODE = True
23112314

23122315
def inject_fake_data(self, tmpdir, config):
23132316
data_folder = os.path.join(tmpdir, "eurosat", "2750")
@@ -2749,6 +2752,8 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase):
27492752

27502753
ADDITIONAL_CONFIGS = combinations_grid(split=("train", "valid", "test"))
27512754

2755+
SUPPORT_TV_IMAGE_DECODE = True
2756+
27522757
def inject_fake_data(self, tmpdir: str, config):
27532758
split_folder = pathlib.Path(tmpdir) / "country211" / config["split"]
27542759
split_folder.mkdir(parents=True, exist_ok=True)

test/test_image.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,42 @@ def test_encode_jpeg_cuda(img_path, scripted, contiguous):
623623
assert abs_mean_diff < 3
624624

625625

626+
@needs_cuda
627+
def test_encode_jpeg_cuda_sync():
628+
"""
629+
Non-regression test for https://github.com/pytorch/vision/issues/8587.
630+
Attempts to reproduce an intermittent CUDA stream synchronization bug
631+
by randomly creating images and round-tripping them via encode_jpeg
632+
and decode_jpeg on the GPU. Fails if the mean difference in uint8 range
633+
exceeds 5.
634+
"""
635+
torch.manual_seed(42)
636+
637+
# manual testing shows this bug appearing often in iterations between 50 and 100
638+
# as a synchronization bug, this can't be reliably reproduced
639+
max_iterations = 100
640+
threshold = 5.0 # in [0..255]
641+
642+
device = torch.device("cuda")
643+
644+
for iteration in range(max_iterations):
645+
height, width = torch.randint(4000, 5000, size=(2,))
646+
647+
image = torch.linspace(0, 1, steps=height * width, device=device)
648+
image = image.view(1, height, width).expand(3, -1, -1)
649+
650+
image = (image * 255).clamp(0, 255).to(torch.uint8)
651+
jpeg_bytes = encode_jpeg(image, quality=100)
652+
653+
decoded_image = decode_jpeg(jpeg_bytes.cpu(), device=device)
654+
mean_difference = (image.float() - decoded_image.float()).abs().mean().item()
655+
656+
assert mean_difference <= threshold, (
657+
f"Encode/decode mismatch at iteration={iteration}, "
658+
f"size={height}x{width}, mean diff={mean_difference:.2f}"
659+
)
660+
661+
626662
@pytest.mark.parametrize("device", cpu_and_cuda())
627663
@pytest.mark.parametrize("scripted", (True, False))
628664
@pytest.mark.parametrize("contiguous", (True, False))

torchvision/csrc/io/image/cuda/encode_jpegs_cuda.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
9494

9595
cudaJpegEncoder->set_quality(quality);
9696
std::vector<torch::Tensor> encoded_images;
97-
at::cuda::CUDAEvent event;
98-
event.record(cudaJpegEncoder->stream);
9997
for (const auto& image : contig_images) {
10098
auto encoded_image = cudaJpegEncoder->encode_jpeg(image);
10199
encoded_images.push_back(encoded_image);
102100
}
101+
at::cuda::CUDAEvent event;
102+
event.record(cudaJpegEncoder->stream);
103103

104104
// We use a dedicated stream to do the encoding and even though the results
105105
// may be ready on that stream we cannot assume that they are also available
@@ -108,10 +108,7 @@ std::vector<torch::Tensor> encode_jpegs_cuda(
108108
// do not want to block the host at this particular point
109109
// (which is what cudaStreamSynchronize would do.) Events allow us to
110110
// synchronize the streams without blocking the host.
111-
event.block(at::cuda::getCurrentCUDAStream(
112-
cudaJpegEncoder->original_device.has_index()
113-
? cudaJpegEncoder->original_device.index()
114-
: 0));
111+
event.block(cudaJpegEncoder->current_stream);
115112
return encoded_images;
116113
}
117114

@@ -121,7 +118,11 @@ CUDAJpegEncoder::CUDAJpegEncoder(const torch::Device& target_device)
121118
stream{
122119
target_device.has_index()
123120
? at::cuda::getStreamFromPool(false, target_device.index())
124-
: at::cuda::getStreamFromPool(false)} {
121+
: at::cuda::getStreamFromPool(false)},
122+
current_stream{
123+
original_device.has_index()
124+
? at::cuda::getCurrentCUDAStream(original_device.index())
125+
: at::cuda::getCurrentCUDAStream()} {
125126
nvjpegStatus_t status;
126127
status = nvjpegCreateSimple(&nvjpeg_handle);
127128
TORCH_CHECK(
@@ -186,12 +187,17 @@ CUDAJpegEncoder::~CUDAJpegEncoder() {
186187
}
187188

188189
torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {
190+
nvjpegStatus_t status;
191+
cudaError_t cudaStatus;
192+
193+
// Ensure that the incoming src_image is safe to use
194+
cudaStatus = cudaStreamSynchronize(current_stream);
195+
TORCH_CHECK(cudaStatus == cudaSuccess, "CUDA ERROR: ", cudaStatus);
196+
189197
int channels = src_image.size(0);
190198
int height = src_image.size(1);
191199
int width = src_image.size(2);
192200

193-
nvjpegStatus_t status;
194-
cudaError_t cudaStatus;
195201
status = nvjpegEncoderParamsSetSamplingFactors(
196202
nv_enc_params, NVJPEG_CSS_444, stream);
197203
TORCH_CHECK(
@@ -251,7 +257,7 @@ torch::Tensor CUDAJpegEncoder::encode_jpeg(const torch::Tensor& src_image) {
251257
nv_enc_state,
252258
encoded_image.data_ptr<uint8_t>(),
253259
&length,
254-
0);
260+
stream);
255261
TORCH_CHECK(
256262
status == NVJPEG_STATUS_SUCCESS,
257263
"Failed to retrieve encoded image: ",

torchvision/csrc/io/image/cuda/encode_jpegs_cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class CUDAJpegEncoder {
2222
const torch::Device original_device;
2323
const torch::Device target_device;
2424
const c10::cuda::CUDAStream stream;
25+
const c10::cuda::CUDAStream current_stream;
2526

2627
protected:
2728
nvjpegEncoderState_t nv_enc_state;

torchvision/datasets/coco.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
class CocoDetection(VisionDataset):
1111
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
1212
13-
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
13+
It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
14+
which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
1415
1516
Args:
1617
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
@@ -65,7 +66,8 @@ def __len__(self) -> int:
6566
class CocoCaptions(CocoDetection):
6667
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
6768
68-
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
69+
It requires `pycocotools <https://github.com/ppwwyyxx/cocoapi>`_ to be installed,
70+
which could be installed via ``pip install pycocotools`` or ``conda install conda-forge::pycocotools``.
6971
7072
Args:
7173
root (str or ``pathlib.Path``): Root directory where images are downloaded to.

torchvision/datasets/country211.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
2-
from typing import Callable, Optional, Union
2+
from typing import Any, Callable, Optional, Union
33

4-
from .folder import ImageFolder
4+
from .folder import default_loader, ImageFolder
55
from .utils import download_and_extract_archive, verify_str_arg
66

77

@@ -21,6 +21,9 @@ class Country211(ImageFolder):
2121
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
2222
download (bool, optional): If True, downloads the dataset from the internet and puts it into
2323
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
24+
loader (callable, optional): A function to load an image given its path.
25+
By default, it uses PIL as its image loader, but users could also pass in
26+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
2427
"""
2528

2629
_URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
@@ -33,6 +36,7 @@ def __init__(
3336
transform: Optional[Callable] = None,
3437
target_transform: Optional[Callable] = None,
3538
download: bool = False,
39+
loader: Callable[[str], Any] = default_loader,
3640
) -> None:
3741
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
3842

@@ -46,7 +50,12 @@ def __init__(
4650
if not self._check_exists():
4751
raise RuntimeError("Dataset not found. You can use download=True to download it")
4852

49-
super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
53+
super().__init__(
54+
str(self._base_folder / self._split),
55+
transform=transform,
56+
target_transform=target_transform,
57+
loader=loader,
58+
)
5059
self.root = str(root)
5160

5261
def _check_exists(self) -> bool:

torchvision/datasets/eurosat.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
from pathlib import Path
3-
from typing import Callable, Optional, Union
3+
from typing import Any, Callable, Optional, Union
44

5-
from .folder import ImageFolder
5+
from .folder import default_loader, ImageFolder
66
from .utils import download_and_extract_archive
77

88

@@ -21,6 +21,9 @@ class EuroSAT(ImageFolder):
2121
download (bool, optional): If True, downloads the dataset from the internet and
2222
puts it in root directory. If dataset is already downloaded, it is not
2323
downloaded again. Default is False.
24+
loader (callable, optional): A function to load an image given its path.
25+
By default, it uses PIL as its image loader, but users could also pass in
26+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
2427
"""
2528

2629
def __init__(
@@ -29,6 +32,7 @@ def __init__(
2932
transform: Optional[Callable] = None,
3033
target_transform: Optional[Callable] = None,
3134
download: bool = False,
35+
loader: Callable[[str], Any] = default_loader,
3236
) -> None:
3337
self.root = os.path.expanduser(root)
3438
self._base_folder = os.path.join(self.root, "eurosat")
@@ -40,7 +44,12 @@ def __init__(
4044
if not self._check_exists():
4145
raise RuntimeError("Dataset not found. You can use download=True to download it")
4246

43-
super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
47+
super().__init__(
48+
self._data_folder,
49+
transform=transform,
50+
target_transform=target_transform,
51+
loader=loader,
52+
)
4453
self.root = os.path.expanduser(root)
4554

4655
def __len__(self) -> int:

torchvision/datasets/imagenet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class ImageNet(ImageFolder):
3636
target_transform (callable, optional): A function/transform that takes in the
3737
target and transforms it.
3838
loader (callable, optional): A function to load an image given its path.
39+
By default, it uses PIL as its image loader, but users could also pass in
40+
``torchvision.io.decode_image`` for decoding image data into tensors directly.
3941
4042
Attributes:
4143
classes (list): List of the class name tuples.

torchvision/datasets/places365.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from os import path
33
from pathlib import Path
4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
55
from urllib.parse import urljoin
66

77
from .folder import default_loader
@@ -15,7 +15,7 @@ class Places365(VisionDataset):
1515
Args:
1616
root (str or ``pathlib.Path``): Root directory of the Places365 dataset.
1717
split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
18-
``val``.
18+
``val``, ``test``.
1919
small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
2020
high resolution ones.
2121
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
@@ -36,7 +36,8 @@ class Places365(VisionDataset):
3636
RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
3737
RuntimeError: If ``download is True`` and the image archive is already extracted.
3838
"""
39-
_SPLITS = ("train-standard", "train-challenge", "val")
39+
40+
_SPLITS = ("train-standard", "train-challenge", "val", "test")
4041
_BASE_URL = "http://data.csail.mit.edu/places/places365/"
4142
# {variant: (archive, md5)}
4243
_DEVKIT_META = {
@@ -50,15 +51,18 @@ class Places365(VisionDataset):
5051
"train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
5152
"train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
5253
"val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
54+
"test": ("places365_test.txt", "2fce8233fe493576d724142e45d93653"),
5355
}
5456
# {(split, small): (file, md5)}
5557
_IMAGES_META = {
5658
("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
5759
("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
5860
("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
61+
("test", False): ("test_large.tar", "41a4b6b724b1d2cd862fb3871ed59913"),
5962
("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
6063
("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
6164
("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
65+
("test", True): ("test_256.tar", "f532f6ad7b582262a2ec8009075e186b"),
6266
}
6367

6468
def __init__(
@@ -123,10 +127,14 @@ def process(line: str) -> Tuple[str, int]:
123127

124128
return sorted(class_to_idx.keys()), class_to_idx
125129

126-
def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
127-
def process(line: str, sep="/") -> Tuple[str, int]:
128-
image, idx = line.split()
129-
return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
130+
def load_file_list(
131+
self, download: bool = True
132+
) -> Tuple[List[Tuple[str, Union[int, None]]], List[Union[int, None]]]:
133+
def process(line: str, sep="/") -> Tuple[str, Union[int, None]]:
134+
image, idx = (line.split() + [None])[:2]
135+
image = cast(str, image)
136+
idx = int(idx) if idx is not None else None
137+
return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), idx
130138

131139
file, md5 = self._FILE_LIST_META[self.split]
132140
file = path.join(self.root, file)

0 commit comments

Comments
 (0)