27
27
from __future__ import division
28
28
from __future__ import print_function
29
29
30
- from tensor2tensor .models .research .shuffle_network import ShuffleNetwork
31
- from tensor2tensor .models .research .shuffle_network import shuffle_layer
32
- from tensor2tensor .models .research .shuffle_network import reverse_shuffle_layer
30
+ import numpy as np
33
31
from tensor2tensor .layers .common_layers import gelu
32
+ from tensor2tensor .models .research .shuffle_network import reverse_shuffle_layer
33
+ from tensor2tensor .models .research .shuffle_network import shuffle_layer
34
+ from tensor2tensor .models .research .shuffle_network import ShuffleNetwork
34
35
from tensor2tensor .utils import registry
35
-
36
- import numpy as np
37
36
import tensorflow .compat .v1 as tf
38
37
39
38
@@ -46,29 +45,34 @@ def __init__(self, axis=1, epsilon=1e-10, **kwargs):
46
45
Args:
47
46
axis: Tuple or number of axis for calculating mean and variance
48
47
epsilon: Small epsilon to avoid division by zero
48
+ **kwargs: keyword args passed to super.
49
49
"""
50
50
self .axis = axis
51
51
self .epsilon = epsilon
52
52
self .bias = None
53
53
super (LayerNormalization , self ).__init__ (** kwargs )
54
54
55
55
def build (self , input_shape ):
56
- """ Initialize bias weights for layer normalization.
56
+ """Initialize bias weights for layer normalization.
57
+
57
58
Args:
58
59
input_shape: shape of input tensor
59
60
"""
60
61
num_units = input_shape .as_list ()[- 1 ]
61
- self .bias = self .add_weight ("bias" , [ 1 , 1 , num_units ],
62
- initializer = tf .zeros_initializer )
62
+ self .bias = self .add_weight (
63
+ "bias" , [ 1 , 1 , num_units ], initializer = tf .zeros_initializer )
63
64
super (LayerNormalization , self ).build (input_shape )
64
65
65
66
def call (self , inputs , ** kwargs ):
66
- """ Apply Layer Normalization without output bias and gain.
67
+ """Apply Layer Normalization without output bias and gain.
67
68
68
69
Args:
69
- inputs: tensor to be normalized. Axis should be smaller than input
70
- tensor dimensions.
70
+ inputs: tensor to be normalized. Axis should be smaller than input tensor
71
+ dimensions.
71
72
**kwargs: more arguments (unused)
73
+
74
+ Returns:
75
+ tensor output.
72
76
"""
73
77
inputs -= tf .reduce_mean (inputs , axis = self .axis , keepdims = True )
74
78
inputs += self .bias
@@ -81,6 +85,9 @@ def inv_sigmoid(y):
81
85
82
86
Args:
83
87
y: float in range 0 to 1
88
+
89
+ Returns:
90
+ the inverse sigmoid.
84
91
"""
85
92
return np .log (y / (1 - y ))
86
93
@@ -107,7 +114,7 @@ def __init__(self, prefix, dropout, mode, **kwargs):
107
114
self .residual_scale = None
108
115
109
116
residual_weight = 0.9
110
- self .candidate_weight = np .sqrt (1 - residual_weight ** 2 ) * 0.25
117
+ self .candidate_weight = np .sqrt (1 - residual_weight ** 2 ) * 0.25
111
118
self .init_value = inv_sigmoid (residual_weight )
112
119
113
120
def build (self , input_shape ):
@@ -119,33 +126,35 @@ def build(self, input_shape):
119
126
in_units = input_shape [- 1 ]
120
127
middle_units = in_units * 4
121
128
out_units = in_units * 2
122
- init = tf .variance_scaling_initializer (scale = 1.0 , mode = "fan_avg" ,
123
- distribution = "uniform" )
129
+ init = tf .variance_scaling_initializer (
130
+ scale = 1.0 , mode = "fan_avg" , distribution = "uniform" )
124
131
125
- self .first_linear = tf .keras .layers .Dense (middle_units ,
126
- use_bias = False ,
127
- kernel_initializer = init ,
128
- name = self .prefix + "/cand1" )
132
+ self .first_linear = tf .keras .layers .Dense (
133
+ middle_units ,
134
+ use_bias = False ,
135
+ kernel_initializer = init ,
136
+ name = self .prefix + "/cand1" )
129
137
130
- self .second_linear = tf .keras .layers .Dense (out_units ,
131
- kernel_initializer = init ,
132
- name = self .prefix + "/cand2" )
138
+ self .second_linear = tf .keras .layers .Dense (
139
+ out_units , kernel_initializer = init , name = self .prefix + "/cand2" )
133
140
self .layer_norm = LayerNormalization ()
134
141
135
142
init = tf .constant_initializer (self .init_value )
136
- self .residual_scale = self .add_weight (self . prefix + "/residual" ,
137
- [out_units ], initializer = init )
143
+ self .residual_scale = self .add_weight (
144
+ self . prefix + "/residual" , [out_units ], initializer = init )
138
145
super (RSU , self ).build (input_shape )
139
146
140
147
def call (self , inputs , ** kwargs ):
141
148
"""Apply Residual Switch Layer to inputs.
142
149
143
150
Args:
144
- inputs: Input tensor
151
+ inputs: Input tensor.
152
+ **kwargs: unused kwargs.
145
153
146
154
Returns:
147
155
tf.Tensor: New candidate value
148
156
"""
157
+ del kwargs
149
158
input_shape = tf .shape (inputs )
150
159
batch_size = input_shape [0 ]
151
160
length = input_shape [1 ]
@@ -201,7 +210,7 @@ def residual_shuffle_network(inputs, hparams):
201
210
202
211
203
212
def reverse_part (inputs , hparams , n_bits ):
204
- """ Reverse part of Beneš block.
213
+ """Reverse part of Benes block.
205
214
206
215
Repeatably applies interleaved Residual Switch layer and Reverse Shuffle
207
216
Layer. One set of weights used for all Switch layers.
@@ -222,24 +231,23 @@ def reverse_step(state, _):
222
231
return reverse_shuffle_layer (new_state )
223
232
224
233
reverse_outputs = tf .scan (
225
- reverse_step ,
226
- tf .range (n_bits , n_bits * 2 ),
227
- initializer = inputs ,
228
- parallel_iterations = 1 ,
229
- swap_memory = True )
234
+ reverse_step ,
235
+ tf .range (n_bits , n_bits * 2 ),
236
+ initializer = inputs ,
237
+ parallel_iterations = 1 ,
238
+ swap_memory = True )
230
239
231
240
return reverse_outputs [- 1 , :, :, :]
232
241
233
242
234
243
def forward_part (block_out , hparams , n_bits ):
235
- """ Forward part of Beneš block.
244
+ """Forward part of Benes block.
236
245
237
246
Repeatably applies interleaved Residual Switch layer and Shuffle
238
247
Layer. One set of weights used for all Switch layers.
239
248
240
249
Args:
241
- inputs: inputs for forward part. Should be inputs from previous layers
242
- or Beneš block.
250
+ block_out: TODO(authors) document.
243
251
hparams: params of the network.
244
252
n_bits: count of repeated layer applications.
245
253
@@ -254,11 +262,11 @@ def forward_step(state, _):
254
262
return shuffle_layer (new_state )
255
263
256
264
forward_outputs = tf .scan (
257
- forward_step ,
258
- tf .range (0 , n_bits ),
259
- initializer = block_out ,
260
- parallel_iterations = 1 ,
261
- swap_memory = True )
265
+ forward_step ,
266
+ tf .range (0 , n_bits ),
267
+ initializer = block_out ,
268
+ parallel_iterations = 1 ,
269
+ swap_memory = True )
262
270
263
271
return forward_outputs [- 1 , :, :, :]
264
272
@@ -272,6 +280,9 @@ def body(self, features):
272
280
273
281
Args:
274
282
features: dictionary of inputs and targets
283
+
284
+ Returns:
285
+ the network output.
275
286
"""
276
287
277
288
inputs = tf .squeeze (features ["inputs" ], axis = 2 )
0 commit comments