Skip to content

Commit 5b5af38

Browse files
ChaiBapchyaRohit Kumar Srivastava
authored andcommitted
Add Large Tensor Support for Sequence, NN Ops (apache#15807)
* sequence_last, sequence_reverse, sequence_mask * working softmax_cross_entropy * fix linting, add index_copy * add softmax output * add leaky relu * add pooling * add layernorm * add dropout, activation, batchnorm and update layernorm * address comments to remove some comments * handling imports
1 parent 29d0592 commit 5b5af38

File tree

2 files changed

+293
-8
lines changed

2 files changed

+293
-8
lines changed

tests/nightly/test_large_array.py

Lines changed: 292 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import math
1819
import numpy as np
1920
import mxnet as mx
20-
from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d
21+
22+
from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context
2123
from mxnet import gluon, nd
2224
from tests.python.unittest.common import with_seed
2325

@@ -299,9 +301,11 @@ def test_pick():
299301
def test_depthtospace():
300302
def numpy_depth_to_space(x, blocksize):
301303
b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
302-
tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w])
304+
tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h,
305+
w])
303306
tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2])
304-
y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize])
307+
y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize,
308+
w * blocksize])
305309
return y
306310

307311
shape_inp = (LARGE_X, 8, 4, 2)
@@ -315,9 +319,11 @@ def numpy_depth_to_space(x, blocksize):
315319
def test_spacetodepth():
316320
def numpy_space_to_depth(x, blocksize):
317321
b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
318-
tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize])
322+
tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize,
323+
blocksize])
319324
tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4])
320-
y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // blocksize])
325+
y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize,
326+
w // blocksize])
321327
return y
322328

323329
shape_inp = (LARGE_X, 2, 8, 4)
@@ -327,6 +333,7 @@ def numpy_space_to_depth(x, blocksize):
327333
output = mx.nd.space_to_depth(data, 2)
328334
assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3)
329335

336+
330337
@with_seed()
331338
def test_diag():
332339
a_np = np.random.random((LARGE_X, SMALL_Y)).astype(np.float32)
@@ -358,7 +365,8 @@ def test_ravel_multi_index():
358365
x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, 9, SMALL_Y)
359366
x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y)
360367
indices_2d = [[x1, x2, x3], [y1, y2, y3]]
361-
idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64), shape=(LARGE_X, SMALL_Y))
368+
idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64),
369+
shape=(LARGE_X, SMALL_Y))
362370
idx_numpy = np.ravel_multi_index(indices_2d, (LARGE_X, SMALL_Y))
363371
assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3
364372

@@ -370,7 +378,8 @@ def test_unravel_index():
370378
x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y)
371379
original_2d_indices = [[x1, x2, x3], [y1, y2, y3]]
372380
idx_numpy = np.ravel_multi_index(original_2d_indices, (LARGE_X, SMALL_Y))
373-
indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64), shape=(LARGE_X, SMALL_Y))
381+
indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64),
382+
shape=(LARGE_X, SMALL_Y))
374383
assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all()
375384

376385

@@ -427,13 +436,288 @@ def test_topk():
427436
b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
428437
k = nd.topk(b, k=10, axis=0, dtype=np.int64)
429438
assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y
430-
ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both", is_ascend=False)
439+
ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both",
440+
is_ascend=False)
431441
assert np.all(ind == val)
432442
b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
433443
l = nd.topk(b, k=1, axis=-1, dtype=np.int64, ret_typ="value")
434444
assert l.sum() == np.sum(np.arange(0, SMALL_Y))
435445

436446

447+
def test_sequence_mask():
448+
# Sequence Mask input [max_sequence_length, batch_size, other_feature_dims]
449+
# test with input batch_size = 2
450+
a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y)
451+
452+
# test as identity operator
453+
b = nd.SequenceMask(a)
454+
assert b[-1][0][1] == a[-1][0][1]
455+
assert b.shape == a.shape
456+
457+
# test with default mask
458+
b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]),
459+
use_sequence_length=True)
460+
assert b[0][1][-1] == a[0][1][-1] # first sequence of each batch kept
461+
assert b[-1][-1][-1] != a[-1][-1][-1] # rest sequences masked
462+
assert b[-1][-1][-1] == 0
463+
464+
# test with mask value
465+
b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]),
466+
use_sequence_length=True, value=-1)
467+
assert b[-1][-1][-1] == -1
468+
469+
470+
def test_sequence_reverse():
471+
a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y)
472+
# test as reverse operator
473+
b = nd.SequenceReverse(a)
474+
assert b[-1][0][0] == a[0][0][0]
475+
assert b.shape == a.shape
476+
477+
# test with sequence length
478+
b = nd.SequenceReverse(a, sequence_length=[2, 3])
479+
assert b[1][0][0] == a[0][0][0] # check if reversed
480+
assert b[-1][0][0] == a[-1][0][0] # check if intact
481+
assert b.shape == a.shape
482+
483+
484+
def test_sequence_last():
485+
a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y)
486+
487+
# test if returns last sequence
488+
b = nd.SequenceLast(a)
489+
assert_almost_equal(b, a[-1]) # only checks for (2,SMALL_Y) tensor
490+
assert b.shape == (2, SMALL_Y)
491+
492+
# test with sequence length
493+
# parameter sequence_length - NDArray with shape (batch_size)
494+
# (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
495+
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]),
496+
use_sequence_length=True)
497+
# check if it takes 2nd sequence from the first batch
498+
assert b[0][-1] == a[1][0][-1]
499+
500+
501+
def test_softmax_cross_entropy():
502+
# dtype of input data, mxnet cross entropy set explicitly to float64
503+
# numpy implicitly takes care of double precision
504+
batch_size = SMALL_Y
505+
num_labels = LARGE_X
506+
input_data = mx.nd.ones((batch_size, num_labels), dtype="float64")
507+
input_label = mx.nd.zeros((batch_size,), dtype="float64")
508+
509+
true_softmax = np.full((batch_size, num_labels), (1 / num_labels))
510+
# use 1/batch_size when softmax axis=0
511+
# here 1/num_labels since softmax_cross_entropy uses default axis
512+
# by default axis=1
513+
np_one_hot_label = np.zeros((batch_size, num_labels))
514+
np_one_hot_label[:, 0] = 1
515+
516+
true_softmax_cross_entropy = np.sum(-np.log(true_softmax) *
517+
np_one_hot_label)
518+
mx_softmax_cross_entropy = mx.nd.softmax_cross_entropy(input_data,
519+
input_label,
520+
dtype="float64")
521+
assert_almost_equal(mx_softmax_cross_entropy.asnumpy(),
522+
true_softmax_cross_entropy, rtol=1e-3, atol=1e-5)
523+
524+
525+
def test_index_copy():
526+
x = mx.nd.zeros((LARGE_X, SMALL_Y))
527+
t = mx.nd.arange(1, SMALL_Y + 1).reshape((1, SMALL_Y))
528+
index = mx.nd.array([LARGE_X - 1])
529+
530+
x = mx.nd.contrib.index_copy(x, index, t)
531+
assert x[-1][-1] == t[0][-1]
532+
533+
534+
def testSoftmaxOutput():
535+
x = mx.sym.Variable('x')
536+
label = mx.sym.Variable('label')
537+
x_nd = mx.nd.ones((LARGE_X, SMALL_Y))
538+
grad_x = mx.nd.zeros((LARGE_X, SMALL_Y))
539+
label_nd = mx.nd.ones((LARGE_X))
540+
541+
sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0,
542+
use_ignore=False)
543+
ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd},
544+
args_grad={'x': grad_x})
545+
546+
ex.forward(is_train=True)
547+
softmax_out = ex.outputs[0][0].asnumpy()
548+
expected_softmax_out = (1/SMALL_Y)*mx.nd.ones((SMALL_Y)).asnumpy()
549+
assert np.isclose(softmax_out, expected_softmax_out).all()
550+
551+
ex.backward(is_train=True)
552+
grad_out = ex.grad_arrays[0][0].asnumpy()
553+
k = int(label_nd[0].asscalar())
554+
expected_grad_out = np.zeros((SMALL_Y,))
555+
expected_grad_out[k] = -1
556+
assert np.isclose(grad_out - softmax_out, expected_grad_out).all()
557+
558+
559+
# TODO: correctness of prelu (currently flaky)
560+
def test_leaky_relu():
561+
a = -1*mx.nd.ones((LARGE_X, SMALL_Y))
562+
563+
def test_leaky():
564+
res = mx.nd.LeakyReLU(a, act_type="leaky", slope=0.3)
565+
assert res[-1][-1].asnumpy() == 0.3*a[-1][-1].asnumpy()
566+
567+
def test_elu():
568+
res = mx.nd.LeakyReLU(a, act_type="elu", slope=0.3)
569+
assert res[-1][-1].asnumpy() == 0.3*(np.exp(a[-1][-1].asnumpy())-1)
570+
571+
def test_selu():
572+
lam = 1.0507009873554804934193349852946
573+
alpha = 1.6732632423543772848170429916717
574+
res = mx.nd.LeakyReLU(a, act_type="selu")
575+
assert res[-1][-1].asnumpy() == (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1))
576+
577+
def test_rrelu():
578+
lower = 0.125
579+
upper = 0.333999991
580+
res = mx.nd.LeakyReLU(a, act_type="rrelu")
581+
assert res[-1][-1].asnumpy() == (lower + upper) / 2 * a[-1][-1].asnumpy()
582+
583+
test_leaky()
584+
test_elu()
585+
test_selu()
586+
test_rrelu()
587+
588+
589+
def test_pooling():
590+
a = mx.nd.ones((MEDIUM_X, MEDIUM_X, SMALL_Y, SMALL_Y))
591+
592+
def test_avg_pooling():
593+
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='avg')
594+
assert res[-1][-1][-1][-1] == 1.0000001
595+
assert res.shape == SMALL_Y - 5 + 1
596+
597+
def test_max_pooling():
598+
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='max')
599+
assert res[-1][-1][-1][-1] == 1.
600+
assert res.shape == SMALL_Y - 5 + 1
601+
602+
def test_sum_pooling():
603+
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='sum')
604+
assert res[-1][-1][-1][-1] == 25
605+
assert res.shape == SMALL_Y - 5 + 1
606+
607+
def test_lp_pooling():
608+
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=2)
609+
assert res[-1][-1][-1][-1] == 5.
610+
assert res.shape == SMALL_Y - 5 + 1
611+
612+
res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=1)
613+
assert res[-1][-1][-1][-1] == 25.
614+
assert res.shape == SMALL_Y - 5 + 1
615+
616+
test_avg_pooling()
617+
test_max_pooling()
618+
test_sum_pooling()
619+
test_lp_pooling()
620+
621+
622+
def test_layer_norm():
623+
dtype = np.float32
624+
forward_check_eps = 1E-3
625+
axis = 1
626+
eps = 1E-5
627+
in_shape = (LARGE_X, SMALL_Y)
628+
ctx = mx.cpu()
629+
630+
def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
631+
if axis < 0:
632+
axis += data.ndim
633+
broadcast_shape = [1 for _ in range(data.ndim)]
634+
broadcast_shape[axis] = data.shape[axis]
635+
mean = data.mean(axis=axis, keepdims=True).astype(dtype)
636+
var = data.var(axis=axis, keepdims=True).astype(dtype)
637+
std = np.sqrt(var + dtype(eps)).astype(dtype)
638+
out = np.reshape(gamma, broadcast_shape) * (data - mean) / std + \
639+
np.reshape(beta, broadcast_shape)
640+
return out
641+
data = np.random.normal(0, 1, in_shape).astype(dtype)
642+
gamma = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
643+
beta = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
644+
data_s = mx.symbol.Variable('data')
645+
gamma_s = mx.symbol.Variable('gamma')
646+
beta_s = mx.symbol.Variable('beta')
647+
out_s = mx.symbol.LayerNorm(data=data_s, gamma=gamma_s, beta=beta_s,
648+
axis=axis, eps=eps)
649+
exe = out_s.simple_bind(ctx, data=in_shape)
650+
exe.arg_dict['data'][:] = data
651+
exe.arg_dict['gamma'][:] = gamma
652+
exe.arg_dict['beta'][:] = beta
653+
out_nd = exe.forward()[0]
654+
out = npy_layer_norm(data, gamma, beta, axis, eps)
655+
assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps,
656+
forward_check_eps)
657+
658+
# TODO: correctness of dropout
659+
# currently only test for dropout to work
660+
# since testing for correctness involves flakiness issue #14288
661+
def test_dropout():
662+
shape = (10, 10)
663+
x = mx.sym.var('data')
664+
y = mx.sym.Dropout(x, p=1, cudnn_off=True)
665+
exe = y.simple_bind(ctx=default_context(), data=shape)
666+
exe.arg_arrays[0][:] = 1
667+
out = exe.forward(is_train=True)
668+
out[0].wait_to_read()
669+
670+
671+
def test_activation():
672+
a = mx.nd.ones((LARGE_X, SMALL_Y))
673+
test_x = -2
674+
a[-1, -1] = test_x
675+
676+
# Hyperbolic tangent (tanh)
677+
# y = (exp(x)-exp(-x))/(exp(x)+exp(-x))
678+
a = mx.nd.Activation(a, act_type="tanh")
679+
tanh_x = (np.exp(-2)-np.exp(2))/(np.exp(-2)+np.exp(2))
680+
assert a[-1][-1] == tanh_x
681+
682+
# Recitified Linear Unit (relu)
683+
# y = max(x,0)
684+
a = mx.nd.Activation(a, act_type="relu")
685+
assert a[-1][-1] == 0
686+
687+
# Sigmoid
688+
# y = x/(1+abs(x))
689+
a = mx.nd.Activation(a, act_type="sigmoid")
690+
sigmoid_x = 1/(1+math.exp(-test_x))
691+
assert a[-1][-1] == sigmoid_x
692+
693+
# Soft Sign
694+
# y = 1/(1+exp(-x))
695+
a = mx.nd.Activation(a, act_type="softsign")
696+
softsign_x = test_x/(1+abs(test_x))
697+
assert a[-1][-1] == softsign_x
698+
699+
700+
# TODO: correctness of batchnorm
701+
# in future, we could test if mean, var of output
702+
# matches target output's mean, var
703+
def test_batchnorm():
704+
shape = (LARGE_X, SMALL_Y)
705+
axis = 1 # default
706+
expand_shape = [1] * len(shape)
707+
expand_shape[axis] = shape[axis]
708+
709+
nch = shape[axis]
710+
data = mx.nd.ones(shape=shape)
711+
bn_gamma = mx.nd.random.uniform(shape=(nch,))
712+
bn_beta = mx.nd.random.uniform(shape=(nch,))
713+
bn_running_mean = mx.nd.zeros(nch)
714+
bn_running_var = mx.nd.ones(nch)
715+
716+
output = mx.nd.BatchNorm(data, bn_gamma, bn_beta,
717+
bn_running_mean, bn_running_var)
718+
output.wait_to_read()
719+
720+
437721
def test_add():
438722
a = nd.ones(shape=(LARGE_X, SMALL_Y))
439723
b = nd.ones(shape=(LARGE_X, SMALL_Y))

tests/nightly/test_large_vector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
import mxnet as mx
20+
2021
from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d
2122
from mxnet import gluon, nd
2223
from tests.python.unittest.common import with_seed

0 commit comments

Comments
 (0)