JeffreyXiang's picture
Finalize
a1e3f5f
from typing import *
import torch
from ..voxel import Voxel
import cumesh
from flex_gemm.ops.grid_sample import grid_sample_3d
class Mesh:
def __init__(self,
vertices,
faces,
vertex_attrs=None
):
self.vertices = vertices.float()
self.faces = faces.int()
self.vertex_attrs = vertex_attrs
@property
def device(self):
return self.vertices.device
def to(self, device, non_blocking=False):
return Mesh(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
)
def cuda(self, non_blocking=False):
return self.to('cuda', non_blocking=non_blocking)
def cpu(self):
return self.to('cpu')
def fill_holes(self, max_hole_perimeter=3e-2):
vertices = self.vertices.cuda()
faces = self.faces.cuda()
mesh = cumesh.CuMesh()
mesh.init(vertices, faces)
mesh.get_edges()
mesh.get_boundary_info()
if mesh.num_boundaries == 0:
return
mesh.get_vertex_edge_adjacency()
mesh.get_vertex_boundary_adjacency()
mesh.get_manifold_boundary_adjacency()
mesh.read_manifold_boundary_adjacency()
mesh.get_boundary_connected_components()
mesh.get_boundary_loops()
if mesh.num_boundary_loops == 0:
return
mesh.fill_holes(max_hole_perimeter=max_hole_perimeter)
new_vertices, new_faces = mesh.read()
self.vertices = new_vertices.to(self.device)
self.faces = new_faces.to(self.device)
def remove_faces(self, face_mask: torch.Tensor):
vertices = self.vertices.cuda()
faces = self.faces.cuda()
mesh = cumesh.CuMesh()
mesh.init(vertices, faces)
mesh.remove_faces(face_mask)
new_vertices, new_faces = mesh.read()
self.vertices = new_vertices.to(self.device)
self.faces = new_faces.to(self.device)
def simplify(self, target=1000000, verbose: bool=False, options: dict={}):
vertices = self.vertices.cuda()
faces = self.faces.cuda()
mesh = cumesh.CuMesh()
mesh.init(vertices, faces)
mesh.simplify(target, verbose=verbose, options=options)
new_vertices, new_faces = mesh.read()
self.vertices = new_vertices.to(self.device)
self.faces = new_faces.to(self.device)
class TextureFilterMode:
CLOSEST = 0
LINEAR = 1
class TextureWrapMode:
CLAMP_TO_EDGE = 0
REPEAT = 1
MIRRORED_REPEAT = 2
class AlphaMode:
OPAQUE = 0
MASK = 1
BLEND = 2
class Texture:
def __init__(
self,
image: torch.Tensor,
filter_mode: TextureFilterMode = TextureFilterMode.LINEAR,
wrap_mode: TextureWrapMode = TextureWrapMode.REPEAT
):
self.image = image
self.filter_mode = filter_mode
self.wrap_mode = wrap_mode
def to(self, device, non_blocking=False):
return Texture(
self.image.to(device, non_blocking=non_blocking),
self.filter_mode,
self.wrap_mode,
)
class PbrMaterial:
def __init__(
self,
base_color_texture: Optional[Texture] = None,
base_color_factor: Union[torch.Tensor, List[float]] = [1.0, 1.0, 1.0],
metallic_texture: Optional[Texture] = None,
metallic_factor: float = 1.0,
roughness_texture: Optional[Texture] = None,
roughness_factor: float = 1.0,
alpha_texture: Optional[Texture] = None,
alpha_factor: float = 1.0,
alpha_mode: AlphaMode = AlphaMode.OPAQUE,
alpha_cutoff: float = 0.5,
):
self.base_color_texture = base_color_texture
self.base_color_factor = torch.tensor(base_color_factor, dtype=torch.float32)[:3]
self.metallic_texture = metallic_texture
self.metallic_factor = metallic_factor
self.roughness_texture = roughness_texture
self.roughness_factor = roughness_factor
self.alpha_texture = alpha_texture
self.alpha_factor = alpha_factor
self.alpha_mode = alpha_mode
self.alpha_cutoff = alpha_cutoff
def to(self, device, non_blocking=False):
return PbrMaterial(
base_color_texture=self.base_color_texture.to(device, non_blocking=non_blocking) if self.base_color_texture is not None else None,
base_color_factor=self.base_color_factor.to(device, non_blocking=non_blocking),
metallic_texture=self.metallic_texture.to(device, non_blocking=non_blocking) if self.metallic_texture is not None else None,
metallic_factor=self.metallic_factor,
roughness_texture=self.roughness_texture.to(device, non_blocking=non_blocking) if self.roughness_texture is not None else None,
roughness_factor=self.roughness_factor,
alpha_texture=self.alpha_texture.to(device, non_blocking=non_blocking) if self.alpha_texture is not None else None,
alpha_factor=self.alpha_factor,
alpha_mode=self.alpha_mode,
alpha_cutoff=self.alpha_cutoff,
)
class MeshWithPbrMaterial(Mesh):
def __init__(self,
vertices,
faces,
material_ids,
uv_coords,
materials: List[PbrMaterial],
):
self.vertices = vertices.float()
self.faces = faces.int()
self.material_ids = material_ids # [M]
self.uv_coords = uv_coords # [M, 3, 2]
self.materials = materials
self.layout = {
'base_color': slice(0, 3),
'metallic': slice(3, 4),
'roughness': slice(4, 5),
'alpha': slice(5, 6),
}
def to(self, device, non_blocking=False):
return MeshWithPbrMaterial(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.material_ids.to(device, non_blocking=non_blocking),
self.uv_coords.to(device, non_blocking=non_blocking),
[material.to(device, non_blocking=non_blocking) for material in self.materials],
)
class MeshWithVoxel(Mesh, Voxel):
def __init__(self,
vertices: torch.Tensor,
faces: torch.Tensor,
origin: list,
voxel_size: float,
coords: torch.Tensor,
attrs: torch.Tensor,
voxel_shape: torch.Size,
layout: Dict = {},
):
self.vertices = vertices.float()
self.faces = faces.int()
self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device)
self.voxel_size = voxel_size
self.coords = coords
self.attrs = attrs
self.voxel_shape = voxel_shape
self.layout = layout
def to(self, device, non_blocking=False):
return MeshWithVoxel(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.origin.tolist(),
self.voxel_size,
self.coords.to(device, non_blocking=non_blocking),
self.attrs.to(device, non_blocking=non_blocking),
self.voxel_shape,
self.layout,
)
def query_attrs(self, xyz):
grid = ((xyz - self.origin) / self.voxel_size).reshape(1, -1, 3)
vertex_attrs = grid_sample_3d(
self.attrs,
torch.cat([torch.zeros_like(self.coords[..., :1]), self.coords], dim=-1),
self.voxel_shape,
grid,
mode='trilinear'
)[0]
return vertex_attrs
def query_vertex_attrs(self):
return self.query_attrs(self.vertices)