15
15
# specific language governing permissions and limitations
16
16
# under the License.
17
17
18
+ import math
18
19
import numpy as np
19
20
import mxnet as mx
20
- from mxnet .test_utils import rand_ndarray , assert_almost_equal , rand_coord_2d
21
+
22
+ from mxnet .test_utils import rand_ndarray , assert_almost_equal , rand_coord_2d , default_context
21
23
from mxnet import gluon , nd
22
24
from tests .python .unittest .common import with_seed
23
25
@@ -299,9 +301,11 @@ def test_pick():
299
301
def test_depthtospace ():
300
302
def numpy_depth_to_space (x , blocksize ):
301
303
b , c , h , w = x .shape [0 ], x .shape [1 ], x .shape [2 ], x .shape [3 ]
302
- tmp = np .reshape (x , [b , blocksize , blocksize , c // (blocksize ** 2 ), h , w ])
304
+ tmp = np .reshape (x , [b , blocksize , blocksize , c // (blocksize ** 2 ), h ,
305
+ w ])
303
306
tmp = np .transpose (tmp , [0 , 3 , 4 , 1 , 5 , 2 ])
304
- y = np .reshape (tmp , [b , c // (blocksize ** 2 ), h * blocksize , w * blocksize ])
307
+ y = np .reshape (tmp , [b , c // (blocksize ** 2 ), h * blocksize ,
308
+ w * blocksize ])
305
309
return y
306
310
307
311
shape_inp = (LARGE_X , 8 , 4 , 2 )
@@ -315,9 +319,11 @@ def numpy_depth_to_space(x, blocksize):
315
319
def test_spacetodepth ():
316
320
def numpy_space_to_depth (x , blocksize ):
317
321
b , c , h , w = x .shape [0 ], x .shape [1 ], x .shape [2 ], x .shape [3 ]
318
- tmp = np .reshape (x , [b , c , h // blocksize , blocksize , w // blocksize , blocksize ])
322
+ tmp = np .reshape (x , [b , c , h // blocksize , blocksize , w // blocksize ,
323
+ blocksize ])
319
324
tmp = np .transpose (tmp , [0 , 3 , 5 , 1 , 2 , 4 ])
320
- y = np .reshape (tmp , [b , c * (blocksize ** 2 ), h // blocksize , w // blocksize ])
325
+ y = np .reshape (tmp , [b , c * (blocksize ** 2 ), h // blocksize ,
326
+ w // blocksize ])
321
327
return y
322
328
323
329
shape_inp = (LARGE_X , 2 , 8 , 4 )
@@ -327,6 +333,7 @@ def numpy_space_to_depth(x, blocksize):
327
333
output = mx .nd .space_to_depth (data , 2 )
328
334
assert_almost_equal (output .asnumpy (), expected , atol = 1e-3 , rtol = 1e-3 )
329
335
336
+
330
337
@with_seed ()
331
338
def test_diag ():
332
339
a_np = np .random .random ((LARGE_X , SMALL_Y )).astype (np .float32 )
@@ -358,7 +365,8 @@ def test_ravel_multi_index():
358
365
x2 , y2 = rand_coord_2d ((LARGE_X - 200 ), LARGE_X , 9 , SMALL_Y )
359
366
x3 , y3 = rand_coord_2d ((LARGE_X - 300 ), LARGE_X , 8 , SMALL_Y )
360
367
indices_2d = [[x1 , x2 , x3 ], [y1 , y2 , y3 ]]
361
- idx = mx .nd .ravel_multi_index (mx .nd .array (indices_2d , dtype = np .int64 ), shape = (LARGE_X , SMALL_Y ))
368
+ idx = mx .nd .ravel_multi_index (mx .nd .array (indices_2d , dtype = np .int64 ),
369
+ shape = (LARGE_X , SMALL_Y ))
362
370
idx_numpy = np .ravel_multi_index (indices_2d , (LARGE_X , SMALL_Y ))
363
371
assert np .sum (1 for i in range (idx .size ) if idx [i ] == idx_numpy [i ]) == 3
364
372
@@ -370,7 +378,8 @@ def test_unravel_index():
370
378
x3 , y3 = rand_coord_2d ((LARGE_X - 300 ), LARGE_X , 8 , SMALL_Y )
371
379
original_2d_indices = [[x1 , x2 , x3 ], [y1 , y2 , y3 ]]
372
380
idx_numpy = np .ravel_multi_index (original_2d_indices , (LARGE_X , SMALL_Y ))
373
- indices_2d = mx .nd .unravel_index (mx .nd .array (idx_numpy , dtype = np .int64 ), shape = (LARGE_X , SMALL_Y ))
381
+ indices_2d = mx .nd .unravel_index (mx .nd .array (idx_numpy , dtype = np .int64 ),
382
+ shape = (LARGE_X , SMALL_Y ))
374
383
assert (indices_2d .asnumpy () == np .array (original_2d_indices )).all ()
375
384
376
385
@@ -427,13 +436,288 @@ def test_topk():
427
436
b = create_2d_tensor (rows = LARGE_X , columns = SMALL_Y )
428
437
k = nd .topk (b , k = 10 , axis = 0 , dtype = np .int64 )
429
438
assert np .sum (k .asnumpy () == (LARGE_X - 1 )) == SMALL_Y
430
- ind , val = mx .nd .topk (b , k = 3 , axis = 0 , dtype = np .int64 , ret_typ = "both" , is_ascend = False )
439
+ ind , val = mx .nd .topk (b , k = 3 , axis = 0 , dtype = np .int64 , ret_typ = "both" ,
440
+ is_ascend = False )
431
441
assert np .all (ind == val )
432
442
b = create_2d_tensor (rows = SMALL_Y , columns = LARGE_X )
433
443
l = nd .topk (b , k = 1 , axis = - 1 , dtype = np .int64 , ret_typ = "value" )
434
444
assert l .sum () == np .sum (np .arange (0 , SMALL_Y ))
435
445
436
446
447
+ def test_sequence_mask ():
448
+ # Sequence Mask input [max_sequence_length, batch_size, other_feature_dims]
449
+ # test with input batch_size = 2
450
+ a = nd .arange (0 , LARGE_X * SMALL_Y * 2 ).reshape (LARGE_X , 2 , SMALL_Y )
451
+
452
+ # test as identity operator
453
+ b = nd .SequenceMask (a )
454
+ assert b [- 1 ][0 ][1 ] == a [- 1 ][0 ][1 ]
455
+ assert b .shape == a .shape
456
+
457
+ # test with default mask
458
+ b = nd .SequenceMask (a , sequence_length = nd .array ([1 , 1 ]),
459
+ use_sequence_length = True )
460
+ assert b [0 ][1 ][- 1 ] == a [0 ][1 ][- 1 ] # first sequence of each batch kept
461
+ assert b [- 1 ][- 1 ][- 1 ] != a [- 1 ][- 1 ][- 1 ] # rest sequences masked
462
+ assert b [- 1 ][- 1 ][- 1 ] == 0
463
+
464
+ # test with mask value
465
+ b = nd .SequenceMask (a , sequence_length = nd .array ([1 , 1 ]),
466
+ use_sequence_length = True , value = - 1 )
467
+ assert b [- 1 ][- 1 ][- 1 ] == - 1
468
+
469
+
470
+ def test_sequence_reverse ():
471
+ a = nd .arange (0 , LARGE_X * SMALL_Y * 2 ).reshape (LARGE_X , 2 , SMALL_Y )
472
+ # test as reverse operator
473
+ b = nd .SequenceReverse (a )
474
+ assert b [- 1 ][0 ][0 ] == a [0 ][0 ][0 ]
475
+ assert b .shape == a .shape
476
+
477
+ # test with sequence length
478
+ b = nd .SequenceReverse (a , sequence_length = [2 , 3 ])
479
+ assert b [1 ][0 ][0 ] == a [0 ][0 ][0 ] # check if reversed
480
+ assert b [- 1 ][0 ][0 ] == a [- 1 ][0 ][0 ] # check if intact
481
+ assert b .shape == a .shape
482
+
483
+
484
+ def test_sequence_last ():
485
+ a = nd .arange (0 , LARGE_X * SMALL_Y * 2 ).reshape (LARGE_X , 2 , SMALL_Y )
486
+
487
+ # test if returns last sequence
488
+ b = nd .SequenceLast (a )
489
+ assert_almost_equal (b , a [- 1 ]) # only checks for (2,SMALL_Y) tensor
490
+ assert b .shape == (2 , SMALL_Y )
491
+
492
+ # test with sequence length
493
+ # parameter sequence_length - NDArray with shape (batch_size)
494
+ # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
495
+ b = nd .SequenceLast (a , sequence_length = mx .nd .array ([2 , 3 ]),
496
+ use_sequence_length = True )
497
+ # check if it takes 2nd sequence from the first batch
498
+ assert b [0 ][- 1 ] == a [1 ][0 ][- 1 ]
499
+
500
+
501
+ def test_softmax_cross_entropy ():
502
+ # dtype of input data, mxnet cross entropy set explicitly to float64
503
+ # numpy implicitly takes care of double precision
504
+ batch_size = SMALL_Y
505
+ num_labels = LARGE_X
506
+ input_data = mx .nd .ones ((batch_size , num_labels ), dtype = "float64" )
507
+ input_label = mx .nd .zeros ((batch_size ,), dtype = "float64" )
508
+
509
+ true_softmax = np .full ((batch_size , num_labels ), (1 / num_labels ))
510
+ # use 1/batch_size when softmax axis=0
511
+ # here 1/num_labels since softmax_cross_entropy uses default axis
512
+ # by default axis=1
513
+ np_one_hot_label = np .zeros ((batch_size , num_labels ))
514
+ np_one_hot_label [:, 0 ] = 1
515
+
516
+ true_softmax_cross_entropy = np .sum (- np .log (true_softmax ) *
517
+ np_one_hot_label )
518
+ mx_softmax_cross_entropy = mx .nd .softmax_cross_entropy (input_data ,
519
+ input_label ,
520
+ dtype = "float64" )
521
+ assert_almost_equal (mx_softmax_cross_entropy .asnumpy (),
522
+ true_softmax_cross_entropy , rtol = 1e-3 , atol = 1e-5 )
523
+
524
+
525
+ def test_index_copy ():
526
+ x = mx .nd .zeros ((LARGE_X , SMALL_Y ))
527
+ t = mx .nd .arange (1 , SMALL_Y + 1 ).reshape ((1 , SMALL_Y ))
528
+ index = mx .nd .array ([LARGE_X - 1 ])
529
+
530
+ x = mx .nd .contrib .index_copy (x , index , t )
531
+ assert x [- 1 ][- 1 ] == t [0 ][- 1 ]
532
+
533
+
534
+ def testSoftmaxOutput ():
535
+ x = mx .sym .Variable ('x' )
536
+ label = mx .sym .Variable ('label' )
537
+ x_nd = mx .nd .ones ((LARGE_X , SMALL_Y ))
538
+ grad_x = mx .nd .zeros ((LARGE_X , SMALL_Y ))
539
+ label_nd = mx .nd .ones ((LARGE_X ))
540
+
541
+ sym = mx .sym .SoftmaxOutput (data = x , label = label , ignore_label = 0 ,
542
+ use_ignore = False )
543
+ ex = sym .bind (ctx = default_context (), args = {'x' : x_nd , 'label' : label_nd },
544
+ args_grad = {'x' : grad_x })
545
+
546
+ ex .forward (is_train = True )
547
+ softmax_out = ex .outputs [0 ][0 ].asnumpy ()
548
+ expected_softmax_out = (1 / SMALL_Y )* mx .nd .ones ((SMALL_Y )).asnumpy ()
549
+ assert np .isclose (softmax_out , expected_softmax_out ).all ()
550
+
551
+ ex .backward (is_train = True )
552
+ grad_out = ex .grad_arrays [0 ][0 ].asnumpy ()
553
+ k = int (label_nd [0 ].asscalar ())
554
+ expected_grad_out = np .zeros ((SMALL_Y ,))
555
+ expected_grad_out [k ] = - 1
556
+ assert np .isclose (grad_out - softmax_out , expected_grad_out ).all ()
557
+
558
+
559
+ # TODO: correctness of prelu (currently flaky)
560
+ def test_leaky_relu ():
561
+ a = - 1 * mx .nd .ones ((LARGE_X , SMALL_Y ))
562
+
563
+ def test_leaky ():
564
+ res = mx .nd .LeakyReLU (a , act_type = "leaky" , slope = 0.3 )
565
+ assert res [- 1 ][- 1 ].asnumpy () == 0.3 * a [- 1 ][- 1 ].asnumpy ()
566
+
567
+ def test_elu ():
568
+ res = mx .nd .LeakyReLU (a , act_type = "elu" , slope = 0.3 )
569
+ assert res [- 1 ][- 1 ].asnumpy () == 0.3 * (np .exp (a [- 1 ][- 1 ].asnumpy ())- 1 )
570
+
571
+ def test_selu ():
572
+ lam = 1.0507009873554804934193349852946
573
+ alpha = 1.6732632423543772848170429916717
574
+ res = mx .nd .LeakyReLU (a , act_type = "selu" )
575
+ assert res [- 1 ][- 1 ].asnumpy () == (lam * alpha * (np .exp (a [- 1 ][- 1 ].asnumpy ())- 1 ))
576
+
577
+ def test_rrelu ():
578
+ lower = 0.125
579
+ upper = 0.333999991
580
+ res = mx .nd .LeakyReLU (a , act_type = "rrelu" )
581
+ assert res [- 1 ][- 1 ].asnumpy () == (lower + upper ) / 2 * a [- 1 ][- 1 ].asnumpy ()
582
+
583
+ test_leaky ()
584
+ test_elu ()
585
+ test_selu ()
586
+ test_rrelu ()
587
+
588
+
589
+ def test_pooling ():
590
+ a = mx .nd .ones ((MEDIUM_X , MEDIUM_X , SMALL_Y , SMALL_Y ))
591
+
592
+ def test_avg_pooling ():
593
+ res = mx .nd .Pooling (a , kernel = (5 , 5 ), pool_type = 'avg' )
594
+ assert res [- 1 ][- 1 ][- 1 ][- 1 ] == 1.0000001
595
+ assert res .shape == SMALL_Y - 5 + 1
596
+
597
+ def test_max_pooling ():
598
+ res = mx .nd .Pooling (a , kernel = (5 , 5 ), pool_type = 'max' )
599
+ assert res [- 1 ][- 1 ][- 1 ][- 1 ] == 1.
600
+ assert res .shape == SMALL_Y - 5 + 1
601
+
602
+ def test_sum_pooling ():
603
+ res = mx .nd .Pooling (a , kernel = (5 , 5 ), pool_type = 'sum' )
604
+ assert res [- 1 ][- 1 ][- 1 ][- 1 ] == 25
605
+ assert res .shape == SMALL_Y - 5 + 1
606
+
607
+ def test_lp_pooling ():
608
+ res = mx .nd .Pooling (a , kernel = (5 , 5 ), pool_type = 'lp' , p_value = 2 )
609
+ assert res [- 1 ][- 1 ][- 1 ][- 1 ] == 5.
610
+ assert res .shape == SMALL_Y - 5 + 1
611
+
612
+ res = mx .nd .Pooling (a , kernel = (5 , 5 ), pool_type = 'lp' , p_value = 1 )
613
+ assert res [- 1 ][- 1 ][- 1 ][- 1 ] == 25.
614
+ assert res .shape == SMALL_Y - 5 + 1
615
+
616
+ test_avg_pooling ()
617
+ test_max_pooling ()
618
+ test_sum_pooling ()
619
+ test_lp_pooling ()
620
+
621
+
622
+ def test_layer_norm ():
623
+ dtype = np .float32
624
+ forward_check_eps = 1E-3
625
+ axis = 1
626
+ eps = 1E-5
627
+ in_shape = (LARGE_X , SMALL_Y )
628
+ ctx = mx .cpu ()
629
+
630
+ def npy_layer_norm (data , gamma , beta , axis = 1 , eps = 1E-5 ):
631
+ if axis < 0 :
632
+ axis += data .ndim
633
+ broadcast_shape = [1 for _ in range (data .ndim )]
634
+ broadcast_shape [axis ] = data .shape [axis ]
635
+ mean = data .mean (axis = axis , keepdims = True ).astype (dtype )
636
+ var = data .var (axis = axis , keepdims = True ).astype (dtype )
637
+ std = np .sqrt (var + dtype (eps )).astype (dtype )
638
+ out = np .reshape (gamma , broadcast_shape ) * (data - mean ) / std + \
639
+ np .reshape (beta , broadcast_shape )
640
+ return out
641
+ data = np .random .normal (0 , 1 , in_shape ).astype (dtype )
642
+ gamma = np .random .normal (0 , 1 , (in_shape [axis ],)).astype (dtype )
643
+ beta = np .random .normal (0 , 1 , (in_shape [axis ],)).astype (dtype )
644
+ data_s = mx .symbol .Variable ('data' )
645
+ gamma_s = mx .symbol .Variable ('gamma' )
646
+ beta_s = mx .symbol .Variable ('beta' )
647
+ out_s = mx .symbol .LayerNorm (data = data_s , gamma = gamma_s , beta = beta_s ,
648
+ axis = axis , eps = eps )
649
+ exe = out_s .simple_bind (ctx , data = in_shape )
650
+ exe .arg_dict ['data' ][:] = data
651
+ exe .arg_dict ['gamma' ][:] = gamma
652
+ exe .arg_dict ['beta' ][:] = beta
653
+ out_nd = exe .forward ()[0 ]
654
+ out = npy_layer_norm (data , gamma , beta , axis , eps )
655
+ assert_almost_equal (out , out_nd .asnumpy (), forward_check_eps ,
656
+ forward_check_eps )
657
+
658
+ # TODO: correctness of dropout
659
+ # currently only test for dropout to work
660
+ # since testing for correctness involves flakiness issue #14288
661
+ def test_dropout ():
662
+ shape = (10 , 10 )
663
+ x = mx .sym .var ('data' )
664
+ y = mx .sym .Dropout (x , p = 1 , cudnn_off = True )
665
+ exe = y .simple_bind (ctx = default_context (), data = shape )
666
+ exe .arg_arrays [0 ][:] = 1
667
+ out = exe .forward (is_train = True )
668
+ out [0 ].wait_to_read ()
669
+
670
+
671
+ def test_activation ():
672
+ a = mx .nd .ones ((LARGE_X , SMALL_Y ))
673
+ test_x = - 2
674
+ a [- 1 , - 1 ] = test_x
675
+
676
+ # Hyperbolic tangent (tanh)
677
+ # y = (exp(x)-exp(-x))/(exp(x)+exp(-x))
678
+ a = mx .nd .Activation (a , act_type = "tanh" )
679
+ tanh_x = (np .exp (- 2 )- np .exp (2 ))/ (np .exp (- 2 )+ np .exp (2 ))
680
+ assert a [- 1 ][- 1 ] == tanh_x
681
+
682
+ # Recitified Linear Unit (relu)
683
+ # y = max(x,0)
684
+ a = mx .nd .Activation (a , act_type = "relu" )
685
+ assert a [- 1 ][- 1 ] == 0
686
+
687
+ # Sigmoid
688
+ # y = x/(1+abs(x))
689
+ a = mx .nd .Activation (a , act_type = "sigmoid" )
690
+ sigmoid_x = 1 / (1 + math .exp (- test_x ))
691
+ assert a [- 1 ][- 1 ] == sigmoid_x
692
+
693
+ # Soft Sign
694
+ # y = 1/(1+exp(-x))
695
+ a = mx .nd .Activation (a , act_type = "softsign" )
696
+ softsign_x = test_x / (1 + abs (test_x ))
697
+ assert a [- 1 ][- 1 ] == softsign_x
698
+
699
+
700
+ # TODO: correctness of batchnorm
701
+ # in future, we could test if mean, var of output
702
+ # matches target output's mean, var
703
+ def test_batchnorm ():
704
+ shape = (LARGE_X , SMALL_Y )
705
+ axis = 1 # default
706
+ expand_shape = [1 ] * len (shape )
707
+ expand_shape [axis ] = shape [axis ]
708
+
709
+ nch = shape [axis ]
710
+ data = mx .nd .ones (shape = shape )
711
+ bn_gamma = mx .nd .random .uniform (shape = (nch ,))
712
+ bn_beta = mx .nd .random .uniform (shape = (nch ,))
713
+ bn_running_mean = mx .nd .zeros (nch )
714
+ bn_running_var = mx .nd .ones (nch )
715
+
716
+ output = mx .nd .BatchNorm (data , bn_gamma , bn_beta ,
717
+ bn_running_mean , bn_running_var )
718
+ output .wait_to_read ()
719
+
720
+
437
721
def test_add ():
438
722
a = nd .ones (shape = (LARGE_X , SMALL_Y ))
439
723
b = nd .ones (shape = (LARGE_X , SMALL_Y ))
0 commit comments