Skip to content

Commit cc2840e

Browse files
Jiali Duanfacebook-github-bot
Jiali Duan
authored andcommitted
Write meshes to GLB
Summary: Write the amalgamated mesh from the Mesh module to glb. In this version, the json header and the binary data specified by the buffer are merged into glb. The image texture attributes are added. Reviewed By: bottler Differential Revision: D41489778 fbshipit-source-id: 3af0e9a8f9e9098e73737a254177802e0fb6bd3c
1 parent dba48fb commit cc2840e

File tree

2 files changed

+329
-10
lines changed

2 files changed

+329
-10
lines changed

pytorch3d/io/experimental_gltf_io.py

Lines changed: 257 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import struct
4040
import warnings
4141
from base64 import b64decode
42-
from collections import deque
42+
from collections import defaultdict, deque
4343
from enum import IntEnum
4444
from io import BytesIO
4545
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
@@ -102,14 +102,34 @@ class _ComponentType(IntEnum):
102102
"MAT4": (4, 4),
103103
}
104104

105+
_DTYPE_BYTES: Dict[Any, int] = {
106+
np.int8: 1,
107+
np.uint8: 1,
108+
np.int16: 2,
109+
np.uint16: 2,
110+
np.uint32: 4,
111+
np.float32: 4,
112+
}
113+
114+
115+
class _TargetType(IntEnum):
116+
ARRAY_BUFFER = 34962
117+
ELEMENT_ARRAY_BUFFER = 34963
118+
119+
120+
class OurEncoder(json.JSONEncoder):
121+
def default(self, obj):
122+
if isinstance(obj, np.int64):
123+
return str(obj)
124+
return super(OurEncoder, self).default(obj)
125+
105126

106127
def _read_header(stream: BinaryIO) -> Optional[Tuple[int, int]]:
107128
header = stream.read(12)
108129
magic, version, length = struct.unpack("<III", header)
109130

110131
if magic != _GLTF_MAGIC:
111132
return None
112-
113133
return version, length
114134

115135

@@ -227,7 +247,6 @@ def _access_image(self, image_index: int) -> np.ndarray:
227247
offset = buffer_view.get("byteOffset", 0)
228248

229249
binary_data = self.get_binary_data(buffer_view["buffer"])
230-
231250
bytesio = BytesIO(binary_data[offset : offset + length].tobytes())
232251
with Image.open(bytesio) as f:
233252
array = np.array(f)
@@ -521,6 +540,223 @@ def load_meshes(
521540
return names_meshes_list
522541

523542

543+
class _GLTFWriter:
544+
def __init__(self, data: Meshes, buffer_stream: BinaryIO) -> None:
545+
self._json_data = defaultdict(list)
546+
self.mesh = data
547+
self.buffer_stream = buffer_stream
548+
549+
# initialize json with one scene and one node
550+
scene_index = 0
551+
# pyre-fixme[6]: Incompatible parameter type
552+
self._json_data["scene"] = scene_index
553+
self._json_data["scenes"].append({"nodes": [scene_index]})
554+
self._json_data["asset"] = {"version": "2.0"}
555+
node = {"name": "Node", "mesh": 0}
556+
self._json_data["nodes"].append(node)
557+
558+
# mesh primitives
559+
meshes = defaultdict(list)
560+
# pyre-fixme[6]: Incompatible parameter type
561+
meshes["name"] = "Node-Mesh"
562+
primitives = {
563+
"attributes": {"POSITION": 0, "TEXCOORD_0": 2},
564+
"indices": 1,
565+
"material": 0, # default material
566+
"mode": _PrimitiveMode.TRIANGLES,
567+
}
568+
meshes["primitives"].append(primitives)
569+
self._json_data["meshes"].append(meshes)
570+
571+
# default material
572+
material = {
573+
"name": "material_1",
574+
"pbrMetallicRoughness": {
575+
"baseColorTexture": {"index": 0},
576+
"baseColorFactor": [1, 1, 1, 1],
577+
"metallicFactor": 0,
578+
"roughnessFactor": 0.99,
579+
},
580+
"emissiveFactor": [0, 0, 0],
581+
"alphaMode": "OPAQUE",
582+
}
583+
self._json_data["materials"].append(material)
584+
585+
# default sampler
586+
sampler = {"magFilter": 9729, "minFilter": 9986, "wrapS": 10497, "wrapT": 10497}
587+
self._json_data["samplers"].append(sampler)
588+
589+
# default textures
590+
texture = {"sampler": 0, "source": 0}
591+
self._json_data["textures"].append(texture)
592+
593+
def _write_accessor_json(self, key: str) -> Tuple[int, np.ndarray]:
594+
name = "Node-Mesh_%s" % key
595+
byte_offset = 0
596+
if key == "positions":
597+
data = self.mesh.verts_packed().cpu().numpy()
598+
component_type = _ComponentType.FLOAT
599+
element_type = "VEC3"
600+
buffer_view = 0
601+
element_min = list(map(float, np.min(data, axis=0)))
602+
element_max = list(map(float, np.max(data, axis=0)))
603+
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
604+
elif key == "texcoords":
605+
component_type = _ComponentType.FLOAT
606+
data = self.mesh.textures.verts_uvs_list()[0].cpu().numpy()
607+
data[:, 1] = 1 - data[:, -1] # flip y tex-coordinate
608+
element_type = "VEC2"
609+
buffer_view = 2
610+
element_min = list(map(float, np.min(data, axis=0)))
611+
element_max = list(map(float, np.max(data, axis=0)))
612+
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
613+
elif key == "indices":
614+
component_type = _ComponentType.UNSIGNED_SHORT
615+
data = (
616+
self.mesh.faces_packed()
617+
.cpu()
618+
.numpy()
619+
.astype(_ITEM_TYPES[component_type])
620+
)
621+
element_type = "SCALAR"
622+
buffer_view = 1
623+
element_min = list(map(int, np.min(data, keepdims=True)))
624+
element_max = list(map(int, np.max(data, keepdims=True)))
625+
byte_per_element = (
626+
3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]]
627+
)
628+
else:
629+
raise NotImplementedError(
630+
"invalid key accessor, should be one of positions, indices or texcoords"
631+
)
632+
633+
count = int(data.shape[0])
634+
byte_length = count * byte_per_element
635+
accessor_json = {
636+
"name": name,
637+
"componentType": component_type,
638+
"type": element_type,
639+
"bufferView": buffer_view,
640+
"byteOffset": byte_offset,
641+
"min": element_min,
642+
"max": element_max,
643+
"count": count * 3 if key == "indices" else count,
644+
}
645+
self._json_data["accessors"].append(accessor_json)
646+
return (byte_length, data)
647+
648+
def _write_bufferview(self, key: str, **kwargs):
649+
if key not in ["positions", "texcoords", "indices"]:
650+
raise ValueError("key must be one of positions, texcoords or indices")
651+
652+
bufferview = {
653+
"name": "bufferView_%s" % key,
654+
"buffer": 0,
655+
}
656+
target = _TargetType.ARRAY_BUFFER
657+
if key == "positions":
658+
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
659+
bufferview["byteStride"] = int(byte_per_element)
660+
elif key == "texcoords":
661+
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
662+
target = _TargetType.ARRAY_BUFFER
663+
bufferview["byteStride"] = int(byte_per_element)
664+
elif key == "indices":
665+
byte_per_element = (
666+
3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]]
667+
)
668+
target = _TargetType.ELEMENT_ARRAY_BUFFER
669+
670+
bufferview["target"] = target
671+
bufferview["byteOffset"] = kwargs.get("offset")
672+
bufferview["byteLength"] = kwargs.get("byte_length")
673+
self._json_data["bufferViews"].append(bufferview)
674+
675+
def _write_image_buffer(self, **kwargs) -> Tuple[int, bytes]:
676+
image_np = self.mesh.textures.maps_list()[0].cpu().numpy()
677+
image_array = (image_np * 255.0).astype(np.uint8)
678+
im = Image.fromarray(image_array)
679+
with BytesIO() as f:
680+
im.save(f, format="PNG")
681+
image_data = f.getvalue()
682+
683+
image_data_byte_length = len(image_data)
684+
bufferview_image = {
685+
"buffer": 0,
686+
}
687+
bufferview_image["byteOffset"] = kwargs.get("offset")
688+
bufferview_image["byteLength"] = image_data_byte_length
689+
self._json_data["bufferViews"].append(bufferview_image)
690+
691+
image = {"name": "texture", "mimeType": "image/png", "bufferView": 3}
692+
self._json_data["images"].append(image)
693+
return (image_data_byte_length, image_data)
694+
695+
def save(self):
696+
# check validity of mesh
697+
if self.mesh.verts_packed() is None or self.mesh.faces_packed() is None:
698+
raise ValueError("invalid mesh to save, verts or face indices are empty")
699+
700+
# accessors for positions, texture uvs and face indices
701+
pos_byte, pos_data = self._write_accessor_json("positions")
702+
idx_byte, idx_data = self._write_accessor_json("indices")
703+
include_textures = False
704+
if (
705+
self.mesh.textures is not None
706+
and self.mesh.textures.verts_uvs_list()[0] is not None
707+
):
708+
tex_byte, tex_data = self._write_accessor_json("texcoords")
709+
include_textures = True
710+
711+
# bufferViews for positions, texture coords and indices
712+
byte_offset = 0
713+
self._write_bufferview("positions", byte_length=pos_byte, offset=byte_offset)
714+
byte_offset += pos_byte
715+
716+
self._write_bufferview("indices", byte_length=idx_byte, offset=byte_offset)
717+
byte_offset += idx_byte
718+
719+
if include_textures:
720+
self._write_bufferview(
721+
"texcoords", byte_length=tex_byte, offset=byte_offset
722+
)
723+
byte_offset += tex_byte
724+
725+
# image bufferView
726+
include_image = False
727+
if (
728+
self.mesh.textures is not None
729+
and self.mesh.textures.maps_list()[0] is not None
730+
):
731+
include_image = True
732+
image_byte, image_data = self._write_image_buffer(offset=byte_offset)
733+
byte_offset += image_byte
734+
735+
# buffers
736+
self._json_data["buffers"].append({"byteLength": int(byte_offset)})
737+
738+
# organize into a glb
739+
json_bytes = bytes(json.dumps(self._json_data, cls=OurEncoder), "utf-8")
740+
json_length = len(json_bytes)
741+
742+
# write header
743+
header = struct.pack("<III", _GLTF_MAGIC, 2, json_length + byte_offset)
744+
self.buffer_stream.write(header)
745+
746+
# write json
747+
self.buffer_stream.write(struct.pack("<II", json_length, _JSON_CHUNK_TYPE))
748+
self.buffer_stream.write(json_bytes)
749+
750+
# write binary data
751+
self.buffer_stream.write(struct.pack("<II", byte_offset, _BINARY_CHUNK_TYPE))
752+
self.buffer_stream.write(pos_data)
753+
self.buffer_stream.write(idx_data)
754+
if include_textures:
755+
self.buffer_stream.write(tex_data)
756+
if include_image:
757+
self.buffer_stream.write(image_data)
758+
759+
524760
class MeshGlbFormat(MeshFormatInterpreter):
525761
"""
526762
Implements loading meshes from glTF 2 assets stored in a
@@ -570,4 +806,21 @@ def save(
570806
binary: Optional[bool],
571807
**kwargs,
572808
) -> bool:
573-
return False
809+
"""
810+
Writes all the meshes from the default scene to GLB file.
811+
812+
Args:
813+
data: meshes to save
814+
path: path of the GLB file to write into
815+
path_manager: PathManager object for interpreting the path
816+
817+
Return True if saving succeeds and False otherwise
818+
"""
819+
820+
if not endswith(path, self.known_suffixes):
821+
return False
822+
823+
with _open_file(path, path_manager, "wb") as f:
824+
writer = _GLTFWriter(data, cast(BinaryIO, f))
825+
writer.save()
826+
return True

0 commit comments

Comments
 (0)