Skip to content

Commit d462da2

Browse files
rbngzNicolasHug
andauthored
Added 'test' split support for Places365 dataset (#8928)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 86f8eb0 commit d462da2

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

torchvision/datasets/places365.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from os import path
33
from pathlib import Path
4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
55
from urllib.parse import urljoin
66

77
from .folder import default_loader
@@ -15,7 +15,7 @@ class Places365(VisionDataset):
1515
Args:
1616
root (str or ``pathlib.Path``): Root directory of the Places365 dataset.
1717
split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
18-
``val``.
18+
``val``, ``test``.
1919
small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
2020
high resolution ones.
2121
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
@@ -36,7 +36,8 @@ class Places365(VisionDataset):
3636
RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
3737
RuntimeError: If ``download is True`` and the image archive is already extracted.
3838
"""
39-
_SPLITS = ("train-standard", "train-challenge", "val")
39+
40+
_SPLITS = ("train-standard", "train-challenge", "val", "test")
4041
_BASE_URL = "http://data.csail.mit.edu/places/places365/"
4142
# {variant: (archive, md5)}
4243
_DEVKIT_META = {
@@ -50,15 +51,18 @@ class Places365(VisionDataset):
5051
"train-standard": ("places365_train_standard.txt", "30f37515461640559006b8329efbed1a"),
5152
"train-challenge": ("places365_train_challenge.txt", "b2931dc997b8c33c27e7329c073a6b57"),
5253
"val": ("places365_val.txt", "e9f2fd57bfd9d07630173f4e8708e4b1"),
54+
"test": ("places365_test.txt", "2fce8233fe493576d724142e45d93653"),
5355
}
5456
# {(split, small): (file, md5)}
5557
_IMAGES_META = {
5658
("train-standard", False): ("train_large_places365standard.tar", "67e186b496a84c929568076ed01a8aa1"),
5759
("train-challenge", False): ("train_large_places365challenge.tar", "605f18e68e510c82b958664ea134545f"),
5860
("val", False): ("val_large.tar", "9b71c4993ad89d2d8bcbdc4aef38042f"),
61+
("test", False): ("test_large.tar", "41a4b6b724b1d2cd862fb3871ed59913"),
5962
("train-standard", True): ("train_256_places365standard.tar", "53ca1c756c3d1e7809517cc47c5561c5"),
6063
("train-challenge", True): ("train_256_places365challenge.tar", "741915038a5e3471ec7332404dfb64ef"),
6164
("val", True): ("val_256.tar", "e27b17d8d44f4af9a78502beb927f808"),
65+
("test", True): ("test_256.tar", "f532f6ad7b582262a2ec8009075e186b"),
6266
}
6367

6468
def __init__(
@@ -123,10 +127,14 @@ def process(line: str) -> Tuple[str, int]:
123127

124128
return sorted(class_to_idx.keys()), class_to_idx
125129

126-
def load_file_list(self, download: bool = True) -> Tuple[List[Tuple[str, int]], List[int]]:
127-
def process(line: str, sep="/") -> Tuple[str, int]:
128-
image, idx = line.split()
129-
return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), int(idx)
130+
def load_file_list(
131+
self, download: bool = True
132+
) -> Tuple[List[Tuple[str, Union[int, None]]], List[Union[int, None]]]:
133+
def process(line: str, sep="/") -> Tuple[str, Union[int, None]]:
134+
image, idx = (line.split() + [None])[:2]
135+
image = cast(str, image)
136+
idx = int(idx) if idx is not None else None
137+
return path.join(self.images_dir, image.lstrip(sep).replace(sep, os.sep)), idx
130138

131139
file, md5 = self._FILE_LIST_META[self.split]
132140
file = path.join(self.root, file)

0 commit comments

Comments
 (0)