Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit cc9d0b6

Browse files
EmilsOzolinscopybara-github
authored andcommitted
Merge of PR #1805
PiperOrigin-RevId: 316719305
1 parent ea1c771 commit cc9d0b6

File tree

2 files changed

+50
-39
lines changed

2 files changed

+50
-39
lines changed

tensor2tensor/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@
4747
from tensor2tensor.models.research import cycle_gan
4848
from tensor2tensor.models.research import gene_expression
4949
from tensor2tensor.models.research import neural_stack
50+
from tensor2tensor.models.research import residual_shuffle_exchange
5051
from tensor2tensor.models.research import rl
5152
from tensor2tensor.models.research import shuffle_network
52-
from tensor2tensor.models.research import residual_shuffle_exchange
5353
from tensor2tensor.models.research import similarity_transformer
5454
from tensor2tensor.models.research import super_lm
5555
from tensor2tensor.models.research import transformer_moe

tensor2tensor/models/research/residual_shuffle_exchange.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@
2727
from __future__ import division
2828
from __future__ import print_function
2929

30-
from tensor2tensor.models.research.shuffle_network import ShuffleNetwork
31-
from tensor2tensor.models.research.shuffle_network import shuffle_layer
32-
from tensor2tensor.models.research.shuffle_network import reverse_shuffle_layer
30+
import numpy as np
3331
from tensor2tensor.layers.common_layers import gelu
32+
from tensor2tensor.models.research.shuffle_network import reverse_shuffle_layer
33+
from tensor2tensor.models.research.shuffle_network import shuffle_layer
34+
from tensor2tensor.models.research.shuffle_network import ShuffleNetwork
3435
from tensor2tensor.utils import registry
35-
36-
import numpy as np
3736
import tensorflow.compat.v1 as tf
3837

3938

@@ -46,29 +45,34 @@ def __init__(self, axis=1, epsilon=1e-10, **kwargs):
4645
Args:
4746
axis: Tuple or number of axis for calculating mean and variance
4847
epsilon: Small epsilon to avoid division by zero
48+
**kwargs: keyword args passed to super.
4949
"""
5050
self.axis = axis
5151
self.epsilon = epsilon
5252
self.bias = None
5353
super(LayerNormalization, self).__init__(**kwargs)
5454

5555
def build(self, input_shape):
56-
""" Initialize bias weights for layer normalization.
56+
"""Initialize bias weights for layer normalization.
57+
5758
Args:
5859
input_shape: shape of input tensor
5960
"""
6061
num_units = input_shape.as_list()[-1]
61-
self.bias = self.add_weight("bias", [1, 1, num_units],
62-
initializer=tf.zeros_initializer)
62+
self.bias = self.add_weight(
63+
"bias", [1, 1, num_units], initializer=tf.zeros_initializer)
6364
super(LayerNormalization, self).build(input_shape)
6465

6566
def call(self, inputs, **kwargs):
66-
""" Apply Layer Normalization without output bias and gain.
67+
"""Apply Layer Normalization without output bias and gain.
6768
6869
Args:
69-
inputs: tensor to be normalized. Axis should be smaller than input
70-
tensor dimensions.
70+
inputs: tensor to be normalized. Axis should be smaller than input tensor
71+
dimensions.
7172
**kwargs: more arguments (unused)
73+
74+
Returns:
75+
tensor output.
7276
"""
7377
inputs -= tf.reduce_mean(inputs, axis=self.axis, keepdims=True)
7478
inputs += self.bias
@@ -81,6 +85,9 @@ def inv_sigmoid(y):
8185
8286
Args:
8387
y: float in range 0 to 1
88+
89+
Returns:
90+
the inverse sigmoid.
8491
"""
8592
return np.log(y / (1 - y))
8693

@@ -107,7 +114,7 @@ def __init__(self, prefix, dropout, mode, **kwargs):
107114
self.residual_scale = None
108115

109116
residual_weight = 0.9
110-
self.candidate_weight = np.sqrt(1 - residual_weight ** 2) * 0.25
117+
self.candidate_weight = np.sqrt(1 - residual_weight**2) * 0.25
111118
self.init_value = inv_sigmoid(residual_weight)
112119

113120
def build(self, input_shape):
@@ -119,33 +126,35 @@ def build(self, input_shape):
119126
in_units = input_shape[-1]
120127
middle_units = in_units * 4
121128
out_units = in_units * 2
122-
init = tf.variance_scaling_initializer(scale=1.0, mode="fan_avg",
123-
distribution="uniform")
129+
init = tf.variance_scaling_initializer(
130+
scale=1.0, mode="fan_avg", distribution="uniform")
124131

125-
self.first_linear = tf.keras.layers.Dense(middle_units,
126-
use_bias=False,
127-
kernel_initializer=init,
128-
name=self.prefix + "/cand1")
132+
self.first_linear = tf.keras.layers.Dense(
133+
middle_units,
134+
use_bias=False,
135+
kernel_initializer=init,
136+
name=self.prefix + "/cand1")
129137

130-
self.second_linear = tf.keras.layers.Dense(out_units,
131-
kernel_initializer=init,
132-
name=self.prefix + "/cand2")
138+
self.second_linear = tf.keras.layers.Dense(
139+
out_units, kernel_initializer=init, name=self.prefix + "/cand2")
133140
self.layer_norm = LayerNormalization()
134141

135142
init = tf.constant_initializer(self.init_value)
136-
self.residual_scale = self.add_weight(self.prefix + "/residual",
137-
[out_units], initializer=init)
143+
self.residual_scale = self.add_weight(
144+
self.prefix + "/residual", [out_units], initializer=init)
138145
super(RSU, self).build(input_shape)
139146

140147
def call(self, inputs, **kwargs):
141148
"""Apply Residual Switch Layer to inputs.
142149
143150
Args:
144-
inputs: Input tensor
151+
inputs: Input tensor.
152+
**kwargs: unused kwargs.
145153
146154
Returns:
147155
tf.Tensor: New candidate value
148156
"""
157+
del kwargs
149158
input_shape = tf.shape(inputs)
150159
batch_size = input_shape[0]
151160
length = input_shape[1]
@@ -201,7 +210,7 @@ def residual_shuffle_network(inputs, hparams):
201210

202211

203212
def reverse_part(inputs, hparams, n_bits):
204-
""" Reverse part of Beneš block.
213+
"""Reverse part of Benes block.
205214
206215
Repeatably applies interleaved Residual Switch layer and Reverse Shuffle
207216
Layer. One set of weights used for all Switch layers.
@@ -222,24 +231,23 @@ def reverse_step(state, _):
222231
return reverse_shuffle_layer(new_state)
223232

224233
reverse_outputs = tf.scan(
225-
reverse_step,
226-
tf.range(n_bits, n_bits * 2),
227-
initializer=inputs,
228-
parallel_iterations=1,
229-
swap_memory=True)
234+
reverse_step,
235+
tf.range(n_bits, n_bits * 2),
236+
initializer=inputs,
237+
parallel_iterations=1,
238+
swap_memory=True)
230239

231240
return reverse_outputs[-1, :, :, :]
232241

233242

234243
def forward_part(block_out, hparams, n_bits):
235-
""" Forward part of Beneš block.
244+
"""Forward part of Benes block.
236245
237246
Repeatably applies interleaved Residual Switch layer and Shuffle
238247
Layer. One set of weights used for all Switch layers.
239248
240249
Args:
241-
inputs: inputs for forward part. Should be inputs from previous layers
242-
or Beneš block.
250+
block_out: TODO(authors) document.
243251
hparams: params of the network.
244252
n_bits: count of repeated layer applications.
245253
@@ -254,11 +262,11 @@ def forward_step(state, _):
254262
return shuffle_layer(new_state)
255263

256264
forward_outputs = tf.scan(
257-
forward_step,
258-
tf.range(0, n_bits),
259-
initializer=block_out,
260-
parallel_iterations=1,
261-
swap_memory=True)
265+
forward_step,
266+
tf.range(0, n_bits),
267+
initializer=block_out,
268+
parallel_iterations=1,
269+
swap_memory=True)
262270

263271
return forward_outputs[-1, :, :, :]
264272

@@ -272,6 +280,9 @@ def body(self, features):
272280
273281
Args:
274282
features: dictionary of inputs and targets
283+
284+
Returns:
285+
the network output.
275286
"""
276287

277288
inputs = tf.squeeze(features["inputs"], axis=2)

0 commit comments

Comments
 (0)