Skip to content

Commit 6fa66f5

Browse files
bottlerfacebook-github-bot
authored andcommitted
PLY load normals
Summary: Add ability to load normals when they are present in a PLY file. Reviewed By: nikhilaravi Differential Revision: D26458971 fbshipit-source-id: 658270b611f7624eab4f5f62ff438038e1d25723
1 parent b314bee commit 6fa66f5

File tree

2 files changed

+105
-21
lines changed

2 files changed

+105
-21
lines changed

pytorch3d/io/ply_io.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -780,9 +780,9 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
780780

781781
def _get_verts_column_indices(
782782
vertex_head: _PlyElementType,
783-
) -> Tuple[List[int], Optional[List[int]], float]:
783+
) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]:
784784
"""
785-
Get the columns of verts and verts_colors in the vertex
785+
Get the columns of verts, verts_colors, and verts_normals in the vertex
786786
element of a parsed ply file, together with a color scale factor.
787787
When the colors are in byte format, they are scaled from 0..255 to [0,1].
788788
Otherwise they are not scaled.
@@ -793,11 +793,14 @@ def _get_verts_column_indices(
793793
property double x
794794
property double y
795795
property double z
796+
property double nx
797+
property double ny
798+
property double nz
796799
property uchar red
797800
property uchar green
798801
property uchar blue
799802
800-
then the return value will be ([0,1,2], [6,7,8], 1.0/255)
803+
then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5])
801804
802805
Args:
803806
vertex_head: as returned from load_ply_raw.
@@ -807,9 +810,12 @@ def _get_verts_column_indices(
807810
color_idxs: List[int] of 3 color columns if they are present,
808811
otherwise None.
809812
color_scale: value to scale colors by.
813+
normal_idxs: List[int] of 3 normals columns if they are present,
814+
otherwise None.
810815
"""
811816
point_idxs: List[Optional[int]] = [None, None, None]
812817
color_idxs: List[Optional[int]] = [None, None, None]
818+
normal_idxs: List[Optional[int]] = [None, None, None]
813819
for i, prop in enumerate(vertex_head.properties):
814820
if prop.list_size_type is not None:
815821
raise ValueError("Invalid vertices in file: did not expect list.")
@@ -819,6 +825,9 @@ def _get_verts_column_indices(
819825
for j, name in enumerate(["red", "green", "blue"]):
820826
if prop.name == name:
821827
color_idxs[j] = i
828+
for j, name in enumerate(["nx", "ny", "nz"]):
829+
if prop.name == name:
830+
normal_idxs[j] = i
822831
if None in point_idxs:
823832
raise ValueError("Invalid vertices in file.")
824833
color_scale = 1.0
@@ -831,21 +840,23 @@ def _get_verts_column_indices(
831840
point_idxs,
832841
None if None in color_idxs else cast(List[int], color_idxs),
833842
color_scale,
843+
None if None in normal_idxs else cast(List[int], normal_idxs),
834844
)
835845

836846

837847
def _get_verts(
838848
header: _PlyHeader, elements: dict
839-
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
849+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
840850
"""
841-
Get the vertex locations and colors from a parsed ply file.
851+
Get the vertex locations, colors and normals from a parsed ply file.
842852
843853
Args:
844854
header, elements: as returned from load_ply_raw.
845855
846856
Returns:
847857
verts: FloatTensor of shape (V, 3).
848858
vertex_colors: None or FloatTensor of shape (V, 3).
859+
vertex_normals: None or FloatTensor of shape (V, 3).
849860
"""
850861

851862
vertex = elements.get("vertex", None)
@@ -854,14 +865,16 @@ def _get_verts(
854865
if not isinstance(vertex, list):
855866
raise ValueError("Invalid vertices in file.")
856867
vertex_head = next(head for head in header.elements if head.name == "vertex")
857-
point_idxs, color_idxs, color_scale = _get_verts_column_indices(vertex_head)
868+
point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices(
869+
vertex_head
870+
)
858871

859872
# Case of no vertices
860873
if vertex_head.count == 0:
861874
verts = torch.zeros((0, 3), dtype=torch.float32)
862875
if color_idxs is None:
863-
return verts, None
864-
return verts, torch.zeros((0, 3), dtype=torch.float32)
876+
return verts, None, None
877+
return verts, torch.zeros((0, 3), dtype=torch.float32), None
865878

866879
# Simple case where the only data is the vertices themselves
867880
if (
@@ -870,9 +883,10 @@ def _get_verts(
870883
and vertex[0].ndim == 2
871884
and vertex[0].shape[1] == 3
872885
):
873-
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None
886+
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None
874887

875888
vertex_colors = None
889+
vertex_normals = None
876890

877891
if len(vertex) == 1:
878892
# This is the case where the whole vertex element has one type,
@@ -882,6 +896,10 @@ def _get_verts(
882896
vertex_colors = color_scale * torch.tensor(
883897
vertex[0][:, color_idxs], dtype=torch.float32
884898
)
899+
if normal_idxs is not None:
900+
vertex_normals = torch.tensor(
901+
vertex[0][:, normal_idxs], dtype=torch.float32
902+
)
885903
else:
886904
# The vertex element is heterogeneous. It was read as several arrays,
887905
# part by part, where a part is a set of properties with the same type.
@@ -913,13 +931,22 @@ def _get_verts(
913931
partnum, col = prop_to_partnum_col[color_idxs[color]]
914932
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
915933
vertex_colors *= color_scale
934+
if normal_idxs is not None:
935+
vertex_normals = torch.empty(
936+
size=(vertex_head.count, 3), dtype=torch.float32
937+
)
938+
for axis in range(3):
939+
partnum, col = prop_to_partnum_col[normal_idxs[axis]]
940+
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
916941

917-
return verts, vertex_colors
942+
return verts, vertex_colors, vertex_normals
918943

919944

920945
def _load_ply(
921946
f, *, path_manager: PathManager
922-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
947+
) -> Tuple[
948+
torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]
949+
]:
923950
"""
924951
Load the data from a .ply file.
925952
@@ -935,10 +962,11 @@ def _load_ply(
935962
verts: FloatTensor of shape (V, 3).
936963
faces: None or LongTensor of vertex indices, shape (F, 3).
937964
vertex_colors: None or FloatTensor of shape (V, 3).
965+
vertex_normals: None or FloatTensor of shape (V, 3).
938966
"""
939967
header, elements = _load_ply_raw(f, path_manager=path_manager)
940968

941-
verts, vertex_colors = _get_verts(header, elements)
969+
verts, vertex_colors, vertex_normals = _get_verts(header, elements)
942970

943971
face = elements.get("face", None)
944972
if face is not None:
@@ -976,7 +1004,7 @@ def _load_ply(
9761004
if faces is not None:
9771005
_check_faces_indices(faces, max_index=verts.shape[0])
9781006

979-
return verts, faces, vertex_colors
1007+
return verts, faces, vertex_colors, vertex_normals
9801008

9811009

9821010
def load_ply(
@@ -1031,7 +1059,7 @@ def load_ply(
10311059

10321060
if path_manager is None:
10331061
path_manager = PathManager()
1034-
verts, faces, _ = _load_ply(f, path_manager=path_manager)
1062+
verts, faces, _, _ = _load_ply(f, path_manager=path_manager)
10351063
if faces is None:
10361064
faces = torch.zeros(0, 3, dtype=torch.int64)
10371065

@@ -1211,18 +1239,23 @@ def read(
12111239
if not endswith(path, self.known_suffixes):
12121240
return None
12131241

1214-
verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager)
1242+
verts, faces, verts_colors, verts_normals = _load_ply(
1243+
f=path, path_manager=path_manager
1244+
)
12151245
if faces is None:
12161246
faces = torch.zeros(0, 3, dtype=torch.int64)
12171247

1218-
textures = None
1248+
texture = None
12191249
if include_textures and verts_colors is not None:
1220-
textures = TexturesVertex([verts_colors.to(device)])
1250+
texture = TexturesVertex([verts_colors.to(device)])
12211251

1252+
if verts_normals is not None:
1253+
verts_normals = [verts_normals]
12221254
mesh = Meshes(
12231255
verts=[verts.to(device)],
12241256
faces=[faces.to(device)],
1225-
textures=textures,
1257+
textures=texture,
1258+
verts_normals=verts_normals,
12261259
)
12271260
return mesh
12281261

@@ -1286,12 +1319,14 @@ def read(
12861319
if not endswith(path, self.known_suffixes):
12871320
return None
12881321

1289-
verts, faces, features = _load_ply(f=path, path_manager=path_manager)
1322+
verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager)
12901323
verts = verts.to(device)
12911324
if features is not None:
12921325
features = [features.to(device)]
1326+
if normals is not None:
1327+
normals = [normals.to(device)]
12931328

1294-
pointcloud = Pointclouds(points=[verts], features=features)
1329+
pointcloud = Pointclouds(points=[verts], features=features, normals=normals)
12951330
return pointcloud
12961331

12971332
def save(

tests/test_io_ply.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,18 @@ def test_save_load_meshes(self):
216216
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
217217
)
218218
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
219+
normals = torch.tensor(
220+
[[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]], dtype=torch.float32
221+
)
219222
vert_colors = torch.rand_like(verts)
220223
texture = TexturesVertex(verts_features=[vert_colors])
221224

222-
for do_textures in itertools.product([True, False]):
225+
for do_textures, do_normals in itertools.product([True, False], [True, False]):
223226
mesh = Meshes(
224227
verts=[verts],
225228
faces=[faces],
226229
textures=texture if do_textures else None,
230+
verts_normals=[normals] if do_normals else None,
227231
)
228232
device = torch.device("cuda:0")
229233

@@ -236,12 +240,57 @@ def test_save_load_meshes(self):
236240
mesh2 = mesh2.cpu()
237241
self.assertClose(mesh2.verts_padded(), mesh.verts_padded())
238242
self.assertClose(mesh2.faces_padded(), mesh.faces_padded())
243+
if do_normals:
244+
self.assertTrue(mesh.has_verts_normals())
245+
self.assertTrue(mesh2.has_verts_normals())
246+
self.assertClose(
247+
mesh2.verts_normals_padded(), mesh.verts_normals_padded()
248+
)
249+
else:
250+
self.assertFalse(mesh.has_verts_normals())
251+
self.assertFalse(mesh2.has_verts_normals())
252+
self.assertFalse(torch.allclose(mesh2.verts_normals_padded(), normals))
239253
if do_textures:
240254
self.assertIsInstance(mesh2.textures, TexturesVertex)
241255
self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors)
242256
else:
243257
self.assertIsNone(mesh2.textures)
244258

259+
def test_save_load_with_normals(self):
260+
points = torch.tensor(
261+
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
262+
)
263+
normals = torch.tensor(
264+
[[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]], dtype=torch.float32
265+
)
266+
features = torch.rand_like(points)
267+
268+
for do_features, do_normals in itertools.product([True, False], [True, False]):
269+
cloud = Pointclouds(
270+
points=[points],
271+
features=[features] if do_features else None,
272+
normals=[normals] if do_normals else None,
273+
)
274+
device = torch.device("cuda:0")
275+
276+
io = IO()
277+
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
278+
io.save_pointcloud(cloud.cuda(), f.name)
279+
f.flush()
280+
cloud2 = io.load_pointcloud(f.name, device=device)
281+
self.assertEqual(cloud2.device, device)
282+
cloud2 = cloud2.cpu()
283+
self.assertClose(cloud2.points_padded(), cloud.points_padded())
284+
if do_normals:
285+
self.assertClose(cloud2.normals_padded(), cloud.normals_padded())
286+
else:
287+
self.assertIsNone(cloud.normals_padded())
288+
self.assertIsNone(cloud2.normals_padded())
289+
if do_features:
290+
self.assertClose(cloud2.features_packed(), features)
291+
else:
292+
self.assertIsNone(cloud2.features_packed())
293+
245294
def test_save_ply_invalid_shapes(self):
246295
# Invalid vertices shape
247296
with self.assertRaises(ValueError) as error:

0 commit comments

Comments
 (0)