@@ -780,9 +780,9 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
780
780
781
781
def _get_verts_column_indices (
782
782
vertex_head : _PlyElementType ,
783
- ) -> Tuple [List [int ], Optional [List [int ]], float ]:
783
+ ) -> Tuple [List [int ], Optional [List [int ]], float , Optional [ List [ int ]] ]:
784
784
"""
785
- Get the columns of verts and verts_colors in the vertex
785
+ Get the columns of verts, verts_colors, and verts_normals in the vertex
786
786
element of a parsed ply file, together with a color scale factor.
787
787
When the colors are in byte format, they are scaled from 0..255 to [0,1].
788
788
Otherwise they are not scaled.
@@ -793,11 +793,14 @@ def _get_verts_column_indices(
793
793
property double x
794
794
property double y
795
795
property double z
796
+ property double nx
797
+ property double ny
798
+ property double nz
796
799
property uchar red
797
800
property uchar green
798
801
property uchar blue
799
802
800
- then the return value will be ([0,1,2], [6,7,8], 1.0/255)
803
+ then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5] )
801
804
802
805
Args:
803
806
vertex_head: as returned from load_ply_raw.
@@ -807,9 +810,12 @@ def _get_verts_column_indices(
807
810
color_idxs: List[int] of 3 color columns if they are present,
808
811
otherwise None.
809
812
color_scale: value to scale colors by.
813
+ normal_idxs: List[int] of 3 normals columns if they are present,
814
+ otherwise None.
810
815
"""
811
816
point_idxs : List [Optional [int ]] = [None , None , None ]
812
817
color_idxs : List [Optional [int ]] = [None , None , None ]
818
+ normal_idxs : List [Optional [int ]] = [None , None , None ]
813
819
for i , prop in enumerate (vertex_head .properties ):
814
820
if prop .list_size_type is not None :
815
821
raise ValueError ("Invalid vertices in file: did not expect list." )
@@ -819,6 +825,9 @@ def _get_verts_column_indices(
819
825
for j , name in enumerate (["red" , "green" , "blue" ]):
820
826
if prop .name == name :
821
827
color_idxs [j ] = i
828
+ for j , name in enumerate (["nx" , "ny" , "nz" ]):
829
+ if prop .name == name :
830
+ normal_idxs [j ] = i
822
831
if None in point_idxs :
823
832
raise ValueError ("Invalid vertices in file." )
824
833
color_scale = 1.0
@@ -831,21 +840,23 @@ def _get_verts_column_indices(
831
840
point_idxs ,
832
841
None if None in color_idxs else cast (List [int ], color_idxs ),
833
842
color_scale ,
843
+ None if None in normal_idxs else cast (List [int ], normal_idxs ),
834
844
)
835
845
836
846
837
847
def _get_verts (
838
848
header : _PlyHeader , elements : dict
839
- ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
849
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [ torch . Tensor ] ]:
840
850
"""
841
- Get the vertex locations and colors from a parsed ply file.
851
+ Get the vertex locations, colors and normals from a parsed ply file.
842
852
843
853
Args:
844
854
header, elements: as returned from load_ply_raw.
845
855
846
856
Returns:
847
857
verts: FloatTensor of shape (V, 3).
848
858
vertex_colors: None or FloatTensor of shape (V, 3).
859
+ vertex_normals: None or FloatTensor of shape (V, 3).
849
860
"""
850
861
851
862
vertex = elements .get ("vertex" , None )
@@ -854,14 +865,16 @@ def _get_verts(
854
865
if not isinstance (vertex , list ):
855
866
raise ValueError ("Invalid vertices in file." )
856
867
vertex_head = next (head for head in header .elements if head .name == "vertex" )
857
- point_idxs , color_idxs , color_scale = _get_verts_column_indices (vertex_head )
868
+ point_idxs , color_idxs , color_scale , normal_idxs = _get_verts_column_indices (
869
+ vertex_head
870
+ )
858
871
859
872
# Case of no vertices
860
873
if vertex_head .count == 0 :
861
874
verts = torch .zeros ((0 , 3 ), dtype = torch .float32 )
862
875
if color_idxs is None :
863
- return verts , None
864
- return verts , torch .zeros ((0 , 3 ), dtype = torch .float32 )
876
+ return verts , None , None
877
+ return verts , torch .zeros ((0 , 3 ), dtype = torch .float32 ), None
865
878
866
879
# Simple case where the only data is the vertices themselves
867
880
if (
@@ -870,9 +883,10 @@ def _get_verts(
870
883
and vertex [0 ].ndim == 2
871
884
and vertex [0 ].shape [1 ] == 3
872
885
):
873
- return _make_tensor (vertex [0 ], cols = 3 , dtype = torch .float32 ), None
886
+ return _make_tensor (vertex [0 ], cols = 3 , dtype = torch .float32 ), None , None
874
887
875
888
vertex_colors = None
889
+ vertex_normals = None
876
890
877
891
if len (vertex ) == 1 :
878
892
# This is the case where the whole vertex element has one type,
@@ -882,6 +896,10 @@ def _get_verts(
882
896
vertex_colors = color_scale * torch .tensor (
883
897
vertex [0 ][:, color_idxs ], dtype = torch .float32
884
898
)
899
+ if normal_idxs is not None :
900
+ vertex_normals = torch .tensor (
901
+ vertex [0 ][:, normal_idxs ], dtype = torch .float32
902
+ )
885
903
else :
886
904
# The vertex element is heterogeneous. It was read as several arrays,
887
905
# part by part, where a part is a set of properties with the same type.
@@ -913,13 +931,22 @@ def _get_verts(
913
931
partnum , col = prop_to_partnum_col [color_idxs [color ]]
914
932
vertex_colors .numpy ()[:, color ] = vertex [partnum ][:, col ]
915
933
vertex_colors *= color_scale
934
+ if normal_idxs is not None :
935
+ vertex_normals = torch .empty (
936
+ size = (vertex_head .count , 3 ), dtype = torch .float32
937
+ )
938
+ for axis in range (3 ):
939
+ partnum , col = prop_to_partnum_col [normal_idxs [axis ]]
940
+ vertex_normals .numpy ()[:, axis ] = vertex [partnum ][:, col ]
916
941
917
- return verts , vertex_colors
942
+ return verts , vertex_colors , vertex_normals
918
943
919
944
920
945
def _load_ply (
921
946
f , * , path_manager : PathManager
922
- ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [torch .Tensor ]]:
947
+ ) -> Tuple [
948
+ torch .Tensor , Optional [torch .Tensor ], Optional [torch .Tensor ], Optional [torch .Tensor ]
949
+ ]:
923
950
"""
924
951
Load the data from a .ply file.
925
952
@@ -935,10 +962,11 @@ def _load_ply(
935
962
verts: FloatTensor of shape (V, 3).
936
963
faces: None or LongTensor of vertex indices, shape (F, 3).
937
964
vertex_colors: None or FloatTensor of shape (V, 3).
965
+ vertex_normals: None or FloatTensor of shape (V, 3).
938
966
"""
939
967
header , elements = _load_ply_raw (f , path_manager = path_manager )
940
968
941
- verts , vertex_colors = _get_verts (header , elements )
969
+ verts , vertex_colors , vertex_normals = _get_verts (header , elements )
942
970
943
971
face = elements .get ("face" , None )
944
972
if face is not None :
@@ -976,7 +1004,7 @@ def _load_ply(
976
1004
if faces is not None :
977
1005
_check_faces_indices (faces , max_index = verts .shape [0 ])
978
1006
979
- return verts , faces , vertex_colors
1007
+ return verts , faces , vertex_colors , vertex_normals
980
1008
981
1009
982
1010
def load_ply (
@@ -1031,7 +1059,7 @@ def load_ply(
1031
1059
1032
1060
if path_manager is None :
1033
1061
path_manager = PathManager ()
1034
- verts , faces , _ = _load_ply (f , path_manager = path_manager )
1062
+ verts , faces , _ , _ = _load_ply (f , path_manager = path_manager )
1035
1063
if faces is None :
1036
1064
faces = torch .zeros (0 , 3 , dtype = torch .int64 )
1037
1065
@@ -1211,18 +1239,23 @@ def read(
1211
1239
if not endswith (path , self .known_suffixes ):
1212
1240
return None
1213
1241
1214
- verts , faces , verts_colors = _load_ply (f = path , path_manager = path_manager )
1242
+ verts , faces , verts_colors , verts_normals = _load_ply (
1243
+ f = path , path_manager = path_manager
1244
+ )
1215
1245
if faces is None :
1216
1246
faces = torch .zeros (0 , 3 , dtype = torch .int64 )
1217
1247
1218
- textures = None
1248
+ texture = None
1219
1249
if include_textures and verts_colors is not None :
1220
- textures = TexturesVertex ([verts_colors .to (device )])
1250
+ texture = TexturesVertex ([verts_colors .to (device )])
1221
1251
1252
+ if verts_normals is not None :
1253
+ verts_normals = [verts_normals ]
1222
1254
mesh = Meshes (
1223
1255
verts = [verts .to (device )],
1224
1256
faces = [faces .to (device )],
1225
- textures = textures ,
1257
+ textures = texture ,
1258
+ verts_normals = verts_normals ,
1226
1259
)
1227
1260
return mesh
1228
1261
@@ -1286,12 +1319,14 @@ def read(
1286
1319
if not endswith (path , self .known_suffixes ):
1287
1320
return None
1288
1321
1289
- verts , faces , features = _load_ply (f = path , path_manager = path_manager )
1322
+ verts , faces , features , normals = _load_ply (f = path , path_manager = path_manager )
1290
1323
verts = verts .to (device )
1291
1324
if features is not None :
1292
1325
features = [features .to (device )]
1326
+ if normals is not None :
1327
+ normals = [normals .to (device )]
1293
1328
1294
- pointcloud = Pointclouds (points = [verts ], features = features )
1329
+ pointcloud = Pointclouds (points = [verts ], features = features , normals = normals )
1295
1330
return pointcloud
1296
1331
1297
1332
def save (
0 commit comments