diff --git a/examples/task_adaptation/image_classification/utils.py b/examples/task_adaptation/image_classification/utils.py index 052b29e2..3f97c533 100644 --- a/examples/task_adaptation/image_classification/utils.py +++ b/examples/task_adaptation/image_classification/utils.py @@ -41,8 +41,9 @@ def get_model(model_name, pretrained_checkpoint=None): backbone = timm.create_model(model_name, pretrained=True) try: backbone.out_features = backbone.get_classifier().in_features + backbone.fc_imagenet = backbone.fc backbone.reset_classifier(0, '') - backbone.copy_head = backbone.get_classifier + backbone.copy_head = lambda: copy.deepcopy(backbone.fc_imagenet) except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() diff --git a/tllib/normalization/stochnorm.py b/tllib/normalization/stochnorm.py index b6b3b304..ef4807f3 100644 --- a/tllib/normalization/stochnorm.py +++ b/tllib/normalization/stochnorm.py @@ -63,17 +63,20 @@ def forward(self, input): if input.dim() == 2: s = torch.from_numpy( np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1, - self.num_features)).float().cuda() + self.num_features)).float() elif input.dim() == 3: s = torch.from_numpy( np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1, self.num_features, - 1)).float().cuda() + 1)).float() elif input.dim() == 4: s = torch.from_numpy( np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1, self.num_features, 1, - 1)).float().cuda() + 1)).float() else: raise BaseException() + + if torch.cuda.is_available(): + s = s.cuda() z = (1 - s) * z_0 + s * z_1 else: