Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Residual Shuffle-Exchange network #1805

Merged
merged 2 commits into from
Jun 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tensor2tensor/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2743,7 +2743,7 @@ def _fn_with_custom_grad(fn, inputs, grad_fn, use_global_vars=False):
Returns:
fn(*inputs)
"""
vs = tf.compat.v1.get_variable_scope()
vs = tf.get_variable_scope()
get_vars_fn = (
vs.global_variables if use_global_vars else vs.trainable_variables)
len_before_vars = len(get_vars_fn())
Expand Down Expand Up @@ -3145,7 +3145,7 @@ def grad_fn(inputs, variables, outputs, output_grads):

@fn_with_custom_grad(grad_fn)
def fn_with_recompute(*args):
cached_vs.append(tf.compat.v1.get_variable_scope())
cached_vs.append(tf.get_variable_scope())
cached_arg_scope.append(contrib.framework().current_arg_scope())
return fn(*args)

Expand All @@ -3160,7 +3160,7 @@ def dense(x, units, **kwargs):
# We need to find the layer parameters using scope name for the layer, so
# check that the layer is named. Otherwise parameters for different layers
# may get mixed up.
layer_name = tf.compat.v1.get_variable_scope().name
layer_name = tf.get_variable_scope().name
if (not layer_name) or ("name" not in kwargs):
raise ValueError(
"Variable scope and layer name cannot be empty. Actual: "
Expand Down Expand Up @@ -3491,7 +3491,7 @@ def should_generate_summaries():
if name_scope and "while/" in name_scope:
# Summaries don't work well within tf.while_loop()
return False
if tf.compat.v1.get_variable_scope().reuse:
if tf.get_variable_scope().reuse:
# Avoid generating separate summaries for different data shards
return False
return True
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from tensor2tensor.models.research import neural_stack
from tensor2tensor.models.research import rl
from tensor2tensor.models.research import shuffle_network
from tensor2tensor.models.research import residual_shuffle_exchange
from tensor2tensor.models.research import similarity_transformer
from tensor2tensor.models.research import super_lm
from tensor2tensor.models.research import transformer_moe
Expand Down
279 changes: 279 additions & 0 deletions tensor2tensor/models/research/residual_shuffle_exchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Residual Shuffle-Exchange Network.

Implementation of
"Residual Shuffle-Exchange Networks for Fast Processing of Long Sequences"
paper by A.Draguns, E.Ozolins, A.Sostaks, M.Apinis, K.Freivalds.

Paper: https://arxiv.org/abs/2004.04662
Original code: https://github.com/LUMII-Syslab/RSE
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensor2tensor.models.research.shuffle_network import ShuffleNetwork
from tensor2tensor.models.research.shuffle_network import shuffle_layer
from tensor2tensor.models.research.shuffle_network import reverse_shuffle_layer
from tensor2tensor.layers.common_layers import gelu
from tensor2tensor.utils import registry

import numpy as np
import tensorflow.compat.v1 as tf


class LayerNormalization(tf.keras.layers.Layer):
"""Layer Normalization (LayerNorm) without output bias and gain."""

def __init__(self, axis=1, epsilon=1e-10, **kwargs):
"""Initialize Layer Normalization layer.

Args:
axis: Tuple or number of axis for calculating mean and variance
epsilon: Small epsilon to avoid division by zero
"""
self.axis = axis
self.epsilon = epsilon
self.bias = None
super(LayerNormalization, self).__init__(**kwargs)

def build(self, input_shape):
""" Initialize bias weights for layer normalization.
Args:
input_shape: shape of input tensor
"""
num_units = input_shape.as_list()[-1]
self.bias = self.add_weight("bias", [1, 1, num_units],
initializer=tf.zeros_initializer)
super(LayerNormalization, self).build(input_shape)

def call(self, inputs, **kwargs):
""" Apply Layer Normalization without output bias and gain.

Args:
inputs: tensor to be normalized. Axis should be smaller than input
tensor dimensions.
**kwargs: more arguments (unused)
"""
inputs -= tf.reduce_mean(inputs, axis=self.axis, keepdims=True)
inputs += self.bias
variance = tf.reduce_mean(tf.square(inputs), self.axis, keepdims=True)
return inputs * tf.math.rsqrt(variance + self.epsilon)


def inv_sigmoid(y):
"""Inverse sigmoid function.

Args:
y: float in range 0 to 1
"""
return np.log(y / (1 - y))


class RSU(tf.keras.layers.Layer):
"""Residual Switch Unit of Residual Shuffle-Exchange network."""

def __init__(self, prefix, dropout, mode, **kwargs):
"""Initialize Switch Layer.

Args:
prefix: Name prefix for switch layer
dropout: Dropout rate
mode: Training mode
**kwargs: more arguments (unused)
"""
super().__init__(**kwargs)
self.prefix = prefix
self.dropout = dropout
self.mode = mode
self.first_linear = None
self.second_linear = None
self.layer_norm = None
self.residual_scale = None

residual_weight = 0.9
self.candidate_weight = np.sqrt(1 - residual_weight ** 2) * 0.25
self.init_value = inv_sigmoid(residual_weight)

def build(self, input_shape):
"""Initialize layer weights and sublayers.

Args:
input_shape: shape of inputs
"""
in_units = input_shape[-1]
middle_units = in_units * 4
out_units = in_units * 2
init = tf.variance_scaling_initializer(scale=1.0, mode="fan_avg",
distribution="uniform")

self.first_linear = tf.keras.layers.Dense(middle_units,
use_bias=False,
kernel_initializer=init,
name=self.prefix + "/cand1")

self.second_linear = tf.keras.layers.Dense(out_units,
kernel_initializer=init,
name=self.prefix + "/cand2")
self.layer_norm = LayerNormalization()

init = tf.constant_initializer(self.init_value)
self.residual_scale = self.add_weight(self.prefix + "/residual",
[out_units], initializer=init)
super(RSU, self).build(input_shape)

def call(self, inputs, **kwargs):
"""Apply Residual Switch Layer to inputs.

Args:
inputs: Input tensor

Returns:
tf.Tensor: New candidate value
"""
input_shape = tf.shape(inputs)
batch_size = input_shape[0]
length = input_shape[1]
num_units = inputs.shape.as_list()[2]

n_bits = tf.log(tf.cast(length - 1, tf.float32)) / tf.log(2.0)
n_bits = tf.floor(n_bits) + 1

reshape_shape = [batch_size, length // 2, num_units * 2]
reshaped_inputs = tf.reshape(inputs, reshape_shape)

first_linear = self.first_linear(reshaped_inputs)
first_linear = self.layer_norm(first_linear)
first_linear = gelu(first_linear)
candidate = self.second_linear(first_linear)

residual = tf.sigmoid(self.residual_scale) * reshaped_inputs
candidate = residual + candidate * self.candidate_weight
candidate = tf.reshape(candidate, input_shape)

if self.dropout > 0:
candidate = tf.nn.dropout(candidate, rate=self.dropout / n_bits)
if self.dropout != 0.0 and self.mode == tf.estimator.ModeKeys.TRAIN:
noise = tf.random_normal(tf.shape(candidate), mean=1.0, stddev=0.001)
candidate = candidate * noise

return candidate


def residual_shuffle_network(inputs, hparams):
"""Residual Shuffle-Exchange network with weight sharing.

Args:
inputs: inputs to the Shuffle-Exchange network. Should be in length of power
of 2.
hparams: Model configuration

Returns:
tf.Tensor: Outputs of the Shuffle-Exchange last layer
"""
input_shape = tf.shape(inputs)
n_bits = tf.log(tf.cast(input_shape[1] - 1, tf.float32)) / tf.log(2.0)
n_bits = tf.cast(n_bits, tf.int32) + 1

block_out = inputs

for k in range(hparams.num_hidden_layers):
with tf.variable_scope("benes_block_" + str(k), reuse=tf.AUTO_REUSE):
forward_output = forward_part(block_out, hparams, n_bits)
block_out = reverse_part(forward_output, hparams, n_bits)

return RSU("last_layer", hparams.dropout, hparams.mode)(block_out)


def reverse_part(inputs, hparams, n_bits):
""" Reverse part of Beneš block.

Repeatably applies interleaved Residual Switch layer and Reverse Shuffle
Layer. One set of weights used for all Switch layers.

Args:
inputs: inputs for reverse part. Should be outputs from forward part.
hparams: params of the network.
n_bits: count of repeated layer applications.

Returns:
tf.Tensor: output of reverse part.
"""
reverse_rsu = RSU("reverse_switch", hparams.dropout, hparams.mode)

def reverse_step(state, _):
with tf.variable_scope("reverse"):
new_state = reverse_rsu(state)
return reverse_shuffle_layer(new_state)

reverse_outputs = tf.scan(
reverse_step,
tf.range(n_bits, n_bits * 2),
initializer=inputs,
parallel_iterations=1,
swap_memory=True)

return reverse_outputs[-1, :, :, :]


def forward_part(block_out, hparams, n_bits):
""" Forward part of Beneš block.

Repeatably applies interleaved Residual Switch layer and Shuffle
Layer. One set of weights used for all Switch layers.

Args:
inputs: inputs for forward part. Should be inputs from previous layers
or Beneš block.
hparams: params of the network.
n_bits: count of repeated layer applications.

Returns:
tf.Tensor: output of forward part.
"""
forward_rsu = RSU("switch", hparams.dropout, hparams.mode)

def forward_step(state, _):
with tf.variable_scope("forward"):
new_state = forward_rsu(state)
return shuffle_layer(new_state)

forward_outputs = tf.scan(
forward_step,
tf.range(0, n_bits),
initializer=block_out,
parallel_iterations=1,
swap_memory=True)

return forward_outputs[-1, :, :, :]


@registry.register_model
class ResidualShuffleExchange(ShuffleNetwork):
"""T2T implementation of Residual Shuffle-Exchange network."""

def body(self, features):
"""Body of Residual Shuffle-Exchange network.

Args:
features: dictionary of inputs and targets
"""

inputs = tf.squeeze(features["inputs"], axis=2)
logits = residual_shuffle_network(inputs, self._hparams)
return tf.expand_dims(logits, axis=2)