4
4
import os
5
5
import tempfile
6
6
import unittest
7
- from typing import Tuple , List
8
7
9
8
import numpy as np
10
9
import pytest
@@ -736,7 +735,7 @@ def forward(self, x):
736
735
def create_pytorch_module_with_nested_inputs (tmp_dir ):
737
736
class PTModel (torch .nn .Module ):
738
737
739
- def forward (self , z : Tuple [torch .Tensor , torch .Tensor ]):
738
+ def forward (self , z : tuple [torch .Tensor , torch .Tensor ]):
740
739
z1 , z2 = z
741
740
zeros1 = torch .zeros ((1 , 1 ))
742
741
zeros2 = torch .zeros ((1 , 5 , 1 ))
@@ -761,7 +760,7 @@ def forward(self, z: Tuple[torch.Tensor, torch.Tensor]):
761
760
def create_pytorch_module_with_nested_inputs_compress_to_fp16_default (tmp_dir ):
762
761
class PTModel (torch .nn .Module ):
763
762
764
- def forward (self , z : Tuple [torch .Tensor , torch .Tensor ]):
763
+ def forward (self , z : tuple [torch .Tensor , torch .Tensor ]):
765
764
z1 , z2 = z
766
765
zeros1 = torch .zeros ((1 , 1 ))
767
766
zeros2 = torch .zeros ((1 , 5 , 1 ))
@@ -787,7 +786,7 @@ def forward(self, z: Tuple[torch.Tensor, torch.Tensor]):
787
786
def create_pytorch_module_with_nested_inputs2 (tmp_dir ):
788
787
class PTModel (torch .nn .Module ):
789
788
790
- def forward (self , x : torch .Tensor , z : Tuple [torch .Tensor , torch .Tensor ]):
789
+ def forward (self , x : torch .Tensor , z : tuple [torch .Tensor , torch .Tensor ]):
791
790
z1 , z2 = z
792
791
zeros1 = torch .zeros ((1 , 1 ))
793
792
zeros2 = torch .zeros ((1 , 5 , 1 ))
@@ -815,7 +814,7 @@ def forward(self, x: torch.Tensor, z: Tuple[torch.Tensor, torch.Tensor]):
815
814
def create_pytorch_module_with_nested_inputs3 (tmp_dir ):
816
815
class PTModel (torch .nn .Module ):
817
816
818
- def forward (self , z : Tuple [torch .Tensor , torch .Tensor ], x : torch .Tensor ):
817
+ def forward (self , z : tuple [torch .Tensor , torch .Tensor ], x : torch .Tensor ):
819
818
z1 , z2 = z
820
819
zeros1 = torch .zeros ((1 , 1 ))
821
820
zeros2 = torch .zeros ((1 , 5 , 1 ))
@@ -843,7 +842,7 @@ def forward(self, z: Tuple[torch.Tensor, torch.Tensor], x: torch.Tensor):
843
842
def create_pytorch_module_with_nested_inputs4 (tmp_dir ):
844
843
class PTModel (torch .nn .Module ):
845
844
846
- def forward (self , x : torch .Tensor , z : Tuple [torch .Tensor , torch .Tensor ], y : torch .Tensor ):
845
+ def forward (self , x : torch .Tensor , z : tuple [torch .Tensor , torch .Tensor ], y : torch .Tensor ):
847
846
z1 , z2 = z
848
847
zeros1 = torch .zeros ((1 , 1 ))
849
848
zeros2 = torch .zeros ((1 , 5 , 1 ))
@@ -874,7 +873,7 @@ def forward(self, x: torch.Tensor, z: Tuple[torch.Tensor, torch.Tensor], y: torc
874
873
def create_pytorch_module_with_nested_inputs5 (tmp_dir ):
875
874
class PTModel (torch .nn .Module ):
876
875
877
- def forward (self , x : torch .Tensor , z : Tuple [torch .Tensor , torch .Tensor ], y : torch .Tensor ):
876
+ def forward (self , x : torch .Tensor , z : tuple [torch .Tensor , torch .Tensor ], y : torch .Tensor ):
878
877
z1 , z2 = z
879
878
zeros1 = torch .zeros ((1 , 1 ))
880
879
zeros2 = torch .zeros ((1 , 5 , 1 ))
@@ -904,7 +903,7 @@ def forward(self, x: torch.Tensor, z: Tuple[torch.Tensor, torch.Tensor], y: torc
904
903
def create_pytorch_module_with_nested_inputs6 (tmp_dir ):
905
904
class PTModel (torch .nn .Module ):
906
905
907
- def forward (self , x : torch .Tensor , y : torch .Tensor = None , z : Tuple [torch .Tensor , torch .Tensor ] = None ):
906
+ def forward (self , x : torch .Tensor , y : torch .Tensor = None , z : tuple [torch .Tensor , torch .Tensor ] = None ):
908
907
z1 , z2 = z
909
908
zeros1 = torch .zeros ((1 , 1 ))
910
909
zeros2 = torch .zeros ((1 , 5 , 1 ))
@@ -933,7 +932,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor = None, z: Tuple[torch.Tensor
933
932
934
933
def create_pytorch_module_with_nested_list_and_single_input (tmp_dir ):
935
934
class PTModel (torch .nn .Module ):
936
- def forward (self , x : List [torch .Tensor ]):
935
+ def forward (self , x : list [torch .Tensor ]):
937
936
x0 = x [0 ]
938
937
x0 = torch .cat ([x0 , torch .zeros (1 , 1 )], 1 )
939
938
return x0 + torch .ones ((1 , 1 ))
0 commit comments