Skip to content

Commit 93a34ea

Browse files
Arkabandhu Chowdhuryfacebook-github-bot
authored andcommitted
Implement resize and train XRayVideo A/V with only resizing (facebookresearch#796)
Summary: Pull Request resolved: facebookresearch#796 We want to check whether training XRayVideo with simply video resizing (in addition to other existing transformation like horizontal flipping and normalization) without random corp is sufficient. The resize dimension is used as 224*224. workflow: f362077622 (Note: in the workflow `fcc_mvit_dataset_v4p2_arkc.yaml` is used which I renamed to `fcc_mvit_dataset_v4p2_onlyresize.yaml` in this diff.) As can be seen, the validation MAP goes to around .422 as opposed to 0.46 when random resized crop is used (f355567669) and rest of the configuration is kept the same. Hence, it is better to keep random resized crop. Differential Revision: D38522980 fbshipit-source-id: 273f0cb1ebe644c13a739d720344fe31fd25fa17
1 parent a34ccc5 commit 93a34ea

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

classy_vision/dataset/transforms/util_video.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class VideoConstants:
3232
MEAN = ImagenetConstants.MEAN #
3333
STD = ImagenetConstants.STD
3434
SIZE_RANGE = (128, 160)
35+
RESCALE_SIZE = (224, 224)
3536
CROP_SIZE = 112
3637

3738

@@ -141,6 +142,44 @@ def __call__(self, clip):
141142
return clip
142143

143144

145+
@register_transform("video_simple_resize")
146+
class VideoSimpleResize(ClassyTransform):
147+
"""Given an input size, rescale the clip to the given size both in
148+
height and width.
149+
"""
150+
151+
def __init__(self, rescale_size: List[int], interpolation_mode: str = "bilinear"):
152+
"""The constructor method of VideoClipResize class.
153+
154+
Args:
155+
rescale_size: size of the rescaled clip
156+
interpolation_mode: Default: "bilinear". See valid values in
157+
(https://pytorch.org/docs/stable/nn.functional.html#torch.nn.
158+
functional.interpolate)
159+
160+
"""
161+
self.interpolation_mode = interpolation_mode
162+
assert (
163+
len(rescale_size) == 2
164+
), "rescale_size should be a list of size 2 (height, width)"
165+
self.rescale_size = rescale_size
166+
167+
def __call__(self, clip):
168+
"""Callable function which applies the tranform to the input clip.
169+
170+
Args:
171+
clip (torch.Tensor): input clip tensor
172+
173+
"""
174+
# clip size: C x T x H x W
175+
clip = torch.nn.functional.interpolate(
176+
clip,
177+
size=self.rescale_size,
178+
mode=self.interpolation_mode,
179+
)
180+
return clip
181+
182+
144183
@register_transform("video_default_augment")
145184
class VideoDefaultAugmentTransform(ClassyTransform):
146185
"""This is the default video transform with data augmentation which is useful for
@@ -190,6 +229,53 @@ def __call__(self, video):
190229
return self._transform(video)
191230

192231

232+
@register_transform("video_resize_augment")
233+
class VideoResizeAugmentTransform(ClassyTransform):
234+
"""This is the resize video transform with data augmentation which is useful for
235+
training.
236+
237+
It sequentially prepares a torch.Tensor of video data,
238+
resizes the video clip to specified size, randomly flips the
239+
video clip horizontally, and normalizes the pixel values by mean subtraction
240+
and standard deviation division.
241+
242+
"""
243+
244+
def __init__(
245+
self,
246+
rescale_size: List[int] = VideoConstants.RESCALE_SIZE,
247+
mean: List[float] = VideoConstants.MEAN,
248+
std: List[float] = VideoConstants.STD,
249+
):
250+
"""The constructor method of VideoResizeAugmentTransform class.
251+
252+
Args:
253+
size: the short edge of rescaled video clip
254+
mean: a 3-tuple denoting the pixel RGB mean
255+
std: a 3-tuple denoting the pixel RGB standard deviation
256+
257+
"""
258+
259+
self._transform = transforms.Compose(
260+
[
261+
transforms_video.ToTensorVideo(),
262+
# TODO(zyan3): migrate VideoClipRandomResizeCrop to TorchVision
263+
VideoSimpleResize(rescale_size),
264+
transforms_video.RandomHorizontalFlipVideo(),
265+
transforms_video.NormalizeVideo(mean=mean, std=std),
266+
]
267+
)
268+
269+
def __call__(self, video):
270+
"""Apply the resize transform with data augmentation to video.
271+
272+
Args:
273+
video: input video that will undergo the transform
274+
275+
"""
276+
return self._transform(video)
277+
278+
193279
@register_transform("video_default_no_augment")
194280
class VideoDefaultNoAugmentTransform(ClassyTransform):
195281
"""This is the default video transform without data augmentation which is useful

0 commit comments

Comments
 (0)