Skip to content

Commit 9ea9bbe

Browse files
committed
fix mps
1 parent da6fd60 commit 9ea9bbe

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/diffusers/models/transformers/transformer_lumina2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,10 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300,
241241

242242
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
243243
freqs_cis = []
244+
# Use float32 for MPS compatibility
245+
dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
244246
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
245-
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64)
247+
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=dtype)
246248
freqs_cis.append(emb)
247249
return freqs_cis
248250

0 commit comments

Comments
 (0)