Skip to content

Commit 1a57b55

Browse files
marksgrahamKumoLiukvtttheyufan1995binliunls
authored
7227 refactor transformer and diffusion model unet (#7715)
Part of #7227 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <[email protected]> Signed-off-by: kaibo <[email protected]> Signed-off-by: heyufan1995 <[email protected]> Signed-off-by: YunLiu <[email protected]> Signed-off-by: binliu <[email protected]> Signed-off-by: dependabot[bot] <[email protected]> Signed-off-by: axel.vlaminck <[email protected]> Signed-off-by: monai-bot <[email protected]> Signed-off-by: Ibrahim Hadzic <[email protected]> Signed-off-by: Behrooz <[email protected]> Signed-off-by: Timothy Baker <[email protected]> Signed-off-by: Mathijs de Boer <[email protected]> Signed-off-by: Fabian Klopfer <[email protected]> Signed-off-by: Lucas Robinet <[email protected]> Signed-off-by: Lucas Robinet <[email protected]> Signed-off-by: chaoliu <[email protected]> Signed-off-by: cxlcl <[email protected]> Signed-off-by: chaoliu <[email protected]> Signed-off-by: Suraj Pai <[email protected]> Signed-off-by: Juan Pablo de la Cruz Gutiérrez <[email protected]> Signed-off-by: elitap <[email protected]> Signed-off-by: Felix Schnabel <[email protected]> Signed-off-by: YanxuanLiu <[email protected]> Signed-off-by: ytl0623 <[email protected]> Signed-off-by: Dženan Zukić <[email protected]> Signed-off-by: Ishan Dutta <[email protected]> Signed-off-by: John Zielke <[email protected]> Signed-off-by: Mingxin Zheng <[email protected]> Signed-off-by: Vladimir Chernyi <[email protected]> Signed-off-by: Yiheng Wang <[email protected]> Signed-off-by: Szabolcs Botond Lorincz Molnar <[email protected]> Signed-off-by: Lucas Robinet <[email protected]> Signed-off-by: Mingxin <[email protected]> Signed-off-by: Han Wang <[email protected]> Signed-off-by: Konstantin Sukharev <[email protected]> Signed-off-by: Ben Murray <[email protected]> Signed-off-by: Matthew Vine <[email protected]> Signed-off-by: Mark Graham <[email protected]> Signed-off-by: Peter Kaplinsky <[email protected]> Signed-off-by: Simon Jensen <[email protected]> Signed-off-by: NabJa <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Kaibo Tang <[email protected]> Co-authored-by: Yufan He <[email protected]> Co-authored-by: binliunls <[email protected]> Co-authored-by: Ben Murray <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: axel.vlaminck <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mingxin Zheng <[email protected]> Co-authored-by: monai-bot <[email protected]> Co-authored-by: Ibrahim Hadzic <[email protected]> Co-authored-by: Dr. Behrooz Hashemian <[email protected]> Co-authored-by: Timothy J. Baker <[email protected]> Co-authored-by: Mathijs de Boer <[email protected]> Co-authored-by: Mathijs de Boer <[email protected]> Co-authored-by: Fabian Klopfer <[email protected]> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: Lucas Robinet <[email protected]> Co-authored-by: Lucas Robinet <[email protected]> Co-authored-by: cxlcl <[email protected]> Co-authored-by: Suraj Pai <[email protected]> Co-authored-by: Juampa <[email protected]> Co-authored-by: elitap <[email protected]> Co-authored-by: Felix Schnabel <[email protected]> Co-authored-by: YanxuanLiu <[email protected]> Co-authored-by: ytl0623 <[email protected]> Co-authored-by: Dženan Zukić <[email protected]> Co-authored-by: Ishan Dutta <[email protected]> Co-authored-by: johnzielke <[email protected]> Co-authored-by: Vladimir Chernyi <[email protected]> Co-authored-by: Lőrincz-Molnár Szabolcs-Botond <[email protected]> Co-authored-by: Nic Ma <[email protected]> Co-authored-by: Lucas Robinet <[email protected]> Co-authored-by: Han Wang <[email protected]> Co-authored-by: Konstantin Sukharev <[email protected]> Co-authored-by: Matthew Vine <[email protected]> Co-authored-by: Pkaps25 <[email protected]> Co-authored-by: Peter Kaplinsky <[email protected]> Co-authored-by: Simon Jensen <[email protected]> Co-authored-by: NabJa <[email protected]>
1 parent ba188e2 commit 1a57b55

File tree

90 files changed

+1327
-921
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+1327
-921
lines changed

.github/workflows/pythonapp-min.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ on:
99
- main
1010
- releasing/*
1111
pull_request:
12+
head_ref-ignore:
13+
- dev
1214

1315
concurrency:
1416
# automatically cancel the previously triggered workflows when there's a newer version

.github/workflows/pythonapp.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ on:
99
- main
1010
- releasing/*
1111
pull_request:
12+
head_ref-ignore:
13+
- dev
1214

1315
concurrency:
1416
# automatically cancel the previously triggered workflows when there's a newer version
@@ -68,10 +70,10 @@ jobs:
6870
maximum-size: 16GB
6971
disk-root: "D:"
7072
- uses: actions/checkout@v4
71-
- name: Set up Python 3.8
73+
- name: Set up Python 3.9
7274
uses: actions/setup-python@v5
7375
with:
74-
python-version: '3.8'
76+
python-version: '3.9'
7577
- name: Prepare pip wheel
7678
run: |
7779
which python

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ repos:
6969
)$
7070
7171
- repo: https://github.com/hadialqattan/pycln
72-
rev: v2.1.3
72+
rev: v2.4.0
7373
hooks:
7474
- id: pycln
7575
args: [--config=pyproject.toml]

Dockerfile

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@ LABEL maintainer="[email protected]"
1818

1919
# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)
2020
RUN if [[ $(uname -m) =~ "aarch64" ]]; then \
21-
cd /opt && \
22-
git clone --branch v0.12.1 --recursive https://github.com/zarr-developers/numcodecs && \
23-
pip wheel numcodecs && \
24-
rm -r /opt/*.whl && \
25-
rm -rf /opt/numcodecs; \
21+
export CFLAGS="-O3" && \
22+
export DISABLE_NUMCODECS_SSE2=true && \
23+
export DISABLE_NUMCODECS_AVX2=true && \
24+
pip install numcodecs; \
2625
fi
2726

2827
WORKDIR /opt/monai

docs/source/networks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ Layers
426426
.. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer
427427
:members:
428428

429+
=======
429430
`ConjugateGradient`
430431
~~~~~~~~~~~~~~~~~~~
431432
.. autoclass:: ConjugateGradient

monai/apps/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,12 @@ def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5
135135
logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.")
136136
return True
137137
actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES)
138-
actual_hash = actual_hash_func()
138+
139+
if sys.version_info >= (3, 9):
140+
actual_hash = actual_hash_func(usedforsecurity=False) # allows checks on FIPS enabled machines
141+
else:
142+
actual_hash = actual_hash_func()
143+
139144
try:
140145
with open(filepath, "rb") as f:
141146
for chunk in iter(lambda: f.read(1024 * 1024), b""):

monai/bundle/workflows.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ class ConfigWorkflow(BundleWorkflow):
239239
logging_file: config file for `logging` module in the program. for more details:
240240
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
241241
If None, default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo.
242+
If False, the logging logic for the bundle will not be modified.
242243
init_id: ID name of the expected config expression to initialize before running, default to "initialize".
243244
allow a config to have no `initialize` logic and the ID.
244245
run_id: ID name of the expected config expression to run, default to "run".
@@ -278,7 +279,7 @@ def __init__(
278279
self,
279280
config_file: str | Sequence[str],
280281
meta_file: str | Sequence[str] | None = None,
281-
logging_file: str | None = None,
282+
logging_file: str | bool | None = None,
282283
init_id: str = "initialize",
283284
run_id: str = "run",
284285
final_id: str = "finalize",
@@ -307,15 +308,18 @@ def __init__(
307308
super().__init__(workflow_type=workflow_type, meta_file=meta_file, properties_path=properties_path)
308309
self.config_root_path = config_root_path
309310
logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file
310-
if logging_file is not None:
311+
312+
if logging_file is False:
313+
logger.warn(f"Logging file is set to {logging_file}, skipping logging.")
314+
else:
311315
if not os.path.isfile(logging_file):
312316
if logging_file == str(self.config_root_path / "logging.conf"):
313317
logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
314318
else:
315319
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
316320
else:
317321
logger.info(f"Setting logging properties based on config: {logging_file}.")
318-
fileConfig(logging_file, disable_existing_loggers=False)
322+
fileConfig(str(logging_file), disable_existing_loggers=False)
319323

320324
self.parser = ConfigParser()
321325
self.parser.read_config(f=config_file)

monai/fl/client/monai_algo.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,14 @@ def initialize(self, extra=None):
134134
135135
Args:
136136
extra: Dict with additional information that should be provided by FL system,
137-
i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
137+
i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.
138+
You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.
138139
139140
"""
140141
if extra is None:
141142
extra = {}
142143
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
144+
logging_file = extra.get(ExtraItems.LOGGING_FILE, None)
143145
self.logger.info(f"Initializing {self.client_name} ...")
144146

145147
# FL platform needs to provide filepath to configuration files
@@ -149,7 +151,7 @@ def initialize(self, extra=None):
149151
if self.workflow is None:
150152
config_train_files = self._add_config_files(self.config_train_filename)
151153
self.workflow = ConfigWorkflow(
152-
config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train"
154+
config_file=config_train_files, meta_file=None, logging_file=logging_file, workflow_type="train"
153155
)
154156
self.workflow.initialize()
155157
self.workflow.bundle_root = self.bundle_root
@@ -412,13 +414,15 @@ def initialize(self, extra=None):
412414
413415
Args:
414416
extra: Dict with additional information that should be provided by FL system,
415-
i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
417+
i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.
418+
You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.
416419
417420
"""
418421
self._set_cuda_device()
419422
if extra is None:
420423
extra = {}
421424
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
425+
logging_file = extra.get(ExtraItems.LOGGING_FILE, None)
422426
timestamp = time.strftime("%Y%m%d_%H%M%S")
423427
self.logger.info(f"Initializing {self.client_name} ...")
424428
# FL platform needs to provide filepath to configuration files
@@ -434,7 +438,7 @@ def initialize(self, extra=None):
434438
self.train_workflow = ConfigWorkflow(
435439
config_file=config_train_files,
436440
meta_file=None,
437-
logging_file=None,
441+
logging_file=logging_file,
438442
workflow_type="train",
439443
**self.train_kwargs,
440444
)
@@ -459,7 +463,7 @@ def initialize(self, extra=None):
459463
self.eval_workflow = ConfigWorkflow(
460464
config_file=config_eval_files,
461465
meta_file=None,
462-
logging_file=None,
466+
logging_file=logging_file,
463467
workflow_type=self.eval_workflow_name,
464468
**self.eval_kwargs,
465469
)

monai/fl/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ExtraItems(StrEnum):
3030
CLIENT_NAME = "fl_client_name"
3131
APP_ROOT = "fl_app_root"
3232
STATS_SENDER = "fl_stats_sender"
33+
LOGGING_FILE = "logging_file"
3334

3435

3536
class FlPhase(StrEnum):

monai/losses/ds_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: list[float] |
3333
weight_mode: {``"same"``, ``"exp"``, ``"two"``}
3434
Specifies the weights calculation for each image level. Defaults to ``"exp"``.
3535
- ``"same"``: all weights are equal to 1.
36-
- ``"exp"``: exponentially decreasing weights by a power of 2: 0, 0.5, 0.25, 0.125, etc .
36+
- ``"exp"``: exponentially decreasing weights by a power of 2: 1, 0.5, 0.25, 0.125, etc .
3737
- ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc
3838
weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used
3939
regardless of the weight_mode

monai/networks/blocks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .backbone_fpn_utils import BackboneWithFPN
1818
from .convolutions import Convolution, ResidualUnit
1919
from .crf import CRF
20+
from .crossattention import CrossAttentionBlock
2021
from .denseblock import ConvDenseBlock, DenseBlock
2122
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
2223
from .downsample import MaxAvgPool
@@ -31,6 +32,7 @@
3132
from .segresnet_block import ResBlock
3233
from .selfattention import SABlock
3334
from .spade_norm import SPADE
35+
from .spatialattention import SpatialAttentionBlock
3436
from .squeeze_and_excitation import (
3537
ChannelSELayer,
3638
ResidualSELayer,
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import Optional, Tuple
15+
16+
import torch
17+
import torch.nn as nn
18+
19+
from monai.networks.layers.utils import get_rel_pos_embedding_layer
20+
from monai.utils import optional_import
21+
22+
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
23+
24+
25+
class CrossAttentionBlock(nn.Module):
26+
"""
27+
A cross-attention block, based on: "Dosovitskiy et al.,
28+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
29+
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
30+
"""
31+
32+
def __init__(
33+
self,
34+
hidden_size: int,
35+
num_heads: int,
36+
dropout_rate: float = 0.0,
37+
hidden_input_size: int | None = None,
38+
context_input_size: int | None = None,
39+
dim_head: int | None = None,
40+
qkv_bias: bool = False,
41+
save_attn: bool = False,
42+
causal: bool = False,
43+
sequence_length: int | None = None,
44+
rel_pos_embedding: Optional[str] = None,
45+
input_size: Optional[Tuple] = None,
46+
attention_dtype: Optional[torch.dtype] = None,
47+
) -> None:
48+
"""
49+
Args:
50+
hidden_size (int): dimension of hidden layer.
51+
num_heads (int): number of attention heads.
52+
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
53+
hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
54+
context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size.
55+
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
56+
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
57+
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
58+
causal: whether to use causal attention.
59+
sequence_length: if causal is True, it is necessary to specify the sequence length.
60+
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
61+
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
62+
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
63+
positional parameter size.
64+
attention_dtype: cast attention operations to this dtype.
65+
"""
66+
67+
super().__init__()
68+
69+
if not (0 <= dropout_rate <= 1):
70+
raise ValueError("dropout_rate should be between 0 and 1.")
71+
72+
if dim_head:
73+
inner_size = num_heads * dim_head
74+
self.head_dim = dim_head
75+
else:
76+
if hidden_size % num_heads != 0:
77+
raise ValueError("hidden size should be divisible by num_heads.")
78+
inner_size = hidden_size
79+
self.head_dim = hidden_size // num_heads
80+
81+
if causal and sequence_length is None:
82+
raise ValueError("sequence_length is necessary for causal attention.")
83+
84+
self.num_heads = num_heads
85+
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
86+
self.context_input_size = context_input_size if context_input_size else hidden_size
87+
self.out_proj = nn.Linear(inner_size, self.hidden_input_size)
88+
# key, query, value projections
89+
self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias)
90+
self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
91+
self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
92+
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
93+
94+
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
95+
self.drop_output = nn.Dropout(dropout_rate)
96+
self.drop_weights = nn.Dropout(dropout_rate)
97+
98+
self.scale = self.head_dim**-0.5
99+
self.save_attn = save_attn
100+
self.attention_dtype = attention_dtype
101+
102+
self.causal = causal
103+
self.sequence_length = sequence_length
104+
105+
if causal and sequence_length is not None:
106+
# causal mask to ensure that attention is only applied to the left in the input sequence
107+
self.register_buffer(
108+
"causal_mask",
109+
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
110+
)
111+
self.causal_mask: torch.Tensor
112+
113+
self.att_mat = torch.Tensor()
114+
self.rel_positional_embedding = (
115+
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)
116+
if rel_pos_embedding is not None
117+
else None
118+
)
119+
self.input_size = input_size
120+
121+
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None):
122+
"""
123+
Args:
124+
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
125+
context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C
126+
127+
Return:
128+
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
129+
"""
130+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
131+
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)
132+
133+
q = self.to_q(x)
134+
kv = context if context is not None else x
135+
_, kv_t, _ = kv.size()
136+
k = self.to_k(kv)
137+
v = self.to_v(kv)
138+
139+
if self.attention_dtype is not None:
140+
q = q.to(self.attention_dtype)
141+
k = k.to(self.attention_dtype)
142+
143+
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
144+
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
145+
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
146+
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
147+
148+
# apply relative positional embedding if defined
149+
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
150+
151+
if self.causal:
152+
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
153+
154+
att_mat = att_mat.softmax(dim=-1)
155+
156+
if self.save_attn:
157+
# no gradients and new tensor;
158+
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
159+
self.att_mat = att_mat.detach()
160+
161+
att_mat = self.drop_weights(att_mat)
162+
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
163+
x = self.out_rearrange(x)
164+
x = self.out_proj(x)
165+
x = self.drop_output(x)
166+
return x

0 commit comments

Comments
 (0)