Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Fix cuda initialize issue #497

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions generative/networks/layers/vector_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(
range(1, self.spatial_dims + 1)
)

@torch.cuda.amp.autocast(enabled=False)
def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.
Expand All @@ -100,28 +99,28 @@ def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, to
encoding_indices_view = list(inputs.shape)
del encoding_indices_view[1]

inputs = inputs.float()
with torch.cuda.amp.autocast(enabled=False):
inputs = inputs.float()

# Converting to channel last format
flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim)
# Converting to channel last format
flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim)

# Calculate Euclidean distances
distances = (
(flat_input**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True)
- 2 * torch.mm(flat_input, self.embedding.weight.t())
)
# Calculate Euclidean distances
distances = (
(flat_input**2).sum(dim=1, keepdim=True)
+ (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True)
- 2 * torch.mm(flat_input, self.embedding.weight.t())
)

# Mapping distances to indexes
encoding_indices = torch.max(-distances, dim=1)[1]
encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float()
# Mapping distances to indexes
encoding_indices = torch.max(-distances, dim=1)[1]
encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float()

# Quantize and reshape
encoding_indices = encoding_indices.view(encoding_indices_view)
# Quantize and reshape
encoding_indices = encoding_indices.view(encoding_indices_view)

return flat_input, encodings, encoding_indices

@torch.cuda.amp.autocast(enabled=False)
def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
"""
Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space
Expand All @@ -135,7 +134,8 @@ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Quantize space representation of encoding_indices in channel first format.
"""
return self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
with torch.cuda.amp.autocast(enabled=False):
return self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()

@torch.jit.unused
def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:
Expand Down