diff --git a/.gitattributes b/.gitattributes index 845e614ff76083c3ec3f4874d44597b728a4a900..bb6009f5153f1a1a2d6c8329fc84d8730944711f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,4 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/*.glb filter=lfs diff=lfs merge=lfs -text examples/*.png filter=lfs diff=lfs merge=lfs -text diff --git a/PartField/LICENSE b/PartField/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..8388f722de77ca12cac5a752999e51882096af68 --- /dev/null +++ b/PartField/LICENSE @@ -0,0 +1,36 @@ +NVIDIA License + +1. Definitions + +“Licensor” means any person or entity that distributes its Work. +“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. +The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. +Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. + +2. License Grant + +2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. + +3. Limitations + +3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. + +3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. + +3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for non-commercial research and educational purposes only. + +3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. + +3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. + +3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + diff --git a/PartField/README.md b/PartField/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b91918fc75ca41183cb46632b045d7c866bed33b --- /dev/null +++ b/PartField/README.md @@ -0,0 +1,242 @@ +# PartField: Learning 3D Feature Fields for Part Segmentation and Beyond [ICCV 2025] +**[[Project]](https://research.nvidia.com/labs/toronto-ai/partfield-release/)** **[[PDF]](https://arxiv.org/pdf/2504.11451)** + +Minghua Liu*, Mikaela Angelina Uy*, Donglai Xiang, Hao Su, Sanja Fidler, Nicholas Sharp, Jun Gao + + + +## Overview +![Alt text](assets/teaser.png) + +PartField is a feedforward model that predicts part-based feature fields for 3D shapes. Our learned features can be clustered to yield a high-quality part decomposition, outperforming the latest open-world 3D part segmentation approaches in both quality and speed. PartField can be applied to a wide variety of inputs in terms of modality, semantic class, and style. The learned feature field exhibits consistency across shapes, enabling applications such as cosegmentation, interactive selection, and correspondence. + +## Table of Contents + +- [Pretrained Model](#pretrained-model) +- [Environment Setup](#environment-setup) +- [TLDR](#tldr) +- [Example Run](#example-run) +- [Interactive Tools and Applications](#interactive-tools-and-applications) +- [Evaluation on PartObjaverse-Tiny](#evaluation-on-partobjaverse-tiny) +- [Discussion](#discussion-clustering-with-messy-mesh-connectivities) +- [Citation](#citation) + + +## Pretrained Model +``` +mkdir model +``` +The link to download our pretrained model is here: [Trained on Objaverse](https://huggingface.co/mikaelaangel/partfield-ckpt/blob/main/model_objaverse.ckpt). Due to licensing restrictions, we are unable to release the model that was also trained on PartNet. + +## Environment Setup + +We use Python 3.10 with PyTorch 2.4 and CUDA 12.4. The environment and required packages can be installed individually as follows: +``` +conda create -n partfield python=3.10 +conda activate partfield +conda install nvidia/label/cuda-12.4.0::cuda +pip install psutil +pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124 +pip install lightning==2.2 h5py yacs trimesh scikit-image loguru boto3 +pip install mesh2sdf tetgen pymeshlab plyfile einops libigl polyscope potpourri3d simple_parsing arrgh open3d +pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cu124.html +apt install libx11-6 libgl1 libxrender1 +pip install vtk +``` + +An environment file is also provided and can be used for installation: +``` +conda env create -f environment.yml +conda activate partfield +``` + +## TLDR +1. Input data (`.obj` or `.glb` for meshes, `.ply` for splats) are stored in subfolders under `data/`. You can create a new subfolder and copy your custom files into it. +2. Extract PartField features by running the script `partfield_inference.py`, passing the arguments `result_name [FEAT_FOL]` and `dataset.data_path [DATA_PATH]`. The output features will be saved in `exp_results/partfield_features/[FEAT_FOL]`. +3. Segmented parts can be obtained by running the script `run_part_clustering.py`, using the arguments `--root exp/[FEAT_FOL]` and `--dump_dir [PART_OUT_FOL]`. The output segmentations will be saved in `exp_results/clustering/[PART_OUT_FOL]`. +4. Application demo scripts are available in the `applications/` directory and can be used after extracting PartField features (i.e., after running `partfield_inference.py` on the desired demo data). + +## Example Run +### Download Demo Data + +#### Mesh Data +We showcase the feasibility of PartField using sample meshes from Objaverse (artist-created) and Trellis3D (AI-generated). Sample data can be downloaded below: +``` +sh download_demo_data.sh +``` +Downloaded meshes can be found in `data/objaverse_samples/` and `data/trellis_samples/`. + +#### Gaussian Splats +We also demonstrate our approach using Gaussian splatting reconstructions as input. Sample splat reconstruction data from the NeRF dataset can be found [here](https://drive.google.com/drive/folders/1l0njShLq37hn1TovgeF-PVGBBrAdNQnf?usp=sharing). Download the data and place it in the `data/splat_samples/` folder. + +### Extract Feature Field +#### Mesh Data + +``` +python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/objaverse dataset.data_path data/objaverse_samples +python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/trellis dataset.data_path data/trellis_samples +``` + +#### Point Clouds / Gaussian Splats +``` +python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/splat dataset.data_path data/splat_samples is_pc True +``` + +### Part Segmentation +#### Mesh Data + +We use agglomerative clustering for part segmentation on mesh inputs. +``` +python run_part_clustering.py --root exp_results/partfield_features/objaverse --dump_dir exp_results/clustering/objaverse --source_dir data/objaverse_samples --use_agglo True --max_num_clusters 20 --option 0 +``` + +When the input mesh has multiple connected components or poor connectivity, defining face adjacency by connecting geometrically close faces can yield better results (see discussion below): +``` +python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 1 --with_knn True +``` + +Note that agglomerative clustering does not return a fixed clustering result, but rather a hierarchical part tree, where the root node represents the whole shape and each leaf node corresponds to a single triangle face. You can explore more clustering results by adaptively traversing the tree, such as deciding which part should be further segmented. + +#### Point Cloud / Gaussian Splats +We use K-Means clustering for part segmentation on point cloud inputs. +``` +python run_part_clustering.py --root exp_results/partfield_features/splat --dump_dir exp_results/clustering/splat --source_dir data/splat_samples --max_num_clusters 20 --is_pc True +``` + +## Interactive Tools and Applications +We include UI tools to demonstrate various applications of PartField. Set up and try out our demos [here](applications/)! + +![Alt text](assets/co-seg.png) + +![Alt text](assets/regression_interactive_segmentation_guitars.gif) + +## Evaluation on PartObjaverse-Tiny + +![Alt text](assets/results_combined_compressed2.gif) + +To evaluate all models in PartObjaverse-Tiny, you can download the data [here](https://github.com/Pointcept/SAMPart3D/blob/main/PartObjaverse-Tiny/PartObjaverse-Tiny.md) and run the following commands: +``` +python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/partobjtiny dataset.data_path data/PartObjaverse-Tiny/PartObjaverse-Tiny_mesh n_point_per_face 2000 n_sample_each 10000 +python run_part_clustering.py --root exp_results/partfield_features/partobjtiny/ --dump_dir exp_results/clustering/partobjtiny --source_dir data/PartObjaverse-Tiny/PartObjaverse-Tiny_mesh --use_agglo True --max_num_clusters 20 --option 0 +``` +If an OOM error occurs, you can reduce the number of points sampled per face—for example, by setting `n_point_per_face` to 500. + +Evaluation metrics can be obtained by running the command below. The per-category average mIoU reported in the paper is also computed. +``` +python compute_metric.py +``` +This evaluation code builds on top of the implementation released by [SAMPart3D](https://github.com/Pointcept/SAMPart3D). Users with their own data and corresponding ground truths can easily modify this script to compute their metrics. + +## Discussion: Clustering with Messy Mesh Connectivities + + When using agglomerative clustering for part segmentation, an adjacency matrix is passed into the algorithm, which ideally requires the mesh to be a single connected component. However, some meshes can be messy, containing multiple connected components. If the input mesh is not a single connected component, we add pseudo-edges to the adjacency matrix to make it one. By default, we take a simple approach: adding `N-1` pseudo-edges as a chain to connect `N` components together. However, this approach can lead to poor results when the mesh is poorly connected and fragmented. + + + +``` +python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis_bad --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 0 +``` + +When this occurs, we explore different options that can lead to better results: + +### 1. Preprocess Input Mesh + +We can perform a simple cleanup on the input meshes by removing duplicate vertices and faces, and by merging nearby vertices using `pymeshlab`. This preprocessing step can be enabled via a flag when generating PartField features: + +``` +python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/trellis_preprocess dataset.data_path data/trellis_samples preprocess_mesh True +``` + +When running agglomerative clustering on a cleaned-up mesh, we observe improved part segmentation: + + + +``` +python run_part_clustering.py --root exp_results/partfield_features/trellis_preprocess --dump_dir exp_results/clustering/trellis_preprocess --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 0 +``` + +### 2. Cluster with KMeans + +If modifying the input mesh is not desirable and you prefer to avoid preprocessing, an alternative is to use KMeans clustering, which does not rely on an adjacency matrix. + + + + +``` +python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis_kmeans --source_dir data/trellis_samples --max_num_clusters 20 +``` + +### 3. MST-based Adjacency Matrix + +Instead of simply chaining the connected components of the input mesh, we also explore adding pseudo-edges to the adjacency matrix by constructing a KNN graph using face centroids and computing the minimum spanning tree of that graph. + + + +``` +python run_part_clustering.py --root exp_results/partfield_features/trellis --dump_dir exp_results/clustering/trellis_faceadj --source_dir data/trellis_samples --use_agglo True --max_num_clusters 20 --option 1 --with_knn True +``` + + + +### More Challenging Meshes! +The proposed approaches improve results for some meshes, but we find that certain cases still do not produce satisfactory segmentations. We leave these challenges for future work. If you're interested, here are some examples of difficult meshes we encountered: + +**Challenging Meshes:** +``` +cd data +mkdir challenge_samples +cd challenge_samples +wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-007/00790c705e4c4a1fbc0af9bf5c9e9525.glb +wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-132/13cc3ffc69964894a2bc94154aed687f.glb +``` + +## Citation +``` +@inproceedings{partfield2025, + title={PartField: Learning 3D Feature Fields for Part Segmentation and Beyond}, + author={Minghua Liu and Mikaela Angelina Uy and Donglai Xiang and Hao Su and Sanja Fidler and Nicholas Sharp and Jun Gao}, + year={2025} +} +``` + +## References +PartField borrows code from the following repositories: +- [OpenLRM](https://github.com/3DTopia/OpenLRM) +- [PyTorch 3D UNet](https://github.com/wolny/pytorch-3dunet) +- [PVCNN](https://github.com/mit-han-lab/pvcnn) +- [SAMPart3D](https://github.com/Pointcept/SAMPart3D) — evaluation script + +Many thanks to the authors for sharing their code! diff --git a/PartField/applications/.polyscope.ini b/PartField/applications/.polyscope.ini new file mode 100644 index 0000000000000000000000000000000000000000..c186950de8973b9a6541fe7d80dcc267838ba918 --- /dev/null +++ b/PartField/applications/.polyscope.ini @@ -0,0 +1,6 @@ +{ + "windowHeight": 1104, + "windowPosX": 66, + "windowPosY": 121, + "windowWidth": 2215 +} diff --git a/PartField/applications/README.md b/PartField/applications/README.md new file mode 100644 index 0000000000000000000000000000000000000000..53289b3e6970da4356079711ce33b64de950ed9b --- /dev/null +++ b/PartField/applications/README.md @@ -0,0 +1,142 @@ +# Interactive Tools and Applications + +## Single-Shape Feature and Segmentation Visualization Tool +We can visualize the output features and segmentation of a single shape by running the script below: + +``` +cd applications/ +python single_shape.py --data_root ../exp_results/partfield_features/trellis/ --filename dwarf +``` + +- `Mode: pca, feature_viz, cluster_agglo, cluster_kmeans` +- `pca` : Visualizes the pca Partfield features of the input model as colors. +- `feature_viz` : Visualizes each dimension of the PartField features as a colormap. +- `cluster_agglo` : Visualizes the part segmentation of the input model using Agglomerative clustering. + - Number of clusters is specified with the slider. + - `Adj Matrix Def`: Specifies how the adjacency matrix is defined for the clustering algorithm by adding dummy edges to make the input mesh a single connected component. + - `Add KNN edges` : Adds additional dummy edges based on k nearest neighbors. +- `cluster_kmeans` : Visualizes the part segmentation of the input model using KMeans clustering. + - Number of clusters is specified with the slider. + + +## Shape-Pair Co-Segmentation and Feature Exploration Tool +We provide a tool to analyze and visualize a pair of shapes that has two main functionalities: 1) **Co-segmentation** via co-clustering and 2) Partfield **feature exploration** and visualization. Try it out as follows: + +``` +cd applications/ +python shape_pair.py --data_root ../exp_results/partfield_features/trellis/ --filename dwarf --filename_alt goblin +``` + +### Co-Clustering for Co-Segmentation + +Here explains the use-case for `Mode: co-segmentation`. + +![Alt text](../assets/co-seg.png) + +The shape-pair is co-segmented by running co-clustering, In this application, we use the KMeans clustering algorithm. The `first shape (left)` is separated into parts via **unsupervised clustering** of its features with KMeans, from which the parts of the `second shape (right)` are then defined. + +Below are a list parameters: +- `Source init`: + - `True`: Initializes the cluster centers of the second shape (right) with the cluster centers of the first shape (left). + - `False`: Uses KMeans++ to initialize the cluster centers for KMeans for the second shape. +- `Independent`: + - `True`: Labels after running KMeans clustering are directly used as parts for the second shape. Correspondence with the parts of the first shape is not explicitly computed after KMeans clustering. + - `False`: After KMeans clustering is ran on the features of the second shape, the mean features for each unique part is then computed. The mean part feature for each part of the first shape is also computed. Then the parts of the second shaped are assigned labels based on the nearest neighbor part of the first shape. +- `Num cluster`: + - `Model1` : A slider is used to specify the number of parts for the first shape, i.e. number of clusters for KMeans clustering. + - `Model2` : A slider is used to specify the number of parts for the second shape, i.e. number of clusters for KMeans clustering. Note: if `Source init` is set to `True` then this slider is ignored and the number of clusters for Model1 is used. + + +### Feature Exploration and Visualization + +Here explains the use-case for `Mode: feature_explore`. + +![Alt text](../assets/feature_exploration2.png) + +This feature allows us to select a query point from the first shape (left) and the feature distance to all points in the second shape (left) and itself is then visualized as a colormap. +- `range` : A slider to specify the distance radius for feature similarity visualization. Large values will result in bigger highlighter areas. +- `continuous` : + - `False` : Query point is specified with a mouse click. + - `True` : You can slide your mouse around the first mesh to visualize feature distances. + +## Multi-shape Cosegmentation Tool +We further demonstrate PartField for cosegmentation of multiple/a set of shapes. Try out our demo application as follows: + +### Dependency Installation +Let's first install the necessary dependencies for this tool: +``` +pip install cuml-cu12 +pip install xgboost +``` + +### Dataset +We use the Shape COSEG dataset for our demo. We first download the dataset here: +``` +mkdir data/coseg_guitar +cd data/coseg_guitar +wget https://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/shapes.zip +wget https://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/gt.zip +unzip shapes.zip +unzip gt.zip +``` + +Now, let's extract Partfield features for the set of shapes: +``` +python partfield_inference.py -c configs/final/demo.yaml --opts continue_ckpt model/model_objaverse.ckpt result_name partfield_features/coseg_guitar/ dataset.data_path data/coseg_guitar/shapes +``` + +Now, we're ready to run the tool! We support two modes: 1) **Few-shot** with click-based annotations and 2) **Supervised** with ground truth labels. + +### Annotate Mode +![Alt text](../assets/regression_interactive_segmentation_guitars.gif) + +We can run our few-shot segmentation tool as follows: +``` +cd applications/ +python multi_shape_cosegment.py --meshes ../exp_results/partfield_features/coseg_guitar/ +``` +We can annotate the segments with a few clicks. A classifier is then ran to obtain the part segmentation. +- N_class: number of segmentation class labels +- Annotate: + - `00, 01, 02, ...`: Select the segmentation class label, then click on a shape region of that class. + - `Undo Last Selection`: Removes and disregards the last annotation made. +- Fit: + - Fit Method: Selects the classification method used for fitting. Default uses `Logistic Regression`. + - `Update Fit`: By default, the fitting process is automatically updated. This can also be changed to a manual update. + +### Ground Truth Labels Mode + +![Alt text](../assets/gt-based_coseg.png) + +Alternatively, we can use the ground truth labels of a subset of the shapes to train the classifier. + +``` +cd applications/ +python multi_shape_cosegment.py --meshes ../exp_results/partfield_features/coseg_guitar/ --n_train_subset 15 +``` +`Fit Method` can also be selected here to choose the classifier to be used. + + +## 3D Correspondences + +First, we clone the repository [SmoothFunctionalMaps](https://github.com/RobinMagnet/SmoothFunctionalMaps) and install additional packages. +``` +pip install omegaconf robust_laplacian +git submodule init +git submodule update --recursive +``` + +Download the [DenseCorr3D dataset](https://drive.google.com/file/d/1bpgsNu8JewRafhdRN4woQL7ObQtfgcpu/view?usp=sharing) into the `data` folder. Unzip the contents and ensure that the file structure is organized so that you can access +`data/DenseCorr3D/animals/071b8_toy_animals_017`. + +Extract the PartField features. +``` +# run in root directory of this repo +python partfield_inference.py -c configs/final/correspondence_demo.yaml --opts continue_ckpt model/model_objaverse.ckpt preprocess_mesh True +``` + +Run the functional map. +``` +cd applications/ +python run_smooth_functional_map.py -c ../configs/final/correspondence_demo.yaml --opts +``` \ No newline at end of file diff --git a/PartField/applications/multi_shape_cosegment.py b/PartField/applications/multi_shape_cosegment.py new file mode 100644 index 0000000000000000000000000000000000000000..a47d8d8ca402f1d74fe401ab06b301f2e212567e --- /dev/null +++ b/PartField/applications/multi_shape_cosegment.py @@ -0,0 +1,482 @@ +import numpy as np +import torch +import argparse +from dataclasses import dataclass + +from arrgh import arrgh +import polyscope as ps +import polyscope.imgui as psim +import potpourri3d as pp3d +import trimesh + +import cuml +import xgboost as xgb + +import os, random + +import sys +sys.path.append("..") +from partfield.utils import * + +@dataclass +class State: + + objects = None + train_objects = None + + # Input options + subsample_inputs: int = -1 + n_train_subset: int = 0 + + # Label + N_class: int = 2 + + # Annotations + # A annotations (initially A = 0) + anno_feat: np.array = np.zeros((0,448), dtype=np.float32) # [A,F] + anno_label: np.array = np.zeros((0,), dtype=np.int32) # [A] + anno_pos: np.array = np.zeros((0,3), dtype=np.float32) # [A,3] + + # Intermediate selection data + is_selecting: bool = False + selection_class: int = 0 + + # Fitting algorithm + fit_to: str = "Annotations" + fit_method : str = "LogisticRegression" + auto_update_fit: bool = True + + # Training data + # T training datapoints + train_feat: np.array = np.zeros((0,448), dtype=np.float32) # [T,F] + train_label: np.array = np.zeros((0,), dtype=np.int32) # [T] + + # Viz + grid_w : int = 8 + per_obj_shift : float = 2. + anno_radius : float = 0.01 + ps_cloud_annotation = None + ps_structure_name_to_index_map = {} + + +fit_methods_list = ["LinearRegression", "LogisticRegression", "LinearSVC", "RandomForest", "NearestNeighbors", "XGBoost"] +fit_to_list = ["Annotations", "TrainingSet"] + +def load_mesh_and_features(mesh_filepath, ind, require_gt=False, gt_label_fol = ""): + + dirpath, filename = os.path.split(mesh_filepath) + filename_core = filename[9:-6] # splits off "feat_pca_" ... "_0.ply" + feature_filename = "part_feat_"+ filename_core + "_0_batch.npy" + feature_filepath = os.path.join(dirpath, feature_filename) + + gt_filename = filename_core + ".seg" + gt_filepath = os.path.join(gt_label_fol, gt_filename) + have_gt = os.path.isfile(gt_filepath) + + print(" Reading file:") + print(f" Mesh filename: {mesh_filepath}") + print(f" Feature filename: {feature_filepath}") + print(f" Ground Truth Label filename: {gt_filepath} -- present = {have_gt}") + + # load features + feat = np.load(feature_filepath, allow_pickle=False) + feat = feat.astype(np.float32) + + # load mesh things + # TODO replace this with just loading V/F from numpy archive + tm = load_mesh_util(mesh_filepath) + + V = np.array(tm.vertices, dtype=np.float32) + F = np.array(tm.faces) + + # load ground truth, if available + if have_gt: + gt_labels = np.loadtxt(gt_filepath) + gt_labels = gt_labels.astype(np.int32) - 1 + else: + if require_gt: + raise ValueError("could not find ground-truth file, but it is required") + gt_labels = None + + # pca_colors = None + + return { + 'nicename' : f"{ind:02d}_{filename_core}", + 'mesh_filepath' : mesh_filepath, + 'feature_filepath' : feature_filepath, + 'V' : V, + 'F' : F, + 'feat_np' : feat, + # 'feat_pt' : torch.tensor(feat, device='cuda'), + 'gt_labels' : gt_labels + } + +def shift_for_ind(state : State, ind): + + x_ind = ind % state.grid_w + y_ind = ind // state.grid_w + + shift = np.array([state.per_obj_shift * x_ind, 0, -state.per_obj_shift * y_ind]) + + return shift + +def viz_upper_limit(state : State, ind_count): + + x_max = min(ind_count, state.grid_w) + y_max = ind_count // state.grid_w + + bound = np.array([state.per_obj_shift * x_max, 0, -state.per_obj_shift * y_max]) + + return bound + + +def initialize_object_viz(state : State, obj, index=0): + obj['ps_mesh'] = ps.register_surface_mesh(obj['nicename'], obj['V'], obj['F'], color=(.8, .8, .8)) + shift = shift_for_ind(state, index) + obj['ps_mesh'].translate(shift) + obj['ps_mesh'].set_selection_mode('faces_only') + state.ps_structure_name_to_index_map[obj['nicename']] = index + +def update_prediction(state: State): + + print("Updating predictions..") + + N_anno = state.anno_label.shape[0] + + # Quick out if we don't have at least two distinct class labels present + if(state.fit_to == "Annotations" and len(np.unique(state.anno_label)) <= 1): + return state + + # Quick out if we don't have + if(state.fit_to == "TrainingSet" and state.train_objects is None): + return state + + if state.fit_method == "LinearRegression": + classifier = cuml.multiclass.MulticlassClassifier(cuml.linear_model.LinearRegression(), strategy='ovr') + elif state.fit_method == "LogisticRegression": + classifier = cuml.multiclass.MulticlassClassifier(cuml.linear_model.LogisticRegression(), strategy='ovr') + elif state.fit_method == "LinearSVC": + classifier = cuml.multiclass.MulticlassClassifier(cuml.svm.LinearSVC(), strategy='ovr') + elif state.fit_method == "RandomForest": + classifier = cuml.ensemble.RandomForestClassifier() + elif state.fit_method == "NearestNeighbors": + classifier = cuml.multiclass.MulticlassClassifier(cuml.neighbors.KNeighborsRegressor(n_neighbors=1), strategy='ovr') + elif state.fit_method == "XGBoost": + classifier = xgb.XGBClassifier(max_depth=7, n_estimators=1000) + else: + raise ValueError("unrecognized fit method") + + if state.fit_to == "TrainingSet": + + all_train_feats = [] + all_train_labels = [] + for obj in state.train_objects: + all_train_feats.append(obj['feat_np']) + all_train_labels.append(obj['gt_labels']) + + all_train_feats = np.concatenate(all_train_feats, axis=0) + all_train_labels = np.concatenate(all_train_labels, axis=0) + + state.N_class = np.max(all_train_labels) + 1 + + classifier.fit(all_train_feats, all_train_labels) + + + elif state.fit_to == "Annotations": + classifier.fit(state.anno_feat,state.anno_label) + else: + raise ValueError("unrecognized fit to") + + n_total = 0 + n_correct = 0 + + for obj in state.objects: + obj['pred_label'] = classifier.predict(obj['feat_np']) + + if obj['gt_labels'] is not None: + n_total += obj['gt_labels'].shape[0] + n_correct += np.sum(obj['pred_label'] == obj['gt_labels'], dtype=np.int32) + + if(state.fit_to == "TrainingSet" and n_total > 0): + frac = n_correct / n_total + print(f"Test accuracy: {n_correct:d} / {n_total:d} {100*frac:.02f}%") + + + print("Done updating predictions.") + + return state + +def update_prediction_viz(state: State): + + for obj in state.objects: + if 'pred_label' in obj: + obj['ps_mesh'].add_scalar_quantity("pred labels", obj['pred_label'], defined_on='faces', vminmax=(0,state.N_class-1), cmap='turbo', enabled=True) + + return state + +def update_annotation_viz(state: State): + + ps_cloud = ps.register_point_cloud("annotations", state.anno_pos, radius=state.anno_radius, material='candy') + ps_cloud.add_scalar_quantity("labels", state.anno_label, vminmax=(0,state.N_class-1), cmap='turbo', enabled=True) + + state.ps_cloud_annotation = ps_cloud + + return state + + +def filter_old_labels(state: State): + """ + Filter out annotations from classes that don't exist any more + """ + + keep_mask = state.anno_label < state.N_class + state.anno_feat = state.anno_feat[keep_mask,:] + state.anno_label = state.anno_label[keep_mask] + state.anno_pos = state.anno_pos[keep_mask,:] + + return state + +def undo_last_annotation(state: State): + + state.anno_feat = state.anno_feat[:-1,:] + state.anno_label = state.anno_label[:-1] + state.anno_pos = state.anno_pos[:-1,:] + + return state + +def ps_callback(state_list): + state : State = state_list[0] # hacky pass-by-reference, since we want to edit it below + + + # If we're in selection mode, that's the only thing we can do + if state.is_selecting: + + psim.TextUnformatted(f"Annotating class {state.selection_class:02d}. Click on any mesh face.") + + io = psim.GetIO() + if io.MouseClicked[0]: + screen_coords = io.MousePos + pick_result = ps.pick(screen_coords=screen_coords) + + # Check if we hit one of the meshes + if pick_result.is_hit and pick_result.structure_name in state.ps_structure_name_to_index_map: + if pick_result.structure_data['element_type'] != "face": + # shouldn't be possible + raise ValueError("pick returned non-face") + + i_obj = state.ps_structure_name_to_index_map[pick_result.structure_name] + f_hit = pick_result.structure_data['index'] + + obj = state.objects[i_obj] + V = obj['V'] + F = obj['F'] + feat = obj['feat_np'] + + face_corners = V[F[f_hit,:],:] + new_anno_feat = feat[f_hit,:] + new_anno_label = state.selection_class + new_anno_pos = np.mean(face_corners, axis=0) + shift_for_ind(state, i_obj) + + state.anno_feat = np.concatenate((state.anno_feat, new_anno_feat[None,:])) + state.anno_label = np.concatenate((state.anno_label, np.array((new_anno_label,)))) + state.anno_pos = np.concatenate((state.anno_pos, new_anno_pos[None,:])) + + state = update_annotation_viz(state) + state.is_selecting = False + needs_pred_update = True + + if state.auto_update_fit: + state = update_prediction(state) + state = update_prediction_viz(state) + + + return + + # If not selecting, build the main UI + needs_pred_update = False + + psim.PushItemWidth(150) + changed, state.N_class = psim.InputInt("N_class", state.N_class, step=1) + psim.PopItemWidth() + if changed: + state = filter_old_labels(state) + state = update_annotation_viz(state) + + + # Check for keypress annotation + io = psim.GetIO() + class_keys = { 'w' : 0, '1' : 1, '2' : 2, '3' : 3, '4' : 4, '5' : 5, '6' : 6, '7' : 7, '8' : 8, '9' : 9,} + for c in class_keys: + if class_keys[c] >= state.N_class: + continue + + if psim.IsKeyPressed(ps.get_key_code(c)): + state.is_selecting = True + state.selection_class = class_keys[c] + + + psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver) + if(psim.TreeNode("Annotate")): + + psim.TextUnformatted("New class annotation. Select class to add add annotation for:") + psim.TextUnformatted("(alternately, press key {w,1,2,3,4...})") + for i_class in range(state.N_class): + + if i_class > 0: + psim.SameLine() + + if psim.Button(f"{i_class:02d}"): + state.is_selecting = True + state.selection_class = i_class + + + if psim.Button("Undo Last Annotation"): + state = undo_last_annotation(state) + state = update_annotation_viz(state) + needs_pred_update = True + + + + psim.TreePop() + + psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver) + if(psim.TreeNode("Fit")): + + psim.PushItemWidth(150) + + changed, ind = psim.Combo("Fit To", fit_to_list.index(state.fit_to), fit_to_list) + if changed: + state.fit_to = fit_methods_list[ind] + needs_pred_update = True + + changed, ind = psim.Combo("Fit Method", fit_methods_list.index(state.fit_method), fit_methods_list) + if changed: + state.fit_method = fit_methods_list[ind] + needs_pred_update = True + + if psim.Button("Update fit"): + state = update_prediction(state) + state = update_prediction_viz(state) + + psim.SameLine() + + changed, state.auto_update_fit = psim.Checkbox("Auto-update fit", state.auto_update_fit) + if changed: + needs_pred_update = True + + + psim.PopItemWidth() + + psim.TreePop() + + psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver) + if(psim.TreeNode("Visualization")): + + psim.PushItemWidth(150) + changed, state.anno_radius = psim.SliderFloat("Annotation Point Radius", state.anno_radius, 0.00001, 0.02) + if changed: + state = update_annotation_viz(state) + psim.PopItemWidth() + + psim.TreePop() + + + if needs_pred_update and state.auto_update_fit: + state = update_prediction(state) + state = update_prediction_viz(state) + + +def main(): + + state = State() + + ## Parse args + parser = argparse.ArgumentParser() + + parser.add_argument('--meshes', nargs='+', help='List of meshes to process.', required=True) + parser.add_argument('--n_train_subset', default=0, help='How many meshes to train on.') + parser.add_argument('--gt_label_fol', default="../data/coseg_guitar/gt", help='Path where labels are stored.') + parser.add_argument('--subsample_inputs', default=state.subsample_inputs, help='Only show a random fraction of inputs') + parser.add_argument('--per_obj_shift', default=state.per_obj_shift, help='How to space out objects in UI grid') + parser.add_argument('--grid_w', default=state.grid_w, help='Grid width') + + args = parser.parse_args() + + + state.n_train_subset = int(args.n_train_subset) + state.subsample_inputs = int(args.subsample_inputs) + state.per_obj_shift = float(args.per_obj_shift) + state.grid_w = int(args.grid_w) + + ## Load data + # First, resolve directories to load all files in directory + all_filepaths = [] + print("Resolving passed directories") + for entry in args.meshes: + if os.path.isdir(entry): + dir_path = entry + print(f" processing directory {dir_path}") + for filename in os.listdir(dir_path): + file_path = os.path.join(dir_path, filename) + if os.path.isfile(file_path) and file_path.endswith(".ply") and "feat_pca" in file_path: + print(f" adding file {file_path}") + all_filepaths.append(file_path) + else: + all_filepaths.append(entry) + + random.shuffle(all_filepaths) + + if state.subsample_inputs != -1: + all_filepaths = all_filepaths[:state.subsample_inputs] + + + if state.n_train_subset != 0: + + print(state.n_train_subset) + + train_filepaths = all_filepaths[:state.n_train_subset] + all_filepaths = all_filepaths[state.n_train_subset:] + + print(f"Loading {len(train_filepaths)} files") + state.train_objects = [] + for i, file_path in enumerate(train_filepaths): + state.train_objects.append(load_mesh_and_features(file_path, i, require_gt=True, gt_label_fol=args.gt_label_fol)) + + state.fit_to = "TrainingSet" + + # Load files + print(f"Loading {len(all_filepaths)} files") + state.objects = [] + for i, file_path in enumerate(all_filepaths): + state.objects.append(load_mesh_and_features(file_path, i)) + + + ## Set up visualization + ps.init() + ps.set_automatically_compute_scene_extents(False) + lim = viz_upper_limit(state, len(state.objects)) + ps.set_length_scale(np.linalg.norm(lim) / 4.) + low = np.array((0, -1., -1.)) + high = lim + ps.set_bounding_box(low, high) + + for ind, o in enumerate(state.objects): + initialize_object_viz(state, o, ind) + + print(f"Loaded {len(state.objects)} objects") + if state.n_train_subset != 0: + print(f"Loaded {len(state.train_objects)} training objects") + + # One first prediction + # (does nothing if there is no annotatoins / training data) + state = update_prediction(state) + state = update_prediction_viz(state) + + # Start the interactive UI + ps.set_user_callback(lambda : ps_callback([state])) + ps.show() + + +if __name__ == "__main__": + main() + diff --git a/PartField/applications/pack_labels_to_obj.py b/PartField/applications/pack_labels_to_obj.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5925726dfeb675e4944d7dfad59c63505e693b --- /dev/null +++ b/PartField/applications/pack_labels_to_obj.py @@ -0,0 +1,47 @@ +import sys, os, fnmatch, re +import argparse + +import numpy as np +import matplotlib +from matplotlib import colors as mcolors +import matplotlib.cm +import potpourri3d as pp3d +import igl +from arrgh import arrgh + +def main(): + + parser = argparse.ArgumentParser() + + parser.add_argument("--input_mesh", type=str, required=True, help="The mesh to read from from, mesh file format.") + parser.add_argument("--input_labels", type=str, required=True, help="The labels, as a text file with one entry per line") + parser.add_argument("--label_count", type=int, default=-1, help="The number of labels to use for the visualization. If -1, computed as max of given labels.") + parser.add_argument("--output", type=str, required=True, help="The obj file to write output to") + + args = parser.parse_args() + + + # Read the mesh + V, F = igl.read_triangle_mesh(args.input_mesh) + + # Read the scalar function + S = np.loadtxt(args.input_labels) + + # Convert integers to scalars on [0,1] + if args.label_count == -1: + N_max = np.max(S) + 1 + else: + N_max = args.label_count + S = S.astype(np.float32) / max(N_max-1, 1) + + # Validate and write + if len(S.shape) != 1 or S.shape[0] != F.shape[0]: + raise ValueError(f"when scalar_on==faces, the scalar should be a length num-faces numpy array, but it has shape {S.shape[0]} and F={F.shape[0]}") + + S = np.stack((S, np.zeros_like(S)), axis=-1) + + pp3d.write_mesh(V, F, args.output, UV_coords=S, UV_type='per-face') + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/PartField/applications/run_smooth_functional_map.py b/PartField/applications/run_smooth_functional_map.py new file mode 100644 index 0000000000000000000000000000000000000000..8a294156aafb73d435e716b15860a71cd9971295 --- /dev/null +++ b/PartField/applications/run_smooth_functional_map.py @@ -0,0 +1,80 @@ +import os, sys +import numpy as np +import torch +import trimesh +import json + +sys.path.append("..") +sys.path.append("../third_party/SmoothFunctionalMaps") +sys.path.append("../third_party/SmoothFunctionalMaps/pyFM") + +from partfield.config import default_argument_parser, setup +from pyFM.mesh import TriMesh +from pyFM.spectral import mesh_FM_to_p2p +import DiscreteOpt + + +def vertex_color_map(vertices): + min_coord, max_coord = np.min(vertices, axis=0, keepdims=True), np.max(vertices, axis=0, keepdims=True) + cmap = (vertices - min_coord) / (max_coord - min_coord) + return cmap + + +if __name__ == '__main__': + parser = default_argument_parser() + args = parser.parse_args() + cfg = setup(args, freeze=False) + + feature_dir = os.path.join("../exp_results", cfg.result_name) + + all_files = cfg.dataset.all_files + assert len(all_files) % 2 == 0 + num_pairs = len(all_files) // 2 + + device = "cuda" + + output_dir = "../exp_results/correspondence/" + os.makedirs(output_dir, exist_ok=True) + + for i in range(num_pairs): + file0 = all_files[2 * i] + file1 = all_files[2 * i + 1] + + uid0 = file0.split(".")[-2].replace("/", "_") + uid1 = file1.split(".")[-2].replace("/", "_") + + mesh0 = trimesh.load(os.path.join(feature_dir, f"input_{uid0}_0.ply"), process=True) + mesh1 = trimesh.load(os.path.join(feature_dir, f"input_{uid1}_0.ply"), process=True) + + feat0 = np.load(os.path.join(feature_dir, f"part_feat_{uid0}_0_batch.npy")) + feat1 = np.load(os.path.join(feature_dir, f"part_feat_{uid1}_0_batch.npy")) + + assert mesh0.vertices.shape[0] == feat0.shape[0], "num of vertices should match num of features" + assert mesh1.vertices.shape[0] == feat1.shape[0], "num of vertices should match num of features" + + th_descr0 = torch.tensor(feat0, device=device, dtype=torch.float32) + th_descr1 = torch.tensor(feat1, device=device, dtype=torch.float32) + + cdist_01 = torch.cdist(th_descr0, th_descr1, p=2) + p2p_10_init = cdist_01.argmin(dim=0).cpu().numpy() + p2p_01_init = cdist_01.argmin(dim=1).cpu().numpy() + + fm_mesh0 = TriMesh(mesh0.vertices, mesh0.faces, area_normalize=True, center=True).process(k=200, intrinsic=True) + fm_mesh1 = TriMesh(mesh1.vertices, mesh1.faces, area_normalize=True, center=True).process(k=200, intrinsic=True) + + model = DiscreteOpt.SmoothDiscreteOptimization(fm_mesh0, fm_mesh1) + model.set_params("zoomout_rhm") + model.opt_params.step = 10 + model.solve_from_p2p(p2p_21=p2p_10_init, p2p_12=p2p_01_init, n_jobs=30, verbose=True) + + p2p_10_FM = mesh_FM_to_p2p(model.FM_12, fm_mesh0, fm_mesh1, use_adj=True) + + color0 = vertex_color_map(mesh0.vertices) + color1 = color0[p2p_10_FM] + + output_mesh0 = trimesh.Trimesh(mesh0.vertices, mesh0.faces, vertex_colors=color0) + output_mesh1 = trimesh.Trimesh(mesh1.vertices, mesh1.faces, vertex_colors=color1) + + output_mesh0.export(os.path.join(output_dir, f"correspondence_{uid0}_{uid1}_0.ply")) + output_mesh1.export(os.path.join(output_dir, f"correspondence_{uid0}_{uid1}_1.ply")) + diff --git a/PartField/applications/shape_pair.py b/PartField/applications/shape_pair.py new file mode 100644 index 0000000000000000000000000000000000000000..6c0ba3e29b5446e121ec426a4a46a8d12c957824 --- /dev/null +++ b/PartField/applications/shape_pair.py @@ -0,0 +1,385 @@ +import numpy as np +import torch +import polyscope as ps +import polyscope.imgui as psim +import potpourri3d as pp3d +import trimesh +import igl +from dataclasses import dataclass +from simple_parsing import ArgumentParser +from arrgh import arrgh + +### For clustering +from collections import defaultdict +from sklearn.cluster import AgglomerativeClustering, DBSCAN, KMeans +from scipy.sparse import coo_matrix, csr_matrix +from scipy.spatial import KDTree +from scipy.sparse.csgraph import connected_components +from sklearn.neighbors import NearestNeighbors +import networkx as nx + +from scipy.optimize import linear_sum_assignment + +import os, sys +sys.path.append("..") +from partfield.utils import * + +@dataclass +class Options: + + """ Basic Options """ + filename: str + filename_alt: str = None + + """System Options""" + device: str = "cuda" # Device + debug: bool = False # enable debug checks + extras: bool = False # include extra output for viz/debugging + + """ State """ + mode: str = 'co-segmentation' + m: dict = None # mesh + m_alt: dict = None # second mesh + + # pca mode + + # feature explore mode + i_feature: int = 0 + + i_cluster: int = 1 + i_cluster2: int = 1 + + i_eps: int = 0.6 + + ### For mixing in clustering + weight_dist = 1.0 + weight_feat = 1.0 + + ### For clustering visualization + independent: bool = True + source_init: bool = True + + feature_range: float = 0.1 + continuous_explore: bool = False + + viz_mode: str = "faces" + + output_fol: str = "results_pair" + + ### counter for screenshot + counter: int = 0 + +modes_list = ['feature_explore', "co-segmentation"] + +def load_features(feature_filename, mesh_filename, viz_mode): + + print("Reading features:") + print(f" Feature filename: {feature_filename}") + print(f" Mesh filename: {mesh_filename}") + + # load features + feat = np.load(feature_filename, allow_pickle=True) + feat = feat.astype(np.float32) + + # load mesh things + tm = load_mesh_util(mesh_filename) + + V = np.array(tm.vertices, dtype=np.float32) + F = np.array(tm.faces) + + if viz_mode == "faces": + pca_colors = np.array(tm.visual.face_colors, dtype=np.float32) + pca_colors = pca_colors[:,:3] / 255. + + else: + pca_colors = np.array(tm.visual.vertex_colors, dtype=np.float32) + pca_colors = pca_colors[:,:3] / 255. + + arrgh(V, F, pca_colors, feat) + + return { + 'V' : V, + 'F' : F, + 'pca_colors' : pca_colors, + 'feat_np' : feat, + 'feat_pt' : torch.tensor(feat, device='cuda'), + 'trimesh' : tm, + 'label' : None, + 'num_cluster' : 1, + 'scalar' : None + } + +def prep_feature_mesh(m, name='mesh'): + ps_mesh = ps.register_surface_mesh(name, m['V'], m['F']) + ps_mesh.set_selection_mode('faces_only') + m['ps_mesh'] = ps_mesh + +def viz_pca_colors(m): + m['ps_mesh'].add_color_quantity('pca colors', m['pca_colors'], enabled=True, defined_on=m["viz_mode"]) + +def viz_feature(m, ind): + m['ps_mesh'].add_scalar_quantity('pca colors', m['feat_np'][:,ind], cmap='turbo', enabled=True, defined_on=m["viz_mode"]) + +def feature_distance_np(feats, query_feat): + # normalize + feats = feats / np.linalg.norm(feats,axis=1)[:,None] + query_feat = query_feat / np.linalg.norm(query_feat) + # cosine distance + cos_sim = np.dot(feats, query_feat) + cos_dist = (1 - cos_sim) / 2. + return cos_dist + +def feature_distance_pt(feats, query_feat): + return (1. - torch.nn.functional.cosine_similarity(feats, query_feat[None,:], dim=-1)) / 2. + + +def ps_callback(opts): + m = opts.m + + changed, ind = psim.Combo("Mode", modes_list.index(opts.mode), modes_list) + if changed: + opts.mode = modes_list[ind] + m['ps_mesh'].remove_all_quantities() + if opts.m_alt is not None: + opts.m_alt['ps_mesh'].remove_all_quantities() + + elif opts.mode == 'feature_explore': + psim.TextUnformatted("Click on the mesh on the left") + psim.TextUnformatted("to highlight all faces within a given radius in feature space.""") + + io = psim.GetIO() + if io.MouseClicked[0] or opts.continuous_explore: + screen_coords = io.MousePos + cam_params = ps.get_view_camera_parameters() + + pick_result = ps.pick(screen_coords=screen_coords) + + # Check if we hit one of the meshes + if pick_result.is_hit and pick_result.structure_name == "mesh": + if pick_result.structure_data['element_type'] != "face": + # shouldn't be possible + raise ValueError("pick returned non-face") + + f_hit = pick_result.structure_data['index'] + bary_weights = np.array(pick_result.structure_data['bary_coords']) + + # get the feature via interpolation + point_feat = m['feat_np'][f_hit,:] + point_feat_pt = torch.tensor(point_feat, device='cuda') + + all_dists1 = feature_distance_pt(m['feat_pt'], point_feat_pt).detach().cpu().numpy() + m['ps_mesh'].add_scalar_quantity("distance", all_dists1, cmap='blues', vminmax=(0, opts.feature_range), enabled=True, defined_on=m["viz_mode"]) + opts.m['scalar'] = all_dists1 + + if opts.m_alt is not None: + all_dists2 = feature_distance_pt(opts.m_alt['feat_pt'], point_feat_pt).detach().cpu().numpy() + opts.m_alt['ps_mesh'].add_scalar_quantity("distance", all_dists2, cmap='blues', vminmax=(0, opts.feature_range), enabled=True, defined_on=m["viz_mode"]) + opts.m_alt['scalar'] = all_dists2 + + else: + # not hit + pass + + if psim.Button("Export"): + ### Save output + OUTPUT_FOL = opts.output_fol + fname1 = opts.filename + out_mesh_file = os.path.join(OUTPUT_FOL, fname1+'.obj') + + igl.write_obj(out_mesh_file, opts.m["V"], opts.m["F"]) + print("Saved '{}'.".format(out_mesh_file)) + + out_face_ids_file = os.path.join(OUTPUT_FOL, fname1 + '_feat_dist_' + str(opts.counter) +'.txt') + np.savetxt(out_face_ids_file, opts.m['scalar'], fmt='%f') + print("Saved '{}'.".format(out_face_ids_file)) + + + fname2 = opts.filename_alt + out_mesh_file = os.path.join(OUTPUT_FOL, fname2+'.obj') + + igl.write_obj(out_mesh_file, opts.m_alt["V"], opts.m_alt["F"]) + print("Saved '{}'.".format(out_mesh_file)) + + out_face_ids_file = os.path.join(OUTPUT_FOL, fname2 + '_feat_dist_' + str(opts.counter) +'.txt') + np.savetxt(out_face_ids_file, opts.m_alt['scalar'], fmt='%f') + # print("Saved '{}'.".format(out_face_ids_file)) + + opts.counter += 1 + + + _, opts.feature_range = psim.SliderFloat('range', opts.feature_range, v_min=0., v_max=1.0, power=3) + _, opts.continuous_explore = psim.Checkbox('continuous', opts.continuous_explore) + + # TODO nsharp remember how the keycodes work + if io.KeysDown[ord('q')]: + opts.feature_range += 0.01 + if io.KeysDown[ord('w')]: + opts.feature_range -= 0.01 + + + elif opts.mode == "co-segmentation": + + changed, opts.source_init = psim.Checkbox("Source Init", opts.source_init) + changed, opts.independent = psim.Checkbox("Independent", opts.independent) + + psim.TextUnformatted("Use the slider to toggle the number of desired clusters.") + cluster_changed, opts.i_cluster = psim.SliderInt("num clusters for model1", opts.i_cluster, v_min=1, v_max=30) + cluster_changed, opts.i_cluster2 = psim.SliderInt("num clusters for model2", opts.i_cluster2, v_min=1, v_max=30) + + # if cluster_changed: + if psim.Button("Recompute"): + + ### Run clustering algorithm + + ### Mesh 1 + num_clusters1 = opts.i_cluster + point_feat1 = m['feat_np'] + point_feat1 = point_feat1 / np.linalg.norm(point_feat1, axis=-1, keepdims=True) + clustering1 = KMeans(n_clusters=num_clusters1, random_state=0, n_init="auto").fit(point_feat1) + + ### Get feature means per cluster + feature_means1 = [] + for j in range(num_clusters1): + all_cluster_feat = point_feat1[clustering1.labels_==j] + mean_feat = np.mean(all_cluster_feat, axis=0) + feature_means1.append(mean_feat) + + feature_means1 = np.array(feature_means1) + tree = KDTree(feature_means1) + + + if opts.source_init: + num_clusters2 = opts.i_cluster + init_mode = np.array(feature_means1) + + ## default is kmeans++ + else: + num_clusters2 = opts.i_cluster2 + init_mode = "k-means++" + + ### Mesh 2 + point_feat2 = opts.m_alt['feat_np'] + point_feat2 = point_feat2 / np.linalg.norm(point_feat2, axis=-1, keepdims=True) + + clustering2 = KMeans(n_clusters=num_clusters2, random_state=0, init=init_mode).fit(point_feat2) + + ### Get feature means per cluster + feature_means2 = [] + for j in range(num_clusters2): + all_cluster_feat = point_feat2[clustering2.labels_==j] + mean_feat = np.mean(all_cluster_feat, axis=0) + feature_means2.append(mean_feat) + + feature_means2 = np.array(feature_means2) + _, nn_idx = tree.query(feature_means2, k=1) + + print(nn_idx) + print("Both KMeans") + print(np.unique(clustering1.labels_)) + print(np.unique(clustering2.labels_)) + + relabelled_2 = nn_idx[clustering2.labels_] + + print(np.unique(relabelled_2)) + print() + + m['ps_mesh'].add_scalar_quantity("cluster_both_kmeans", clustering1.labels_, cmap='turbo', vminmax=(0, num_clusters1-1), enabled=True, defined_on=m["viz_mode"]) + opts.m['label'] = clustering1.labels_ + opts.m['num_cluster'] = num_clusters1 + + if opts.independent: + opts.m_alt['ps_mesh'].add_scalar_quantity("cluster", clustering2.labels_, cmap='turbo', vminmax=(0, num_clusters2-1), enabled=True, defined_on=m["viz_mode"]) + opts.m_alt['label'] = clustering2.labels_ + opts.m_alt['num_cluster'] = num_clusters2 + else: + opts.m_alt['ps_mesh'].add_scalar_quantity("cluster", relabelled_2, cmap='turbo', vminmax=(0, num_clusters1-1), enabled=True, defined_on=m["viz_mode"]) + opts.m_alt['label'] = relabelled_2 + opts.m_alt['num_cluster'] = num_clusters1 + + + if psim.Button("Export"): + ### Save output + OUTPUT_FOL = opts.output_fol + fname1 = opts.filename + out_mesh_file = os.path.join(OUTPUT_FOL, fname1+'.obj') + + igl.write_obj(out_mesh_file, opts.m["V"], opts.m["F"]) + print("Saved '{}'.".format(out_mesh_file)) + + if m["viz_mode"] == "faces": + out_face_ids_file = os.path.join(OUTPUT_FOL, fname1 + "_" + str(opts.m['num_cluster']) + '_pred_face_ids.txt') + else: + out_face_ids_file = os.path.join(OUTPUT_FOL, fname1 + "_" + str(opts.m['num_cluster']) + '_pred_vertices_ids.txt') + + np.savetxt(out_face_ids_file, opts.m['label'], fmt='%d') + print("Saved '{}'.".format(out_face_ids_file)) + + + fname2 = opts.filename_alt + out_mesh_file = os.path.join(OUTPUT_FOL, fname2 +'.obj') + + igl.write_obj(out_mesh_file, opts.m_alt["V"], opts.m_alt["F"]) + print("Saved '{}'.".format(out_mesh_file)) + + if m["viz_mode"] == "faces": + out_face_ids_file = os.path.join(OUTPUT_FOL, fname2 + "_" + str(opts.m_alt['num_cluster']) + '_pred_face_ids.txt') + else: + out_face_ids_file = os.path.join(OUTPUT_FOL, fname2 + "_" + str(opts.m_alt['num_cluster']) + '_pred_vertices_ids.txt') + + np.savetxt(out_face_ids_file, opts.m_alt['label'], fmt='%d') + print("Saved '{}'.".format(out_face_ids_file)) + + +def main(): + ## Parse args + # Uses simple_parsing library to automatically construct parser from the dataclass Options + parser = ArgumentParser() + parser.add_arguments(Options, dest="options") + parser.add_argument('--data_root', default="../exp_results/partfield_features/trellis", help='Path the model features are stored.') + args = parser.parse_args() + opts: Options = args.options + + DATA_ROOT = args.data_root + + shape_1 = opts.filename + shape_2 = opts.filename_alt + + if os.path.exists(os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy")): + feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy") + feature_fname2 = os.path.join(DATA_ROOT, "part_feat_"+ shape_2 + "_0.npy") + + mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply") + mesh_fname2 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_2 + "_0.ply") + else: + feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0_batch.npy") + feature_fname2 = os.path.join(DATA_ROOT, "part_feat_"+ shape_2 + "_0_batch.npy") + + mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply") + mesh_fname2 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_2 + "_0.ply") + + #### To save output #### + os.makedirs(opts.output_fol, exist_ok=True) + ######################## + + # Initialize + ps.init() + + mesh_dict = load_features(feature_fname1, mesh_fname1, opts.viz_mode) + prep_feature_mesh(mesh_dict) + mesh_dict["viz_mode"] = opts.viz_mode + opts.m = mesh_dict + + mesh_dict_alt = load_features(feature_fname2, mesh_fname2, opts.viz_mode) + prep_feature_mesh(mesh_dict_alt, name='mesh_alt') + mesh_dict_alt['ps_mesh'].translate((2.5, 0., 0.)) + mesh_dict_alt["viz_mode"] = opts.viz_mode + opts.m_alt = mesh_dict_alt + + # Start the interactive UI + ps.set_user_callback(lambda : ps_callback(opts)) + ps.show() + + +if __name__ == "__main__": + main() + diff --git a/PartField/applications/single_shape.py b/PartField/applications/single_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..6f6af4969d87be1258fabf8c8a396db3777fffaf --- /dev/null +++ b/PartField/applications/single_shape.py @@ -0,0 +1,758 @@ +import numpy as np +import torch +import polyscope as ps +import polyscope.imgui as psim +import potpourri3d as pp3d +import trimesh +import igl +from dataclasses import dataclass +from simple_parsing import ArgumentParser +from arrgh import arrgh + +### For clustering +from collections import defaultdict +from sklearn.cluster import AgglomerativeClustering, DBSCAN, KMeans +from scipy.sparse import coo_matrix, csr_matrix +from scipy.spatial import KDTree +from scipy.sparse.csgraph import connected_components +from sklearn.neighbors import NearestNeighbors +import networkx as nx + +from scipy.optimize import linear_sum_assignment + +import os, sys +sys.path.append("..") +from partfield.utils import * + +@dataclass +class Options: + + """ Basic Options """ + filename: str + + """System Options""" + device: str = "cuda" # Device + debug: bool = False # enable debug checks + extras: bool = False # include extra output for viz/debugging + + """ State """ + mode: str = 'pca' + m: dict = None # mesh + + # pca mode + + # feature explore mode + i_feature: int = 0 + + i_cluster: int = 1 + + i_eps: int = 0.6 + + ### For mixing in clustering + weight_dist = 1.0 + weight_feat = 1.0 + + ### For clustering visualization + feature_range: float = 0.1 + continuous_explore: bool = False + + viz_mode: str = "faces" + + output_fol: str = "results_single" + + ### For adj_matrix + adj_mode: str = "Vanilla" + add_knn_edges: bool = False + + ### counter for screenshot + counter: int = 0 + +modes_list = ['pca', 'feature_viz', 'cluster_agglo', 'cluster_kmeans'] +adj_mode_list = ["Vanilla", "Face_MST", "CC_MST"] + +#### For clustering +class UnionFind: + def __init__(self, n): + self.parent = list(range(n)) + self.rank = [1] * n + + def find(self, x): + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x, y): + rootX = self.find(x) + rootY = self.find(y) + + if rootX != rootY: + if self.rank[rootX] > self.rank[rootY]: + self.parent[rootY] = rootX + elif self.rank[rootX] < self.rank[rootY]: + self.parent[rootX] = rootY + else: + self.parent[rootY] = rootX + self.rank[rootX] += 1 + +##################################### +## Face adjacency computation options +##################################### +def construct_face_adjacency_matrix_ccmst(face_list, vertices, k=10, with_knn=True): + """ + Given a list of faces (each face is a 3-tuple of vertex indices), + construct a face-based adjacency matrix of shape (num_faces, num_faces). + + Two faces are adjacent if they share an edge (the "mesh adjacency"). + If multiple connected components remain, we: + 1) Compute the centroid of each connected component as the mean of all face centroids. + 2) Use a KNN graph (k=10) based on centroid distances on each connected component. + 3) Compute MST of that KNN graph. + 4) Add MST edges that connect different components as "dummy" edges + in the face adjacency matrix, ensuring one connected component. The selected face for + each connected component is the face closest to the component centroid. + + Parameters + ---------- + face_list : list of tuples + List of faces, each face is a tuple (v0, v1, v2) of vertex indices. + vertices : np.ndarray of shape (num_vertices, 3) + Array of vertex coordinates. + k : int, optional + Number of neighbors to use in centroid KNN. Default is 10. + + Returns + ------- + face_adjacency : scipy.sparse.csr_matrix + A CSR sparse matrix of shape (num_faces, num_faces), + containing 1s for adjacent faces (shared-edge adjacency) + plus dummy edges ensuring a single connected component. + """ + num_faces = len(face_list) + if num_faces == 0: + # Return an empty matrix if no faces + return csr_matrix((0, 0)) + + #-------------------------------------------------------------------------- + # 1) Build adjacency based on shared edges. + # (Same logic as the original code, plus import statements.) + #-------------------------------------------------------------------------- + edge_to_faces = defaultdict(list) + uf = UnionFind(num_faces) + for f_idx, (v0, v1, v2) in enumerate(face_list): + # Sort each edge’s endpoints so (i, j) == (j, i) + edges = [ + tuple(sorted((v0, v1))), + tuple(sorted((v1, v2))), + tuple(sorted((v2, v0))) + ] + for e in edges: + edge_to_faces[e].append(f_idx) + + row = [] + col = [] + for edge, face_indices in edge_to_faces.items(): + unique_faces = list(set(face_indices)) + if len(unique_faces) > 1: + # For every pair of distinct faces that share this edge, + # mark them as mutually adjacent + for i in range(len(unique_faces)): + for j in range(i + 1, len(unique_faces)): + fi = unique_faces[i] + fj = unique_faces[j] + row.append(fi) + col.append(fj) + row.append(fj) + col.append(fi) + uf.union(fi, fj) + + data = np.ones(len(row), dtype=np.int8) + face_adjacency = coo_matrix( + (data, (row, col)), shape=(num_faces, num_faces) + ).tocsr() + + #-------------------------------------------------------------------------- + # 2) Check if the graph from shared edges is already connected. + #-------------------------------------------------------------------------- + n_components = 0 + for i in range(num_faces): + if uf.find(i) == i: + n_components += 1 + print("n_components", n_components) + + if n_components == 1: + # Already a single connected component, no need for dummy edges + return face_adjacency + + #-------------------------------------------------------------------------- + # 3) Compute centroids of each face for building a KNN graph. + #-------------------------------------------------------------------------- + face_centroids = [] + for (v0, v1, v2) in face_list: + centroid = (vertices[v0] + vertices[v1] + vertices[v2]) / 3.0 + face_centroids.append(centroid) + face_centroids = np.array(face_centroids) + + #-------------------------------------------------------------------------- + # 4b) Build a KNN graph on connected components + #-------------------------------------------------------------------------- + # Group faces by their root representative in the Union-Find structure + component_dict = {} + for face_idx in range(num_faces): + root = uf.find(face_idx) + if root not in component_dict: + component_dict[root] = set() + component_dict[root].add(face_idx) + + connected_components = list(component_dict.values()) + + print("Using connected component MST.") + component_centroid_face_idx = [] + connected_component_centroids = [] + knn = NearestNeighbors(n_neighbors=1, algorithm='auto') + for component in connected_components: + curr_component_faces = list(component) + curr_component_face_centroids = face_centroids[curr_component_faces] + component_centroid = np.mean(curr_component_face_centroids, axis=0) + + ### Assign a face closest to the centroid + face_idx = curr_component_faces[np.argmin(np.linalg.norm(curr_component_face_centroids-component_centroid, axis=-1))] + + connected_component_centroids.append(component_centroid) + component_centroid_face_idx.append(face_idx) + + component_centroid_face_idx = np.array(component_centroid_face_idx) + connected_component_centroids = np.array(connected_component_centroids) + + if n_components < k: + knn = NearestNeighbors(n_neighbors=n_components, algorithm='auto') + else: + knn = NearestNeighbors(n_neighbors=k, algorithm='auto') + knn.fit(connected_component_centroids) + distances, indices = knn.kneighbors(connected_component_centroids) + + #-------------------------------------------------------------------------- + # 5) Build a weighted graph in NetworkX using centroid-distances as edges + #-------------------------------------------------------------------------- + G = nx.Graph() + # Add each face as a node in the graph + G.add_nodes_from(range(num_faces)) + + # For each face i, add edges (i -> j) for each neighbor j in the KNN + for idx1 in range(n_components): + i = component_centroid_face_idx[idx1] + for idx2, dist in zip(indices[idx1], distances[idx1]): + j = component_centroid_face_idx[idx2] + if i == j: + continue # skip self-loop + # Add an undirected edge with 'weight' = distance + # NetworkX handles parallel edges gracefully via last add_edge, + # but it typically overwrites the weight if (i, j) already exists. + G.add_edge(i, j, weight=dist) + + #-------------------------------------------------------------------------- + # 6) Compute MST on that KNN graph + #-------------------------------------------------------------------------- + mst = nx.minimum_spanning_tree(G, weight='weight') + # Sort MST edges by ascending weight, so we add the shortest edges first + mst_edges_sorted = sorted( + mst.edges(data=True), key=lambda e: e[2]['weight'] + ) + print("mst edges sorted", len(mst_edges_sorted)) + #-------------------------------------------------------------------------- + # 7) Use a union-find structure to add MST edges only if they + # connect two currently disconnected components of the adjacency matrix + #-------------------------------------------------------------------------- + + # Convert face_adjacency to LIL format for efficient edge addition + adjacency_lil = face_adjacency.tolil() + + # Now, step through MST edges in ascending order + for (u, v, attr) in mst_edges_sorted: + if uf.find(u) != uf.find(v): + # These belong to different components, so unify them + uf.union(u, v) + # And add a "dummy" edge to our adjacency matrix + adjacency_lil[u, v] = 1 + adjacency_lil[v, u] = 1 + + # Convert back to CSR format and return + face_adjacency = adjacency_lil.tocsr() + + if with_knn: + print("Adding KNN edges.") + ### Add KNN edges graph too + dummy_row = [] + dummy_col = [] + for idx1 in range(n_components): + i = component_centroid_face_idx[idx1] + for idx2 in indices[idx1]: + j = component_centroid_face_idx[idx2] + dummy_row.extend([i, j]) + dummy_col.extend([j, i]) ### duplicates are handled by coo + + dummy_data = np.ones(len(dummy_row), dtype=np.int16) + dummy_mat = coo_matrix( + (dummy_data, (dummy_row, dummy_col)), + shape=(num_faces, num_faces) + ).tocsr() + face_adjacency = face_adjacency + dummy_mat + ########################### + + return face_adjacency +######################### + +def construct_face_adjacency_matrix_facemst(face_list, vertices, k=10, with_knn=True): + """ + Given a list of faces (each face is a 3-tuple of vertex indices), + construct a face-based adjacency matrix of shape (num_faces, num_faces). + + Two faces are adjacent if they share an edge (the "mesh adjacency"). + If multiple connected components remain, we: + 1) Compute the centroid of each face. + 2) Use a KNN graph (k=10) based on centroid distances. + 3) Compute MST of that KNN graph. + 4) Add MST edges that connect different components as "dummy" edges + in the face adjacency matrix, ensuring one connected component. + + Parameters + ---------- + face_list : list of tuples + List of faces, each face is a tuple (v0, v1, v2) of vertex indices. + vertices : np.ndarray of shape (num_vertices, 3) + Array of vertex coordinates. + k : int, optional + Number of neighbors to use in centroid KNN. Default is 10. + + Returns + ------- + face_adjacency : scipy.sparse.csr_matrix + A CSR sparse matrix of shape (num_faces, num_faces), + containing 1s for adjacent faces (shared-edge adjacency) + plus dummy edges ensuring a single connected component. + """ + num_faces = len(face_list) + if num_faces == 0: + # Return an empty matrix if no faces + return csr_matrix((0, 0)) + + #-------------------------------------------------------------------------- + # 1) Build adjacency based on shared edges. + # (Same logic as the original code, plus import statements.) + #-------------------------------------------------------------------------- + edge_to_faces = defaultdict(list) + uf = UnionFind(num_faces) + for f_idx, (v0, v1, v2) in enumerate(face_list): + # Sort each edge’s endpoints so (i, j) == (j, i) + edges = [ + tuple(sorted((v0, v1))), + tuple(sorted((v1, v2))), + tuple(sorted((v2, v0))) + ] + for e in edges: + edge_to_faces[e].append(f_idx) + + row = [] + col = [] + for edge, face_indices in edge_to_faces.items(): + unique_faces = list(set(face_indices)) + if len(unique_faces) > 1: + # For every pair of distinct faces that share this edge, + # mark them as mutually adjacent + for i in range(len(unique_faces)): + for j in range(i + 1, len(unique_faces)): + fi = unique_faces[i] + fj = unique_faces[j] + row.append(fi) + col.append(fj) + row.append(fj) + col.append(fi) + uf.union(fi, fj) + + data = np.ones(len(row), dtype=np.int8) + face_adjacency = coo_matrix( + (data, (row, col)), shape=(num_faces, num_faces) + ).tocsr() + + #-------------------------------------------------------------------------- + # 2) Check if the graph from shared edges is already connected. + #-------------------------------------------------------------------------- + n_components = 0 + for i in range(num_faces): + if uf.find(i) == i: + n_components += 1 + print("n_components", n_components) + + if n_components == 1: + # Already a single connected component, no need for dummy edges + return face_adjacency + #-------------------------------------------------------------------------- + # 3) Compute centroids of each face for building a KNN graph. + #-------------------------------------------------------------------------- + face_centroids = [] + for (v0, v1, v2) in face_list: + centroid = (vertices[v0] + vertices[v1] + vertices[v2]) / 3.0 + face_centroids.append(centroid) + face_centroids = np.array(face_centroids) + + #-------------------------------------------------------------------------- + # 4) Build a KNN graph (k=10) over face centroids using scikit‐learn + #-------------------------------------------------------------------------- + knn = NearestNeighbors(n_neighbors=k, algorithm='auto') + knn.fit(face_centroids) + distances, indices = knn.kneighbors(face_centroids) + # 'distances[i]' are the distances from face i to each of its 'k' neighbors + # 'indices[i]' are the face indices of those neighbors + + #-------------------------------------------------------------------------- + # 5) Build a weighted graph in NetworkX using centroid-distances as edges + #-------------------------------------------------------------------------- + G = nx.Graph() + # Add each face as a node in the graph + G.add_nodes_from(range(num_faces)) + + # For each face i, add edges (i -> j) for each neighbor j in the KNN + for i in range(num_faces): + for j, dist in zip(indices[i], distances[i]): + if i == j: + continue # skip self-loop + # Add an undirected edge with 'weight' = distance + # NetworkX handles parallel edges gracefully via last add_edge, + # but it typically overwrites the weight if (i, j) already exists. + G.add_edge(i, j, weight=dist) + + #-------------------------------------------------------------------------- + # 6) Compute MST on that KNN graph + #-------------------------------------------------------------------------- + mst = nx.minimum_spanning_tree(G, weight='weight') + # Sort MST edges by ascending weight, so we add the shortest edges first + mst_edges_sorted = sorted( + mst.edges(data=True), key=lambda e: e[2]['weight'] + ) + print("mst edges sorted", len(mst_edges_sorted)) + #-------------------------------------------------------------------------- + # 7) Use a union-find structure to add MST edges only if they + # connect two currently disconnected components of the adjacency matrix + #-------------------------------------------------------------------------- + + # Convert face_adjacency to LIL format for efficient edge addition + adjacency_lil = face_adjacency.tolil() + + # Now, step through MST edges in ascending order + for (u, v, attr) in mst_edges_sorted: + if uf.find(u) != uf.find(v): + # These belong to different components, so unify them + uf.union(u, v) + # And add a "dummy" edge to our adjacency matrix + adjacency_lil[u, v] = 1 + adjacency_lil[v, u] = 1 + + # Convert back to CSR format and return + face_adjacency = adjacency_lil.tocsr() + + if with_knn: + print("Adding KNN edges.") + ### Add KNN edges graph too + dummy_row = [] + dummy_col = [] + for i in range(num_faces): + for j in indices[i]: + dummy_row.extend([i, j]) + dummy_col.extend([j, i]) ### duplicates are handled by coo + + dummy_data = np.ones(len(dummy_row), dtype=np.int16) + dummy_mat = coo_matrix( + (dummy_data, (dummy_row, dummy_col)), + shape=(num_faces, num_faces) + ).tocsr() + face_adjacency = face_adjacency + dummy_mat + ########################### + + return face_adjacency + +def construct_face_adjacency_matrix_naive(face_list): + """ + Given a list of faces (each face is a 3-tuple of vertex indices), + construct a face-based adjacency matrix of shape (num_faces, num_faces). + Two faces are adjacent if they share an edge. + + If multiple connected components exist, dummy edges are added to + turn them into a single connected component. Edges are added naively by + randomly selecting a face and connecting consecutive components -- (comp_i, comp_i+1) ... + + Parameters + ---------- + face_list : list of tuples + List of faces, each face is a tuple (v0, v1, v2) of vertex indices. + + Returns + ------- + face_adjacency : scipy.sparse.csr_matrix + A CSR sparse matrix of shape (num_faces, num_faces), + containing 1s for adjacent faces and 0s otherwise. + Additional edges are added if the faces are in multiple components. + """ + + num_faces = len(face_list) + if num_faces == 0: + # Return an empty matrix if no faces + return csr_matrix((0, 0)) + + # Step 1: Map each undirected edge -> list of face indices that contain that edge + edge_to_faces = defaultdict(list) + + # Populate the edge_to_faces dictionary + for f_idx, (v0, v1, v2) in enumerate(face_list): + # For an edge, we always store its endpoints in sorted order + # to avoid duplication (e.g. edge (2,5) is the same as (5,2)). + edges = [ + tuple(sorted((v0, v1))), + tuple(sorted((v1, v2))), + tuple(sorted((v2, v0))) + ] + for e in edges: + edge_to_faces[e].append(f_idx) + + # Step 2: Build the adjacency (row, col) lists among faces + row = [] + col = [] + for e, faces_sharing_e in edge_to_faces.items(): + # If an edge is shared by multiple faces, make each pair of those faces adjacent + f_indices = list(set(faces_sharing_e)) # unique face indices for this edge + if len(f_indices) > 1: + # For each pair of faces, mark them as adjacent + for i in range(len(f_indices)): + for j in range(i + 1, len(f_indices)): + f_i = f_indices[i] + f_j = f_indices[j] + row.append(f_i) + col.append(f_j) + row.append(f_j) + col.append(f_i) + + # Create a COO matrix, then convert it to CSR + data = np.ones(len(row), dtype=np.int8) + face_adjacency = coo_matrix( + (data, (row, col)), + shape=(num_faces, num_faces) + ).tocsr() + + # Step 3: Ensure single connected component + # Use connected_components to see how many components exist + n_components, labels = connected_components(face_adjacency, directed=False) + + if n_components > 1: + # We have multiple components; let's "connect" them via dummy edges + # The simplest approach is to pick one face from each component + # and connect them sequentially to enforce a single component. + component_representatives = [] + + for comp_id in range(n_components): + # indices of faces in this component + faces_in_comp = np.where(labels == comp_id)[0] + if len(faces_in_comp) > 0: + # take the first face in this component as a representative + component_representatives.append(faces_in_comp[0]) + + # Now, add edges between consecutive representatives + dummy_row = [] + dummy_col = [] + for i in range(len(component_representatives) - 1): + f_i = component_representatives[i] + f_j = component_representatives[i + 1] + dummy_row.extend([f_i, f_j]) + dummy_col.extend([f_j, f_i]) + + if dummy_row: + dummy_data = np.ones(len(dummy_row), dtype=np.int8) + dummy_mat = coo_matrix( + (dummy_data, (dummy_row, dummy_col)), + shape=(num_faces, num_faces) + ).tocsr() + face_adjacency = face_adjacency + dummy_mat + + return face_adjacency +##################################### + +def load_features(feature_filename, mesh_filename, viz_mode): + + print("Reading features:") + print(f" Feature filename: {feature_filename}") + print(f" Mesh filename: {mesh_filename}") + + # load features + feat = np.load(feature_filename, allow_pickle=True) + feat = feat.astype(np.float32) + + # load mesh things + tm = load_mesh_util(mesh_filename) + + V = np.array(tm.vertices, dtype=np.float32) + F = np.array(tm.faces) + + if viz_mode == "faces": + pca_colors = np.array(tm.visual.face_colors, dtype=np.float32) + pca_colors = pca_colors[:,:3] / 255. + + else: + pca_colors = np.array(tm.visual.vertex_colors, dtype=np.float32) + pca_colors = pca_colors[:,:3] / 255. + + arrgh(V, F, pca_colors, feat) + + print(F) + print(V[F[1][0]]) + print(V[F[1][1]]) + print(V[F[1][2]]) + + return { + 'V' : V, + 'F' : F, + 'pca_colors' : pca_colors, + 'feat_np' : feat, + 'feat_pt' : torch.tensor(feat, device='cuda'), + 'trimesh' : tm, + 'label' : None, + 'num_cluster' : 1, + 'scalar' : None + } + +def prep_feature_mesh(m, name='mesh'): + ps_mesh = ps.register_surface_mesh(name, m['V'], m['F']) + ps_mesh.set_selection_mode('faces_only') + m['ps_mesh'] = ps_mesh + +def viz_pca_colors(m): + m['ps_mesh'].add_color_quantity('pca colors', m['pca_colors'], enabled=True, defined_on=m["viz_mode"]) + +def viz_feature(m, ind): + m['ps_mesh'].add_scalar_quantity('pca colors', m['feat_np'][:,ind], cmap='turbo', enabled=True, defined_on=m["viz_mode"]) + +def feature_distance_np(feats, query_feat): + # normalize + feats = feats / np.linalg.norm(feats,axis=1)[:,None] + query_feat = query_feat / np.linalg.norm(query_feat) + # cosine distance + cos_sim = np.dot(feats, query_feat) + cos_dist = (1 - cos_sim) / 2. + return cos_dist + +def feature_distance_pt(feats, query_feat): + return (1. - torch.nn.functional.cosine_similarity(feats, query_feat[None,:], dim=-1)) / 2. + + +def ps_callback(opts): + m = opts.m + + changed, ind = psim.Combo("Mode", modes_list.index(opts.mode), modes_list) + if changed: + opts.mode = modes_list[ind] + m['ps_mesh'].remove_all_quantities() + + if opts.mode == 'pca': + psim.TextUnformatted("""3-dim PCA embeddeding of features is shown as rgb color""") + viz_pca_colors(m) + + elif opts.mode == 'feature_viz': + psim.TextUnformatted("""Use the slider to scrub through all features.\nCtrl-click to type a particular index.""") + + this_changed, opts.i_feature = psim.SliderInt("feature index", opts.i_feature, v_min=0, v_max=(m['feat_np'].shape[-1]-1)) + this_changed = this_changed or changed + + if this_changed: + viz_feature(m, opts.i_feature) + + elif opts.mode == "cluster_agglo": + psim.TextUnformatted("""Use the slider to toggle the number of desired clusters.""") + cluster_changed, opts.i_cluster = psim.SliderInt("number of clusters", opts.i_cluster, v_min=1, v_max=30) + + ### To handle different face adjacency options + mode_changed, ind = psim.Combo("Adj Matrix Def", adj_mode_list.index(opts.adj_mode), adj_mode_list) + knn_changed, opts.add_knn_edges = psim.Checkbox("Add KNN edges", opts.add_knn_edges) + + if mode_changed: + opts.adj_mode = adj_mode_list[ind] + + if psim.Button("Recompute"): + + ### Run clustering algorithm + num_clusters = opts.i_cluster + + ### Mesh 1 + point_feat = m['feat_np'] + point_feat = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) + + ### Compute adjacency matrix ### + if opts.adj_mode == "Vanilla": + adj_matrix = construct_face_adjacency_matrix_naive(opts.m["F"]) + elif opts.adj_mode == "Face_MST": + adj_matrix = construct_face_adjacency_matrix_facemst(opts.m["F"], opts.m["V"], with_knn=opts.add_knn_edges) + elif opts.adj_mode == "CC_MST": + adj_matrix = construct_face_adjacency_matrix_ccmst(opts.m["F"], opts.m["V"], with_knn=opts.add_knn_edges) + ################################ + + ## Agglomerative clustering + clustering = AgglomerativeClustering(connectivity= adj_matrix, + n_clusters=num_clusters, + ).fit(point_feat) + + m['ps_mesh'].add_scalar_quantity("cluster", clustering.labels_, cmap='turbo', vminmax=(0, num_clusters-1), enabled=True, defined_on=m["viz_mode"]) + print("Recomputed.") + + + elif opts.mode == "cluster_kmeans": + psim.TextUnformatted("""Use the slider to toggle the number of desired clusters.""") + + cluster_changed, opts.i_cluster = psim.SliderInt("number of clusters", opts.i_cluster, v_min=1, v_max=30) + + if psim.Button("Recompute"): + + ### Run clustering algorithm + num_clusters = opts.i_cluster + + ### Mesh 1 + point_feat = m['feat_np'] + point_feat = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) + clustering = KMeans(n_clusters=num_clusters, random_state=0, n_init="auto").fit(point_feat) + + m['ps_mesh'].add_scalar_quantity("cluster", clustering.labels_, cmap='turbo', vminmax=(0, num_clusters-1), enabled=True, defined_on=m["viz_mode"]) + +def main(): + ## Parse args + # Uses simple_parsing library to automatically construct parser from the dataclass Options + parser = ArgumentParser() + parser.add_arguments(Options, dest="options") + parser.add_argument('--data_root', default="../exp_results/partfield_features/trellis/", help='Path the model features are stored.') + args = parser.parse_args() + opts: Options = args.options + + DATA_ROOT = args.data_root + + shape_1 = opts.filename + + if os.path.exists(os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy")): + feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0.npy") + mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply") + else: + feature_fname1 = os.path.join(DATA_ROOT, "part_feat_"+ shape_1 + "_0_batch.npy") + mesh_fname1 = os.path.join(DATA_ROOT, "feat_pca_"+ shape_1 + "_0.ply") + + #### To save output #### + os.makedirs(opts.output_fol, exist_ok=True) + ######################## + + # Initialize + ps.init() + + mesh_dict = load_features(feature_fname1, mesh_fname1, opts.viz_mode) + prep_feature_mesh(mesh_dict) + mesh_dict["viz_mode"] = opts.viz_mode + opts.m = mesh_dict + + # Start the interactive UI + ps.set_user_callback(lambda : ps_callback(opts)) + ps.show() + + +if __name__ == "__main__": + main() + diff --git a/PartField/compute_metric.py b/PartField/compute_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..818e4a7c55a21dd1c32885d6c146e2c1e834e0eb --- /dev/null +++ b/PartField/compute_metric.py @@ -0,0 +1,97 @@ +import numpy as np +import json +from os.path import join +from typing import List +import os + +def compute_iou(pred, gt): + intersection = np.logical_and(pred, gt).sum() + union = np.logical_or(pred, gt).sum() + if union != 0: + return (intersection / union) * 100 + else: + return 0 + +def eval_single_gt_shape(gt_label, pred_masks): + # gt: [N,], label index + # pred: [B, N], B is the number of predicted parts, binary label + unique_gt_label = np.unique(gt_label) + best_ious = [] + for label in unique_gt_label: + best_iou = 0 + if label == -1: + continue + for mask in pred_masks: + iou = compute_iou(mask, gt_label == label) + best_iou = max(best_iou, iou) + best_ious.append(best_iou) + return np.mean(best_ious) + +def eval_whole_dataset(pred_folder, merge_parts=False): + print(pred_folder) + meta = json.load(open("/home/mikaelaangel/Desktop/data/PartObjaverse-Tiny_semantic.json", "r")) + + categories = meta.keys() + results_per_cat = {} + per_cat_mious = [] + overall_mious = [] + + MAX_NUM_CLUSTERS = 20 + view_id = 0 + + for cat in categories: + results_per_cat[cat] = [] + for shape_id in meta[cat].keys(): + + try: + all_pred_labels = [] + for num_cluster in range(2, MAX_NUM_CLUSTERS): + ### load each label + fname_clustering = os.path.join(pred_folder, "cluster_out", str(shape_id) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2)) + ".npy" + pred_label = np.load(fname_clustering) + all_pred_labels.append(np.squeeze(pred_label)) + + all_pred_labels = np.array(all_pred_labels) + + except: + continue + + pred_masks = [] + + #### Path for PartObjaverseTiny Labels + gt_labels_path = "PartObjaverse-Tiny_instance_gt" + ################################# + + gt_label = np.load(os.path.join(gt_labels_path, shape_id + ".npy")) + + if merge_parts: + pred_masks = [] + for result in all_pred_labels: + pred = result + assert pred.shape[0] == gt_label.shape[0] + for label in np.unique(pred): + pred_masks.append(pred == label) + miou = eval_single_gt_shape(gt_label, np.array(pred_masks)) + results_per_cat[cat].append(miou) + else: + best_miou = 0 + for result in all_pred_labels: + pred_masks = [] + pred = result + + for label in np.unique(pred): + pred_masks.append(pred == label) + miou = eval_single_gt_shape(gt_label, np.array(pred_masks)) + best_miou = max(best_miou, miou) + results_per_cat[cat].append(best_miou) + + print(np.mean(results_per_cat[cat])) + per_cat_mious.append(np.mean(results_per_cat[cat])) + overall_mious += results_per_cat[cat] + print(np.mean(per_cat_mious)) + print(np.mean(overall_mious), len(overall_mious)) + + +if __name__ == "__main__": + eval_whole_dataset("dump_partobjtiny_clustering") + diff --git a/PartField/configs/final/correspondence_demo.yaml b/PartField/configs/final/correspondence_demo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..42272f38c0b4e06f93a91b04fce9629c2f1b712e --- /dev/null +++ b/PartField/configs/final/correspondence_demo.yaml @@ -0,0 +1,44 @@ +result_name: partfield_features/correspondence_demo + +continue_ckpt: model/model.ckpt + +triplane_channels_low: 128 +triplane_channels_high: 512 +triplane_resolution: 128 + +vertex_feature: True +n_point_per_face: 1000 +n_sample_each: 10000 +is_pc: False +remesh_demo: False +correspondence_demo: True + +preprocess_mesh: True + +dataset: + type: "Mix" + data_path: data/DenseCorr3D + train_batch_size: 1 + val_batch_size: 1 + train_num_workers: 8 + all_files: + # pairs of example to run correspondence + - animals/071b8_toy_animals_017/simple_mesh.obj + - animals/bdfd0_toy_animals_016/simple_mesh.obj + - animals/2d6b3_toy_animals_009/simple_mesh.obj + - animals/96615_toy_animals_018/simple_mesh.obj + - chairs/063d1_chair_006/simple_mesh.obj + - chairs/bea57_chair_012/simple_mesh.obj + - chairs/fe0fe_chair_004/simple_mesh.obj + - chairs/288dc_chair_011/simple_mesh.obj + # consider decimating animals/../color_mesh.obj yourself for better mesh topology than the provided simple_mesh.obj + # (e.g. <50k vertices for functional map efficiency). + +loss: + triplet: 1.0 + +use_2d_feat: False +pvcnn: + point_encoder_type: 'pvcnn' + z_triplane_channels: 256 + z_triplane_resolution: 128 \ No newline at end of file diff --git a/PartField/configs/final/demo.yaml b/PartField/configs/final/demo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..010fb8dbef280c6abfbc7fd082654402c2cbfcef --- /dev/null +++ b/PartField/configs/final/demo.yaml @@ -0,0 +1,28 @@ +result_name: demo_test + +continue_ckpt: model/model.ckpt + +triplane_channels_low: 128 +triplane_channels_high: 512 +triplane_resolution: 128 + +n_point_per_face: 1000 +n_sample_each: 10000 +is_pc : False +remesh_demo : False + +dataset: + type: "Mix" + data_path: "objaverse_data" + train_batch_size: 1 + val_batch_size: 1 + train_num_workers: 8 + +loss: + triplet: 1.0 + +use_2d_feat: False +pvcnn: + point_encoder_type: 'pvcnn' + z_triplane_channels: 256 + z_triplane_resolution: 128 \ No newline at end of file diff --git a/PartField/download_demo_data.sh b/PartField/download_demo_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..b9f31a36d46f3d7e371a8927f80346e127f253df --- /dev/null +++ b/PartField/download_demo_data.sh @@ -0,0 +1,19 @@ +#!/bin/bash +mkdir data +cd data +mkdir objaverse_samples +cd objaverse_samples +wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-050/00200996b8f34f55a2dd2f44d316d107.glb +wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-042/002e462c8bfa4267a9c9f038c7966f3b.glb +wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-046/0c3ca2b32545416f8f1e6f0e87def1a6.glb +wget https://huggingface.co/datasets/allenai/objaverse/resolve/main/glbs/000-063/65c6ffa083c6496eb84a0aa3c48d63ad.glb + +cd .. +mkdir trellis_samples +cd trellis_samples +wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/scenes/blacksmith/glbs/dwarf.glb +wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/img2/glbs/goblin.glb +wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/img2/glbs/excavator.glb +wget https://github.com/Trellis3D/trellis3d.github.io/raw/refs/heads/main/assets/img2/glbs/elephant.glb +cd .. +cd .. diff --git a/PartField/environment.yml b/PartField/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..152cafa636f9200abdd156c4f41ac56e0c270e78 --- /dev/null +++ b/PartField/environment.yml @@ -0,0 +1,772 @@ +name: partfield +channels: + - nvidia/label/cuda-12.4.0 + - conda-forge + - defaults +dependencies: + - _anaconda_depends=2025.03=py310_mkl_0 + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - aiobotocore=2.21.1=pyhd8ed1ab_0 + - aiohappyeyeballs=2.6.1=pyhd8ed1ab_0 + - aiohttp=3.11.14=py310h89163eb_0 + - aioitertools=0.12.0=pyhd8ed1ab_1 + - aiosignal=1.3.2=pyhd8ed1ab_0 + - alabaster=1.0.0=pyhd8ed1ab_1 + - alsa-lib=1.2.13=hb9d3cd8_0 + - altair=5.5.0=pyhd8ed1ab_1 + - anaconda=custom=py310_3 + - anyio=4.9.0=pyh29332c3_0 + - aom=3.9.1=hac33072_0 + - appdirs=1.4.4=pyhd8ed1ab_1 + - argon2-cffi=23.1.0=pyhd8ed1ab_1 + - argon2-cffi-bindings=21.2.0=py310ha75aee5_5 + - arrow=1.3.0=pyhd8ed1ab_1 + - astroid=3.3.9=py310hff52083_0 + - astropy=6.1.7=py310hf462985_0 + - astropy-iers-data=0.2025.3.31.0.36.18=pyhd8ed1ab_0 + - asttokens=3.0.0=pyhd8ed1ab_1 + - async-lru=2.0.5=pyh29332c3_0 + - async-timeout=5.0.1=pyhd8ed1ab_1 + - asyncssh=2.20.0=pyhd8ed1ab_0 + - atomicwrites=1.4.1=pyhd8ed1ab_1 + - attr=2.5.1=h166bdaf_1 + - attrs=25.3.0=pyh71513ae_0 + - automat=24.8.1=pyhd8ed1ab_1 + - autopep8=2.0.4=pyhd8ed1ab_0 + - aws-c-auth=0.8.6=hd08a7f5_4 + - aws-c-cal=0.8.7=h043a21b_0 + - aws-c-common=0.12.0=hb9d3cd8_0 + - aws-c-compression=0.3.1=h3870646_2 + - aws-c-event-stream=0.5.4=h04a3f94_2 + - aws-c-http=0.9.4=hb9b18c6_4 + - aws-c-io=0.17.0=h3dad3f2_6 + - aws-c-mqtt=0.12.2=h108da3e_2 + - aws-c-s3=0.7.13=h822ba82_2 + - aws-c-sdkutils=0.2.3=h3870646_2 + - aws-checksums=0.2.3=h3870646_2 + - aws-crt-cpp=0.31.0=h55f77e1_4 + - aws-sdk-cpp=1.11.510=h37a5c72_3 + - azure-core-cpp=1.14.0=h5cfcd09_0 + - azure-identity-cpp=1.10.0=h113e628_0 + - azure-storage-blobs-cpp=12.13.0=h3cf044e_1 + - azure-storage-common-cpp=12.8.0=h736e048_1 + - azure-storage-files-datalake-cpp=12.12.0=ha633028_1 + - babel=2.17.0=pyhd8ed1ab_0 + - backports=1.0=pyhd8ed1ab_5 + - backports.tarfile=1.2.0=pyhd8ed1ab_1 + - bcrypt=4.3.0=py310h505e2c1_0 + - beautifulsoup4=4.13.3=pyha770c72_0 + - binaryornot=0.4.4=pyhd8ed1ab_2 + - binutils=2.43=h4852527_4 + - binutils_impl_linux-64=2.43=h4bf12b8_4 + - binutils_linux-64=2.43=h4852527_4 + - black=25.1.0=pyha5154f8_0 + - blas=1.0=mkl + - bleach=6.2.0=pyh29332c3_4 + - bleach-with-css=6.2.0=h82add2a_4 + - blinker=1.9.0=pyhff2d567_0 + - blosc=1.21.6=he440d0b_1 + - bokeh=3.7.0=pyhd8ed1ab_0 + - brotli=1.1.0=hb9d3cd8_2 + - brotli-bin=1.1.0=hb9d3cd8_2 + - brotli-python=1.1.0=py310hf71b8c6_2 + - brunsli=0.1=h9c3ff4c_0 + - bzip2=1.0.8=h4bc722e_7 + - c-ares=1.34.4=hb9d3cd8_0 + - c-blosc2=2.15.2=h3122c55_1 + - c-compiler=1.9.0=h2b85faf_0 + - ca-certificates=2025.1.31=hbcca054_0 + - cached-property=1.5.2=hd8ed1ab_1 + - cached_property=1.5.2=pyha770c72_1 + - cachetools=5.5.2=pyhd8ed1ab_0 + - cairo=1.18.4=h3394656_0 + - certifi=2025.1.31=pyhd8ed1ab_0 + - cffi=1.17.1=py310h8deb56e_0 + - chardet=5.2.0=pyhd8ed1ab_3 + - charls=2.4.2=h59595ed_0 + - charset-normalizer=3.4.1=pyhd8ed1ab_0 + - click=8.1.8=pyh707e725_0 + - cloudpickle=3.1.1=pyhd8ed1ab_0 + - colorama=0.4.6=pyhd8ed1ab_1 + - colorcet=3.1.0=pyhd8ed1ab_1 + - comm=0.2.2=pyhd8ed1ab_1 + - constantly=15.1.0=py_0 + - contourpy=1.3.1=py310h3788b33_0 + - cookiecutter=2.6.0=pyhd8ed1ab_1 + - cpython=3.10.16=py310hd8ed1ab_1 + - cryptography=44.0.2=py310h6c63255_0 + - cssselect=1.2.0=pyhd8ed1ab_1 + - cuda=12.4.0=0 + - cuda-cccl_linux-64=12.8.90=ha770c72_1 + - cuda-command-line-tools=12.8.1=ha770c72_0 + - cuda-compiler=12.8.1=hbad6d8a_0 + - cuda-crt-dev_linux-64=12.8.93=ha770c72_1 + - cuda-crt-tools=12.8.93=ha770c72_1 + - cuda-cudart=12.8.90=h5888daf_1 + - cuda-cudart-dev=12.8.90=h5888daf_1 + - cuda-cudart-dev_linux-64=12.8.90=h3f2d84a_1 + - cuda-cudart-static=12.8.90=h5888daf_1 + - cuda-cudart-static_linux-64=12.8.90=h3f2d84a_1 + - cuda-cudart_linux-64=12.8.90=h3f2d84a_1 + - cuda-cuobjdump=12.8.90=hbd13f7d_1 + - cuda-cupti=12.8.90=hbd13f7d_0 + - cuda-cupti-dev=12.8.90=h5888daf_0 + - cuda-cuxxfilt=12.8.90=hbd13f7d_1 + - cuda-demo-suite=12.4.99=0 + - cuda-driver-dev=12.8.90=h5888daf_1 + - cuda-driver-dev_linux-64=12.8.90=h3f2d84a_1 + - cuda-gdb=12.8.90=h50b4baa_0 + - cuda-libraries=12.8.1=ha770c72_0 + - cuda-libraries-dev=12.8.1=ha770c72_0 + - cuda-nsight=12.8.90=h7938cbb_1 + - cuda-nvcc=12.8.93=hcdd1206_1 + - cuda-nvcc-dev_linux-64=12.8.93=he91c749_1 + - cuda-nvcc-impl=12.8.93=h85509e4_1 + - cuda-nvcc-tools=12.8.93=he02047a_1 + - cuda-nvcc_linux-64=12.8.93=h04802cd_1 + - cuda-nvdisasm=12.8.90=hbd13f7d_1 + - cuda-nvml-dev=12.8.90=hbd13f7d_0 + - cuda-nvprof=12.8.90=hbd13f7d_0 + - cuda-nvprune=12.8.90=hbd13f7d_1 + - cuda-nvrtc=12.8.93=h5888daf_1 + - cuda-nvrtc-dev=12.8.93=h5888daf_1 + - cuda-nvtx=12.8.90=hbd13f7d_0 + - cuda-nvvm-dev_linux-64=12.8.93=ha770c72_1 + - cuda-nvvm-impl=12.8.93=he02047a_1 + - cuda-nvvm-tools=12.8.93=he02047a_1 + - cuda-nvvp=12.8.93=hbd13f7d_1 + - cuda-opencl=12.8.90=hbd13f7d_0 + - cuda-opencl-dev=12.8.90=h5888daf_0 + - cuda-profiler-api=12.8.90=h7938cbb_1 + - cuda-sanitizer-api=12.8.93=hbd13f7d_1 + - cuda-toolkit=12.8.1=ha804496_0 + - cuda-tools=12.8.1=ha770c72_0 + - cuda-version=12.8=h5d125a7_3 + - cuda-visual-tools=12.8.1=ha770c72_0 + - curl=8.12.1=h332b0f4_0 + - cxx-compiler=1.9.0=h1a2810e_0 + - cycler=0.12.1=pyhd8ed1ab_1 + - cyrus-sasl=2.1.27=h54b06d7_7 + - cytoolz=1.0.1=py310ha75aee5_0 + - datashader=0.17.0=pyhd8ed1ab_0 + - dav1d=1.2.1=hd590300_0 + - dbus=1.13.6=h5008d03_3 + - debugpy=1.8.13=py310hf71b8c6_0 + - decorator=5.2.1=pyhd8ed1ab_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - deprecated=1.2.18=pyhd8ed1ab_0 + - diff-match-patch=20241021=pyhd8ed1ab_1 + - dill=0.3.9=pyhd8ed1ab_1 + - docstring-to-markdown=0.16=pyh29332c3_1 + - docutils=0.21.2=pyhd8ed1ab_1 + - double-conversion=3.3.1=h5888daf_0 + - et_xmlfile=2.0.0=pyhd8ed1ab_1 + - exceptiongroup=1.2.2=pyhd8ed1ab_1 + - executing=2.1.0=pyhd8ed1ab_1 + - expat=2.7.0=h5888daf_0 + - fcitx-qt5=1.2.7=h748e8b9_2 + - filelock=3.18.0=pyhd8ed1ab_0 + - flake8=7.1.2=pyhd8ed1ab_0 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_3 + - fontconfig=2.15.0=h7e30c49_1 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - fonttools=4.56.0=py310h89163eb_0 + - fqdn=1.5.1=pyhd8ed1ab_1 + - freetype=2.13.3=h48d6fc4_0 + - frozenlist=1.5.0=py310h89163eb_1 + - fzf=0.61.0=h59e48b9_0 + - gcc=13.3.0=h9576a4e_2 + - gcc_impl_linux-64=13.3.0=h1e990d8_2 + - gcc_linux-64=13.3.0=hc28eda2_8 + - gds-tools=1.13.1.3=h5888daf_0 + - gettext=0.23.1=h5888daf_0 + - gettext-tools=0.23.1=h5888daf_0 + - gflags=2.2.2=h5888daf_1005 + - giflib=5.2.2=hd590300_0 + - gitdb=4.0.12=pyhd8ed1ab_0 + - gitpython=3.1.44=pyhff2d567_0 + - glib=2.84.0=h07242d1_0 + - glib-tools=2.84.0=h4833e2c_0 + - glog=0.7.1=hbabe93e_0 + - gmp=6.3.0=hac33072_2 + - gmpy2=2.1.5=py310he8512ff_3 + - graphite2=1.3.13=h59595ed_1003 + - greenlet=3.1.1=py310hf71b8c6_1 + - gst-plugins-base=1.24.7=h0a52356_0 + - gstreamer=1.24.7=hf3bb09a_0 + - gxx=13.3.0=h9576a4e_2 + - gxx_impl_linux-64=13.3.0=hae580e1_2 + - gxx_linux-64=13.3.0=h6834431_8 + - h11=0.14.0=pyhd8ed1ab_1 + - h2=4.2.0=pyhd8ed1ab_0 + - h5py=3.13.0=nompi_py310h60e0fe6_100 + - harfbuzz=10.4.0=h76408a6_0 + - hdf5=1.14.3=nompi_h2d575fe_109 + - holoviews=1.20.2=pyhd8ed1ab_0 + - hpack=4.1.0=pyhd8ed1ab_0 + - httpcore=1.0.7=pyh29332c3_1 + - httpx=0.28.1=pyhd8ed1ab_0 + - hvplot=0.11.2=pyhd8ed1ab_0 + - hyperframe=6.1.0=pyhd8ed1ab_0 + - hyperlink=21.0.0=pyh29332c3_1 + - icu=75.1=he02047a_0 + - idna=3.10=pyhd8ed1ab_1 + - imagecodecs=2024.12.30=py310h78a9a29_0 + - imageio=2.37.0=pyhfb79c49_0 + - imagesize=1.4.1=pyhd8ed1ab_0 + - imbalanced-learn=0.13.0=pyhd8ed1ab_0 + - importlib-metadata=8.6.1=pyha770c72_0 + - importlib_resources=6.5.2=pyhd8ed1ab_0 + - incremental=24.7.2=pyhd8ed1ab_1 + - inflection=0.5.1=pyhd8ed1ab_1 + - iniconfig=2.0.0=pyhd8ed1ab_1 + - intake=2.0.8=pyhd8ed1ab_0 + - intervaltree=3.1.0=pyhd8ed1ab_1 + - ipykernel=6.29.5=pyh3099207_0 + - ipython=8.34.0=pyh907856f_0 + - ipython_genutils=0.2.0=pyhd8ed1ab_2 + - isoduration=20.11.0=pyhd8ed1ab_1 + - isort=6.0.1=pyhd8ed1ab_0 + - itemadapter=0.11.0=pyhd8ed1ab_0 + - itemloaders=1.3.2=pyhd8ed1ab_1 + - itsdangerous=2.2.0=pyhd8ed1ab_1 + - jaraco.classes=3.4.0=pyhd8ed1ab_2 + - jaraco.context=6.0.1=pyhd8ed1ab_0 + - jaraco.functools=4.1.0=pyhd8ed1ab_0 + - jedi=0.19.2=pyhd8ed1ab_1 + - jeepney=0.9.0=pyhd8ed1ab_0 + - jellyfish=1.1.3=py310h505e2c1_0 + - jinja2=3.1.6=pyhd8ed1ab_0 + - jmespath=1.0.1=pyhd8ed1ab_1 + - joblib=1.4.2=pyhd8ed1ab_1 + - jq=1.7.1=hd590300_0 + - json5=0.10.0=pyhd8ed1ab_1 + - jsonpointer=3.0.0=py310hff52083_1 + - jsonschema=4.23.0=pyhd8ed1ab_1 + - jsonschema-specifications=2024.10.1=pyhd8ed1ab_1 + - jsonschema-with-format-nongpl=4.23.0=hd8ed1ab_1 + - jupyter=1.1.1=pyhd8ed1ab_1 + - jupyter-lsp=2.2.5=pyhd8ed1ab_1 + - jupyter_client=8.6.3=pyhd8ed1ab_1 + - jupyter_console=6.6.3=pyhd8ed1ab_1 + - jupyter_core=5.7.2=pyh31011fe_1 + - jupyter_events=0.12.0=pyh29332c3_0 + - jupyter_server=2.15.0=pyhd8ed1ab_0 + - jupyter_server_terminals=0.5.3=pyhd8ed1ab_1 + - jupyterlab=4.3.6=pyhd8ed1ab_0 + - jupyterlab-variableinspector=3.2.4=pyhd8ed1ab_0 + - jupyterlab_pygments=0.3.0=pyhd8ed1ab_2 + - jupyterlab_server=2.27.3=pyhd8ed1ab_1 + - jxrlib=1.1=hd590300_3 + - kernel-headers_linux-64=3.10.0=he073ed8_18 + - keyring=25.6.0=pyha804496_0 + - keyutils=1.6.1=h166bdaf_0 + - kiwisolver=1.4.7=py310h3788b33_0 + - krb5=1.21.3=h659f571_0 + - lame=3.100=h166bdaf_1003 + - lazy-loader=0.4=pyhd8ed1ab_2 + - lazy_loader=0.4=pyhd8ed1ab_2 + - lcms2=2.17=h717163a_0 + - ld_impl_linux-64=2.43=h712a8e2_4 + - lerc=4.0.0=h27087fc_0 + - libabseil=20250127.1=cxx17_hbbce691_0 + - libaec=1.1.3=h59595ed_0 + - libarrow=19.0.1=h120c447_5_cpu + - libarrow-acero=19.0.1=hcb10f89_5_cpu + - libarrow-dataset=19.0.1=hcb10f89_5_cpu + - libarrow-substrait=19.0.1=h1bed206_5_cpu + - libasprintf=0.23.1=h8e693c7_0 + - libasprintf-devel=0.23.1=h8e693c7_0 + - libavif16=1.2.1=hbb36593_2 + - libblas=3.9.0=1_h86c2bf4_netlib + - libbrotlicommon=1.1.0=hb9d3cd8_2 + - libbrotlidec=1.1.0=hb9d3cd8_2 + - libbrotlienc=1.1.0=hb9d3cd8_2 + - libcap=2.75=h39aace5_0 + - libcblas=3.9.0=8_h3b12eaf_netlib + - libclang-cpp19.1=19.1.7=default_hb5137d0_2 + - libclang-cpp20.1=20.1.1=default_hb5137d0_0 + - libclang13=20.1.1=default_h9c6a7e4_0 + - libcrc32c=1.1.2=h9c3ff4c_0 + - libcublas=12.8.4.1=h9ab20c4_1 + - libcublas-dev=12.8.4.1=h9ab20c4_1 + - libcufft=11.3.3.83=h5888daf_1 + - libcufft-dev=11.3.3.83=h5888daf_1 + - libcufile=1.13.1.3=h12f29b5_0 + - libcufile-dev=1.13.1.3=h5888daf_0 + - libcups=2.3.3=h4637d8d_4 + - libcurand=10.3.9.90=h9ab20c4_1 + - libcurand-dev=10.3.9.90=h9ab20c4_1 + - libcurl=8.12.1=h332b0f4_0 + - libcusolver=11.7.3.90=h9ab20c4_1 + - libcusolver-dev=11.7.3.90=h9ab20c4_1 + - libcusparse=12.5.8.93=hbd13f7d_0 + - libcusparse-dev=12.5.8.93=h5888daf_0 + - libdeflate=1.23=h4ddbbb0_0 + - libdrm=2.4.124=hb9d3cd8_0 + - libedit=3.1.20250104=pl5321h7949ede_0 + - libegl=1.7.0=ha4b6fd6_2 + - libev=4.33=hd590300_2 + - libevent=2.1.12=hf998b51_1 + - libexpat=2.7.0=h5888daf_0 + - libffi=3.4.6=h2dba641_1 + - libflac=1.4.3=h59595ed_0 + - libgcc=14.2.0=h767d61c_2 + - libgcc-devel_linux-64=13.3.0=hc03c837_102 + - libgcc-ng=14.2.0=h69a702a_2 + - libgcrypt-lib=1.11.0=hb9d3cd8_2 + - libgettextpo=0.23.1=h5888daf_0 + - libgettextpo-devel=0.23.1=h5888daf_0 + - libgfortran=14.2.0=h69a702a_2 + - libgfortran-ng=14.2.0=h69a702a_2 + - libgfortran5=14.2.0=hf1ad2bd_2 + - libgl=1.7.0=ha4b6fd6_2 + - libglib=2.84.0=h2ff4ddf_0 + - libglvnd=1.7.0=ha4b6fd6_2 + - libglx=1.7.0=ha4b6fd6_2 + - libgomp=14.2.0=h767d61c_2 + - libgoogle-cloud=2.36.0=hc4361e1_1 + - libgoogle-cloud-storage=2.36.0=h0121fbd_1 + - libgpg-error=1.51=hbd13f7d_1 + - libgrpc=1.71.0=he753a82_0 + - libhwy=1.1.0=h00ab1b0_0 + - libiconv=1.18=h4ce23a2_1 + - libjpeg-turbo=3.0.0=hd590300_1 + - libjxl=0.11.1=hdb8da77_0 + - liblapack=3.9.0=8_h3b12eaf_netlib + - libllvm19=19.1.7=ha7bfdaf_1 + - libllvm20=20.1.1=ha7bfdaf_0 + - liblzma=5.6.4=hb9d3cd8_0 + - libnghttp2=1.64.0=h161d5f1_0 + - libnl=3.11.0=hb9d3cd8_0 + - libnpp=12.3.3.100=h9ab20c4_1 + - libnpp-dev=12.3.3.100=h9ab20c4_1 + - libnsl=2.0.1=hd590300_0 + - libntlm=1.8=hb9d3cd8_0 + - libnuma=2.0.18=h4ab18f5_2 + - libnvfatbin=12.8.90=hbd13f7d_0 + - libnvfatbin-dev=12.8.90=h5888daf_0 + - libnvjitlink=12.8.93=h5888daf_1 + - libnvjitlink-dev=12.8.93=h5888daf_1 + - libnvjpeg=12.3.5.92=h97fd463_0 + - libnvjpeg-dev=12.3.5.92=ha770c72_0 + - libogg=1.3.5=h4ab18f5_0 + - libopengl=1.7.0=ha4b6fd6_2 + - libopentelemetry-cpp=1.19.0=hd1b1c89_0 + - libopentelemetry-cpp-headers=1.19.0=ha770c72_0 + - libopus=1.3.1=h7f98852_1 + - libparquet=19.0.1=h081d1f1_5_cpu + - libpciaccess=0.18=hd590300_0 + - libpng=1.6.47=h943b412_0 + - libpq=17.4=h27ae623_0 + - libprotobuf=5.29.3=h501fc15_0 + - libre2-11=2024.07.02=hba17884_3 + - libsanitizer=13.3.0=he8ea267_2 + - libsndfile=1.2.2=hc60ed4a_1 + - libsodium=1.0.20=h4ab18f5_0 + - libspatialindex=2.1.0=he57a185_0 + - libsqlite=3.49.1=hee588c1_2 + - libssh2=1.11.1=hf672d98_0 + - libstdcxx=14.2.0=h8f9b012_2 + - libstdcxx-devel_linux-64=13.3.0=hc03c837_102 + - libstdcxx-ng=14.2.0=h4852527_2 + - libsystemd0=257.4=h4e0b6ca_1 + - libthrift=0.21.0=h0e7cc3e_0 + - libtiff=4.7.0=hd9ff511_3 + - libudev1=257.4=hbe16f8c_1 + - libutf8proc=2.10.0=h4c51ac1_0 + - libuuid=2.38.1=h0b41bf4_0 + - libvorbis=1.3.7=h9c3ff4c_0 + - libwebp=1.5.0=hae8dbeb_0 + - libwebp-base=1.5.0=h851e524_0 + - libxcb=1.17.0=h8a09558_0 + - libxcrypt=4.4.36=hd590300_1 + - libxkbcommon=1.8.1=hc4a0caf_0 + - libxkbfile=1.1.0=h166bdaf_1 + - libxml2=2.13.7=h8d12d68_0 + - libxslt=1.1.39=h76b75d6_0 + - libzlib=1.3.1=hb9d3cd8_2 + - libzopfli=1.0.3=h9c3ff4c_0 + - linkify-it-py=2.0.3=pyhd8ed1ab_1 + - locket=1.0.0=pyhd8ed1ab_0 + - lxml=5.3.1=py310h6ee67d5_0 + - lz4=4.3.3=py310h80b8a69_2 + - lz4-c=1.10.0=h5888daf_1 + - markdown=3.6=pyhd8ed1ab_0 + - markdown-it-py=3.0.0=pyhd8ed1ab_1 + - markupsafe=3.0.2=py310h89163eb_1 + - matplotlib=3.10.1=py310hff52083_0 + - matplotlib-base=3.10.1=py310h68603db_0 + - matplotlib-inline=0.1.7=pyhd8ed1ab_1 + - mccabe=0.7.0=pyhd8ed1ab_1 + - mdit-py-plugins=0.4.2=pyhd8ed1ab_1 + - mdurl=0.1.2=pyhd8ed1ab_1 + - mistune=3.1.3=pyh29332c3_0 + - more-itertools=10.6.0=pyhd8ed1ab_0 + - mpc=1.3.1=h24ddda3_1 + - mpfr=4.2.1=h90cbb55_3 + - mpg123=1.32.9=hc50e24c_0 + - mpmath=1.3.0=pyhd8ed1ab_1 + - msgpack-python=1.1.0=py310h3788b33_0 + - multidict=6.2.0=py310h89163eb_0 + - multipledispatch=0.6.0=pyhd8ed1ab_1 + - munkres=1.1.4=pyh9f0ad1d_0 + - mypy=1.15.0=py310ha75aee5_0 + - mypy_extensions=1.0.0=pyha770c72_1 + - mysql-common=9.0.1=h266115a_5 + - mysql-libs=9.0.1=he0572af_5 + - narwhals=1.32.0=pyhd8ed1ab_0 + - nbclient=0.10.2=pyhd8ed1ab_0 + - nbconvert=7.16.6=hb482800_0 + - nbconvert-core=7.16.6=pyh29332c3_0 + - nbconvert-pandoc=7.16.6=hed9df3c_0 + - nbformat=5.10.4=pyhd8ed1ab_1 + - ncurses=6.5=h2d0b736_3 + - nest-asyncio=1.6.0=pyhd8ed1ab_1 + - networkx=3.4.2=pyh267e887_2 + - nlohmann_json=3.11.3=he02047a_1 + - nltk=3.9.1=pyhd8ed1ab_1 + - nomkl=1.0=h5ca1d4c_0 + - notebook=7.3.3=pyhd8ed1ab_0 + - notebook-shim=0.2.4=pyhd8ed1ab_1 + - nsight-compute=2025.1.1.2=hb5ebaad_0 + - nspr=4.36=h5888daf_0 + - nss=3.110=h159eef7_0 + - numexpr=2.10.2=py310hdb6e06b_100 + - numpydoc=1.8.0=pyhd8ed1ab_1 + - ocl-icd=2.3.2=hb9d3cd8_2 + - oniguruma=6.9.10=hb9d3cd8_0 + - opencl-headers=2024.10.24=h5888daf_0 + - openjpeg=2.5.3=h5fbd93e_0 + - openldap=2.6.9=he970967_0 + - openpyxl=3.1.5=py310h0999ad4_1 + - openssl=3.4.1=h7b32b05_0 + - orc=2.1.1=h17f744e_1 + - overrides=7.7.0=pyhd8ed1ab_1 + - packaging=24.2=pyhd8ed1ab_2 + - pandas=2.2.3=py310h5eaa309_1 + - pandoc=3.6.4=ha770c72_0 + - pandocfilters=1.5.0=pyhd8ed1ab_0 + - panel=1.6.2=pyhd8ed1ab_0 + - param=2.2.0=pyhd8ed1ab_0 + - parsel=1.10.0=pyhd8ed1ab_0 + - parso=0.8.4=pyhd8ed1ab_1 + - partd=1.4.2=pyhd8ed1ab_0 + - pathspec=0.12.1=pyhd8ed1ab_1 + - patsy=1.0.1=pyhd8ed1ab_1 + - pcre2=10.44=hba22ea6_2 + - pexpect=4.9.0=pyhd8ed1ab_1 + - pickleshare=0.7.5=pyhd8ed1ab_1004 + - pillow=11.1.0=py310h7e6dc6c_0 + - pip=25.0.1=pyh8b19718_0 + - pixman=0.44.2=h29eaf8c_0 + - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_2 + - platformdirs=4.3.7=pyh29332c3_0 + - plotly=6.0.1=pyhd8ed1ab_0 + - pluggy=1.5.0=pyhd8ed1ab_1 + - ply=3.11=pyhd8ed1ab_3 + - prometheus-cpp=1.3.0=ha5d0236_0 + - prometheus_client=0.21.1=pyhd8ed1ab_0 + - prompt-toolkit=3.0.50=pyha770c72_0 + - prompt_toolkit=3.0.50=hd8ed1ab_0 + - propcache=0.2.1=py310h89163eb_1 + - protego=0.4.0=pyhd8ed1ab_0 + - protobuf=5.29.3=py310hcba5963_0 + - psutil=7.0.0=py310ha75aee5_0 + - pthread-stubs=0.4=hb9d3cd8_1002 + - ptyprocess=0.7.0=pyhd8ed1ab_1 + - pulseaudio-client=17.0=hac146a9_1 + - pure_eval=0.2.3=pyhd8ed1ab_1 + - py-cpuinfo=9.0.0=pyhd8ed1ab_1 + - pyarrow=19.0.1=py310hff52083_0 + - pyarrow-core=19.0.1=py310hac404ae_0_cpu + - pyasn1=0.6.1=pyhd8ed1ab_2 + - pyasn1-modules=0.4.2=pyhd8ed1ab_0 + - pycodestyle=2.12.1=pyhd8ed1ab_1 + - pyconify=0.2.1=pyhd8ed1ab_0 + - pycparser=2.22=pyh29332c3_1 + - pyct=0.5.0=pyhd8ed1ab_1 + - pycurl=7.45.6=py310h6811363_0 + - pydeck=0.9.1=pyhd8ed1ab_0 + - pydispatcher=2.0.5=py_1 + - pydocstyle=6.3.0=pyhd8ed1ab_1 + - pyerfa=2.0.1.5=py310hf462985_0 + - pyflakes=3.2.0=pyhd8ed1ab_1 + - pygithub=2.6.1=pyhd8ed1ab_0 + - pygments=2.19.1=pyhd8ed1ab_0 + - pyjwt=2.10.1=pyhd8ed1ab_0 + - pylint=3.3.5=pyh29332c3_0 + - pylint-venv=3.0.4=pyhd8ed1ab_1 + - pyls-spyder=0.4.0=pyhd8ed1ab_1 + - pynacl=1.5.0=py310ha75aee5_4 + - pyodbc=5.2.0=py310hf71b8c6_0 + - pyopenssl=25.0.0=pyhd8ed1ab_0 + - pyparsing=3.2.3=pyhd8ed1ab_1 + - pyqt=5.15.9=py310h04931ad_5 + - pyqt5-sip=12.12.2=py310hc6cd4ac_5 + - pyqtwebengine=5.15.9=py310h704022c_5 + - pyside6=6.8.3=py310hfd10a26_0 + - pysocks=1.7.1=pyha55dd90_7 + - pytables=3.10.1=py310h1affd9f_4 + - pytest=8.3.5=pyhd8ed1ab_0 + - python=3.10.16=he725a3c_1_cpython + - python-dateutil=2.9.0.post0=pyhff2d567_1 + - python-fastjsonschema=2.21.1=pyhd8ed1ab_0 + - python-gssapi=1.9.0=py310h695cd88_1 + - python-json-logger=2.0.7=pyhd8ed1ab_0 + - python-lsp-black=2.0.0=pyhff2d567_1 + - python-lsp-jsonrpc=1.1.2=pyhff2d567_1 + - python-lsp-server=1.12.2=pyhff2d567_0 + - python-lsp-server-base=1.12.2=pyhd8ed1ab_0 + - python-slugify=8.0.4=pyhd8ed1ab_1 + - python-tzdata=2025.2=pyhd8ed1ab_0 + - python_abi=3.10=5_cp310 + - pytoolconfig=1.2.5=pyhd8ed1ab_1 + - pytz=2024.1=pyhd8ed1ab_0 + - pyuca=1.2=pyhd8ed1ab_2 + - pyviz_comms=3.0.4=pyhd8ed1ab_1 + - pywavelets=1.8.0=py310hf462985_0 + - pyxdg=0.28=pyhd8ed1ab_0 + - pyyaml=6.0.2=py310h89163eb_2 + - pyzmq=26.3.0=py310h71f11fc_0 + - qdarkstyle=3.2.3=pyhd8ed1ab_1 + - qhull=2020.2=h434a139_5 + - qstylizer=0.2.4=pyhff2d567_0 + - qt-main=5.15.15=hc3cb62f_2 + - qt-webengine=5.15.15=h0071231_2 + - qt6-main=6.8.3=h588cce1_0 + - qtawesome=1.4.0=pyh9208f05_1 + - qtconsole=5.6.1=pyhd8ed1ab_1 + - qtconsole-base=5.6.1=pyha770c72_1 + - qtpy=2.4.3=pyhd8ed1ab_0 + - queuelib=1.8.0=pyhd8ed1ab_0 + - rav1e=0.6.6=he8a937b_2 + - rdma-core=56.0=h5888daf_0 + - re2=2024.07.02=h9925aae_3 + - readline=8.2=h8c095d6_2 + - referencing=0.36.2=pyh29332c3_0 + - regex=2024.11.6=py310ha75aee5_0 + - requests=2.32.3=pyhd8ed1ab_1 + - requests-file=2.1.0=pyhd8ed1ab_1 + - rfc3339-validator=0.1.4=pyhd8ed1ab_1 + - rfc3986-validator=0.1.1=pyh9f0ad1d_0 + - rich=14.0.0=pyh29332c3_0 + - rope=1.13.0=pyhd8ed1ab_1 + - rpds-py=0.24.0=py310hc1293b2_0 + - rtree=1.4.0=pyh11ca60a_1 + - s2n=1.5.14=h6c98b2b_0 + - s3fs=2025.3.1=pyhd8ed1ab_0 + - scikit-image=0.25.2=py310h5eaa309_0 + - scikit-learn=1.6.1=py310h27f47ee_0 + - scipy=1.15.2=py310h1d65ade_0 + - scrapy=2.12.0=py310hff52083_1 + - seaborn=0.13.2=hd8ed1ab_3 + - seaborn-base=0.13.2=pyhd8ed1ab_3 + - secretstorage=3.3.3=py310hff52083_3 + - send2trash=1.8.3=pyh0d859eb_1 + - service-identity=24.2.0=pyha770c72_1 + - service_identity=24.2.0=hd8ed1ab_1 + - setuptools=75.8.2=pyhff2d567_0 + - sip=6.7.12=py310hc6cd4ac_0 + - six=1.17.0=pyhd8ed1ab_0 + - sklearn-compat=0.1.3=pyhd8ed1ab_0 + - smmap=5.0.2=pyhd8ed1ab_0 + - snappy=1.2.1=h8bd8927_1 + - sniffio=1.3.1=pyhd8ed1ab_1 + - snowballstemmer=2.2.0=pyhd8ed1ab_0 + - sortedcontainers=2.4.0=pyhd8ed1ab_1 + - soupsieve=2.5=pyhd8ed1ab_1 + - sphinx=8.1.3=pyhd8ed1ab_1 + - sphinxcontrib-applehelp=2.0.0=pyhd8ed1ab_1 + - sphinxcontrib-devhelp=2.0.0=pyhd8ed1ab_1 + - sphinxcontrib-htmlhelp=2.1.0=pyhd8ed1ab_1 + - sphinxcontrib-jsmath=1.0.1=pyhd8ed1ab_1 + - sphinxcontrib-qthelp=2.0.0=pyhd8ed1ab_1 + - sphinxcontrib-serializinghtml=1.1.10=pyhd8ed1ab_1 + - spyder=6.0.5=hd8ed1ab_0 + - spyder-base=6.0.5=linux_pyh62a8a7d_0 + - spyder-kernels=3.0.3=unix_pyh707e725_0 + - sqlalchemy=2.0.40=py310ha75aee5_0 + - stack_data=0.6.3=pyhd8ed1ab_1 + - statsmodels=0.14.4=py310hf462985_0 + - streamlit=1.44.0=pyhd8ed1ab_1 + - superqt=0.7.3=pyhb6d5dde_0 + - svt-av1=3.0.2=h5888daf_0 + - sympy=1.13.3=pyh2585a3b_105 + - sysroot_linux-64=2.17=h0157908_18 + - tabulate=0.9.0=pyhd8ed1ab_2 + - tblib=3.0.0=pyhd8ed1ab_1 + - tenacity=9.0.0=pyhd8ed1ab_1 + - terminado=0.18.1=pyh0d859eb_0 + - text-unidecode=1.3=pyhd8ed1ab_2 + - textdistance=4.6.3=pyhd8ed1ab_1 + - threadpoolctl=3.6.0=pyhecae5ae_0 + - three-merge=0.1.1=pyhd8ed1ab_1 + - tifffile=2025.3.30=pyhd8ed1ab_0 + - tinycss2=1.4.0=pyhd8ed1ab_0 + - tk=8.6.13=noxft_h4845f30_101 + - tldextract=5.1.3=pyhd8ed1ab_1 + - toml=0.10.2=pyhd8ed1ab_1 + - tomli=2.2.1=pyhd8ed1ab_1 + - tomlkit=0.13.2=pyha770c72_1 + - toolz=1.0.0=pyhd8ed1ab_1 + - tornado=6.4.2=py310ha75aee5_0 + - tqdm=4.67.1=pyhd8ed1ab_1 + - traitlets=5.14.3=pyhd8ed1ab_1 + - twisted=24.11.0=py310ha75aee5_0 + - types-python-dateutil=2.9.0.20241206=pyhd8ed1ab_0 + - typing-extensions=4.13.0=h9fa5a19_1 + - typing_extensions=4.13.0=pyh29332c3_1 + - typing_utils=0.1.0=pyhd8ed1ab_1 + - tzdata=2025b=h78e105d_0 + - uc-micro-py=1.0.3=pyhd8ed1ab_1 + - ujson=5.10.0=py310hf71b8c6_1 + - unicodedata2=16.0.0=py310ha75aee5_0 + - unixodbc=2.3.12=h661eb56_0 + - uri-template=1.3.0=pyhd8ed1ab_1 + - urllib3=2.3.0=pyhd8ed1ab_0 + - w3lib=2.3.1=pyhd8ed1ab_0 + - watchdog=6.0.0=py310hff52083_0 + - wayland=1.23.1=h3e06ad9_0 + - wcwidth=0.2.13=pyhd8ed1ab_1 + - webcolors=24.11.1=pyhd8ed1ab_0 + - webencodings=0.5.1=pyhd8ed1ab_3 + - websocket-client=1.8.0=pyhd8ed1ab_1 + - whatthepatch=1.0.7=pyhd8ed1ab_1 + - wheel=0.45.1=pyhd8ed1ab_1 + - wrapt=1.17.2=py310ha75aee5_0 + - wurlitzer=3.1.1=pyhd8ed1ab_1 + - xarray=2025.3.1=pyhd8ed1ab_0 + - xcb-util=0.4.1=hb711507_2 + - xcb-util-cursor=0.1.5=hb9d3cd8_0 + - xcb-util-image=0.4.0=hb711507_2 + - xcb-util-keysyms=0.4.1=hb711507_0 + - xcb-util-renderutil=0.3.10=hb711507_0 + - xcb-util-wm=0.4.2=hb711507_0 + - xkeyboard-config=2.43=hb9d3cd8_0 + - xorg-libice=1.1.2=hb9d3cd8_0 + - xorg-libsm=1.2.6=he73a12e_0 + - xorg-libx11=1.8.12=h4f16b4b_0 + - xorg-libxau=1.0.12=hb9d3cd8_0 + - xorg-libxcomposite=0.4.6=hb9d3cd8_2 + - xorg-libxcursor=1.2.3=hb9d3cd8_0 + - xorg-libxdamage=1.1.6=hb9d3cd8_0 + - xorg-libxdmcp=1.1.5=hb9d3cd8_0 + - xorg-libxext=1.3.6=hb9d3cd8_0 + - xorg-libxfixes=6.0.1=hb9d3cd8_0 + - xorg-libxi=1.8.2=hb9d3cd8_0 + - xorg-libxrandr=1.5.4=hb9d3cd8_0 + - xorg-libxrender=0.9.12=hb9d3cd8_0 + - xorg-libxtst=1.2.5=hb9d3cd8_3 + - xorg-libxxf86vm=1.1.6=hb9d3cd8_0 + - xyzservices=2025.1.0=pyhd8ed1ab_0 + - yaml=0.2.5=h7f98852_2 + - yapf=0.43.0=pyhd8ed1ab_1 + - yarl=1.18.3=py310h89163eb_1 + - zeromq=4.3.5=h3b0a872_7 + - zfp=1.0.1=h5888daf_2 + - zict=3.0.0=pyhd8ed1ab_1 + - zipp=3.21.0=pyhd8ed1ab_1 + - zlib=1.3.1=hb9d3cd8_2 + - zlib-ng=2.2.4=h7955e40_0 + - zope.interface=7.2=py310ha75aee5_0 + - zstandard=0.23.0=py310ha75aee5_1 + - zstd=1.5.7=hb8e6e7a_2 + - pip: + - addict==2.4.0 + - arrgh==1.0.0 + - boto3==1.37.24 + - botocore==1.37.24 + - configargparse==1.7 + - cuda-bindings==12.8.0 + - cuda-python==12.8.0 + - cudf-cu12==25.2.2 + - cuml-cu12==25.2.1 + - cupy-cuda12x==13.4.1 + - cuvs-cu12==25.2.1 + - dash==3.0.2 + - dask==2024.12.1 + - dask-cuda==25.2.0 + - dask-cudf-cu12==25.2.2 + - dask-expr==1.1.21 + - distributed==2024.12.1 + - distributed-ucxx-cu12==0.42.0 + - docstring-parser==0.16 + - einops==0.8.1 + - fastrlock==0.8.3 + - flask==3.0.3 + - fsspec==2024.12.0 + - ipywidgets==8.1.5 + - jupyterlab-widgets==3.0.13 + - libcudf-cu12==25.2.2 + - libcuml-cu12==25.2.1 + - libcuvs-cu12==25.2.1 + - libigl==2.5.1 + - libkvikio-cu12==25.2.1 + - libraft-cu12==25.2.0 + - libucx-cu12==1.18.0 + - libucxx-cu12==0.42.0 + - lightning==2.2.0 + - lightning-utilities==0.14.2 + - llvmlite==0.43.0 + - loguru==0.7.3 + - mesh2sdf==1.1.0 + - numba==0.60.0 + - numba-cuda==0.2.0 + - numpy==2.0.2 + - nvidia-cublas-cu12==12.4.2.65 + - nvidia-cuda-cupti-cu12==12.4.99 + - nvidia-cuda-nvrtc-cu12==12.4.99 + - nvidia-cuda-runtime-cu12==12.4.99 + - nvidia-cudnn-cu12==9.1.0.70 + - nvidia-cufft-cu12==11.2.0.44 + - nvidia-curand-cu12==10.3.5.119 + - nvidia-cusolver-cu12==11.6.0.99 + - nvidia-cusparse-cu12==12.3.0.142 + - nvidia-ml-py==12.570.86 + - nvidia-nccl-cu12==2.20.5 + - nvidia-nvcomp-cu12==4.2.0.11 + - nvidia-nvjitlink-cu12==12.4.99 + - nvidia-nvtx-cu12==12.4.99 + - nvtx==0.2.11 + - open3d==0.19.0 + - plyfile==1.1 + - polyscope==2.4.0 + - pooch==1.8.2 + - potpourri3d==1.2.1 + - pylibcudf-cu12==25.2.2 + - pylibraft-cu12==25.2.0 + - pymeshlab==2023.12.post3 + - pynvjitlink-cu12==0.5.2 + - pynvml==12.0.0 + - pyquaternion==0.9.9 + - pytorch-lightning==2.5.1 + - pyvista==0.44.2 + - raft-dask-cu12==25.2.0 + - rapids-dask-dependency==25.2.0 + - retrying==1.3.4 + - rmm-cu12==25.2.0 + - s3transfer==0.11.4 + - scooby==0.10.0 + - simple-parsing==0.1.7 + - tetgen==0.6.5 + - torch==2.4.0+cu124 + - torch-scatter==2.1.2+pt24cu124 + - torchaudio==2.4.0+cu124 + - torchmetrics==1.7.0 + - torchvision==0.19.0+cu124 + - treelite==4.4.1 + - trimesh==4.6.6 + - triton==3.0.0 + - ucx-py-cu12==0.42.0 + - ucxx-cu12==0.42.0 + - vtk==9.3.1 + - werkzeug==3.0.6 + - widgetsnbextension==4.0.13 + - xgboost==3.0.0 + - yacs==0.1.8 diff --git a/PartField/partfield/__pycache__/dataloader.cpython-310.pyc b/PartField/partfield/__pycache__/dataloader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0b930e994798c163acdca1f8cb6fa16b90efab4 Binary files /dev/null and b/PartField/partfield/__pycache__/dataloader.cpython-310.pyc differ diff --git a/PartField/partfield/__pycache__/model_trainer_pvcnn_only_demo.cpython-310.pyc b/PartField/partfield/__pycache__/model_trainer_pvcnn_only_demo.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db7a74cd4e6b605004dea58eea4d3fc54dbf7c95 Binary files /dev/null and b/PartField/partfield/__pycache__/model_trainer_pvcnn_only_demo.cpython-310.pyc differ diff --git a/PartField/partfield/__pycache__/utils.cpython-310.pyc b/PartField/partfield/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47b42cf17bc28405c1a5991e1022884ccdb99bfd Binary files /dev/null and b/PartField/partfield/__pycache__/utils.cpython-310.pyc differ diff --git a/PartField/partfield/config/__init__.py b/PartField/partfield/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..39582506b85759ab473acedb5d0f15b7d7f26594 --- /dev/null +++ b/PartField/partfield/config/__init__.py @@ -0,0 +1,26 @@ +import argparse +import os.path as osp +from datetime import datetime +import pytz + +def default_argument_parser(add_help=True, default_config_file=""): + parser = argparse.ArgumentParser(add_help=add_help) + parser.add_argument("--config-file", '-c', default=default_config_file, metavar="FILE", help="path to config file") + parser.add_argument( + "--opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + return parser + +def setup(args, freeze=True): + from .defaults import _C as cfg + cfg = cfg.clone() + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + dt = datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%y%m%d-%H%M%S') + cfg.output_dir = osp.join(cfg.output_dir, cfg.name, dt) + if freeze: + cfg.freeze() + return cfg \ No newline at end of file diff --git a/PartField/partfield/config/__pycache__/__init__.cpython-310.pyc b/PartField/partfield/config/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5296568b359080307e813e6e5e39e427886716c6 Binary files /dev/null and b/PartField/partfield/config/__pycache__/__init__.cpython-310.pyc differ diff --git a/PartField/partfield/config/__pycache__/defaults.cpython-310.pyc b/PartField/partfield/config/__pycache__/defaults.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1f36d317fc60ade8627ce994465025c840367af Binary files /dev/null and b/PartField/partfield/config/__pycache__/defaults.cpython-310.pyc differ diff --git a/PartField/partfield/config/defaults.py b/PartField/partfield/config/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..14f82ed464b7a9afb6bda92bad66c59b1077289a --- /dev/null +++ b/PartField/partfield/config/defaults.py @@ -0,0 +1,92 @@ +from yacs.config import CfgNode as CN + +_C = CN() +_C.seed = 0 +_C.output_dir = "results" +_C.result_name = "test_all" + +_C.triplet_sampling = "random" +_C.load_original_mesh = False + +_C.num_pos = 64 +_C.num_neg_random = 256 +_C.num_neg_hard_pc = 128 +_C.num_neg_hard_emb = 128 + +_C.vertex_feature = False # if true, sample feature on vertices; if false, sample feature on faces +_C.n_point_per_face = 2000 +_C.n_sample_each = 10000 +_C.preprocess_mesh = False + +_C.regress_2d_feat = False + +_C.is_pc = False + +_C.cut_manifold = False +_C.remesh_demo = False +_C.correspondence_demo = False + +_C.save_every_epoch = 10 +_C.training_epochs = 30 +_C.continue_training = False + +_C.continue_ckpt = None +_C.epoch_selected = "epoch=50.ckpt" + +_C.triplane_resolution = 128 +_C.triplane_channels_low = 128 +_C.triplane_channels_high = 512 +_C.lr = 1e-3 +_C.train = True +_C.test = False + +_C.inference_save_pred_sdf_to_mesh=True +_C.inference_save_feat_pca=True +_C.name = "test" +_C.test_subset = False +_C.test_corres = False +_C.test_partobjaversetiny = False + +_C.dataset = CN() +_C.dataset.type = "Demo_Dataset" +_C.dataset.data_path = "objaverse_data/" +_C.dataset.train_num_workers = 64 +_C.dataset.val_num_workers = 32 +_C.dataset.train_batch_size = 2 +_C.dataset.val_batch_size = 2 +_C.dataset.all_files = [] # only used for correspondence demo + +_C.voxel2triplane = CN() +_C.voxel2triplane.transformer_dim = 1024 +_C.voxel2triplane.transformer_layers = 6 +_C.voxel2triplane.transformer_heads = 8 +_C.voxel2triplane.triplane_low_res = 32 +_C.voxel2triplane.triplane_high_res = 256 +_C.voxel2triplane.triplane_dim = 64 +_C.voxel2triplane.normalize_vox_feat = False + + +_C.loss = CN() +_C.loss.triplet = 0.0 +_C.loss.sdf = 1.0 +_C.loss.feat = 10.0 +_C.loss.l1 = 0.0 + +_C.use_pvcnn = False +_C.use_pvcnnonly = True + +_C.pvcnn = CN() +_C.pvcnn.point_encoder_type = 'pvcnn' +_C.pvcnn.use_point_scatter = True +_C.pvcnn.z_triplane_channels = 64 +_C.pvcnn.z_triplane_resolution = 256 +_C.pvcnn.unet_cfg = CN() +_C.pvcnn.unet_cfg.depth = 3 +_C.pvcnn.unet_cfg.enabled = True +_C.pvcnn.unet_cfg.rolled = True +_C.pvcnn.unet_cfg.use_3d_aware = True +_C.pvcnn.unet_cfg.start_hidden_channels = 32 +_C.pvcnn.unet_cfg.use_initial_conv = False + +_C.use_2d_feat = False +_C.inference_metrics_only = False diff --git a/PartField/partfield/dataloader.py b/PartField/partfield/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..c305c566618c47bc045a76d928e56f497c642812 --- /dev/null +++ b/PartField/partfield/dataloader.py @@ -0,0 +1,366 @@ +import torch +import boto3 +import json +from os import path as osp +# from botocore.config import Config +# from botocore.exceptions import ClientError +import h5py +import io +import numpy as np +import skimage +import trimesh +import os +from scipy.spatial import KDTree +import gc +from plyfile import PlyData + +## For remeshing +import mesh2sdf +import tetgen +import vtk +import math +import tempfile + +### For mesh processing +import pymeshlab + +from partfield.utils import * + +######################### +## To handle quad inputs +######################### +def quad_to_triangle_mesh(F): + """ + Converts a quad-dominant mesh into a pure triangle mesh by splitting quads into two triangles. + + Parameters: + quad_mesh (trimesh.Trimesh): Input mesh with quad faces. + + Returns: + trimesh.Trimesh: A new mesh with only triangle faces. + """ + faces = F + + ### If already a triangle mesh -- skip + if len(faces[0]) == 3: + return F + + new_faces = [] + + for face in faces: + if len(face) == 4: # Quad face + # Split into two triangles + new_faces.append([face[0], face[1], face[2]]) # Triangle 1 + new_faces.append([face[0], face[2], face[3]]) # Triangle 2 + else: + print(f"Warning: Skipping non-triangle/non-quad face {face}") + + new_faces = np.array(new_faces) + + return new_faces +######################### + +class Demo_Dataset(torch.utils.data.Dataset): + def __init__(self, cfg): + super().__init__() + + self.data_path = cfg.dataset.data_path + self.is_pc = cfg.is_pc + + all_files = os.listdir(self.data_path) + + selected = [] + for f in all_files: + if ".ply" in f and self.is_pc: + selected.append(f) + elif (".obj" in f or ".glb" in f or ".off" in f) and not self.is_pc: + selected.append(f) + + self.data_list = selected + self.pc_num_pts = 100000 + + self.preprocess_mesh = cfg.preprocess_mesh + self.result_name = cfg.result_name + + print("val dataset len:", len(self.data_list)) + + + def __len__(self): + return len(self.data_list) + + def load_ply_to_numpy(self, filename): + """ + Load a PLY file and extract the point cloud as a (N, 3) NumPy array. + + Parameters: + filename (str): Path to the PLY file. + + Returns: + numpy.ndarray: Point cloud array of shape (N, 3). + """ + ply_data = PlyData.read(filename) + + # Extract vertex data + vertex_data = ply_data["vertex"] + + # Convert to NumPy array (x, y, z) + points = np.vstack([vertex_data["x"], vertex_data["y"], vertex_data["z"]]).T + + return points + + def get_model(self, ply_file): + + uid = ply_file.split(".")[-2].replace("/", "_") + + #### + if self.is_pc: + ply_file_read = os.path.join(self.data_path, ply_file) + pc = self.load_ply_to_numpy(ply_file_read) + + bbmin = pc.min(0) + bbmax = pc.max(0) + center = (bbmin + bbmax) * 0.5 + scale = 2.0 * 0.9 / (bbmax - bbmin).max() + pc = (pc - center) * scale + + else: + obj_path = os.path.join(self.data_path, ply_file) + mesh = load_mesh_util(obj_path) + vertices = mesh.vertices + faces = mesh.faces + + bbmin = vertices.min(0) + bbmax = vertices.max(0) + center = (bbmin + bbmax) * 0.5 + scale = 2.0 * 0.9 / (bbmax - bbmin).max() + vertices = (vertices - center) * scale + mesh.vertices = vertices + + ### Make sure it is a triangle mesh -- just convert the quad + mesh.faces = quad_to_triangle_mesh(faces) + + print("before preprocessing...") + print(mesh.vertices.shape) + print(mesh.faces.shape) + print() + + ### Pre-process mesh + if self.preprocess_mesh: + # Create a PyMeshLab mesh directly from vertices and faces + ml_mesh = pymeshlab.Mesh(vertex_matrix=mesh.vertices, face_matrix=mesh.faces) + + # Create a MeshSet and add your mesh + ms = pymeshlab.MeshSet() + ms.add_mesh(ml_mesh, "from_trimesh") + + # Apply filters + ms.apply_filter('meshing_remove_duplicate_faces') + ms.apply_filter('meshing_remove_duplicate_vertices') + percentageMerge = pymeshlab.PercentageValue(0.5) + ms.apply_filter('meshing_merge_close_vertices', threshold=percentageMerge) + ms.apply_filter('meshing_remove_unreferenced_vertices') + + # Save or extract mesh + processed = ms.current_mesh() + mesh.vertices = processed.vertex_matrix() + mesh.faces = processed.face_matrix() + + print("after preprocessing...") + print(mesh.vertices.shape) + print(mesh.faces.shape) + + ### Save input + save_dir = f"exp_results/{self.result_name}" + os.makedirs(save_dir, exist_ok=True) + view_id = 0 + mesh.export(f'{save_dir}/input_{uid}_{view_id}.ply') + + + pc, _ = trimesh.sample.sample_surface(mesh, self.pc_num_pts) + + result = { + 'uid': uid + } + + result['pc'] = torch.tensor(pc, dtype=torch.float32) + + if not self.is_pc: + result['vertices'] = mesh.vertices + result['faces'] = mesh.faces + + return result + + def __getitem__(self, index): + + gc.collect() + + return self.get_model(self.data_list[index]) + +############## + +############################### +class Demo_Remesh_Dataset(torch.utils.data.Dataset): + def __init__(self, cfg): + super().__init__() + + self.data_path = cfg.dataset.data_path + + all_files = os.listdir(self.data_path) + + selected = [] + for f in all_files: + if (".obj" in f or ".glb" in f): + selected.append(f) + + self.data_list = selected + self.pc_num_pts = 100000 + + self.preprocess_mesh = cfg.preprocess_mesh + self.result_name = cfg.result_name + + print("val dataset len:", len(self.data_list)) + + + def __len__(self): + return len(self.data_list) + + + def get_model(self, ply_file): + + uid = ply_file.split(".")[-2] + + #### + obj_path = os.path.join(self.data_path, ply_file) + mesh = load_mesh_util(obj_path) + vertices = mesh.vertices + faces = mesh.faces + + bbmin = vertices.min(0) + bbmax = vertices.max(0) + center = (bbmin + bbmax) * 0.5 + scale = 2.0 * 0.9 / (bbmax - bbmin).max() + vertices = (vertices - center) * scale + mesh.vertices = vertices + + ### Pre-process mesh + if self.preprocess_mesh: + # Create a PyMeshLab mesh directly from vertices and faces + ml_mesh = pymeshlab.Mesh(vertex_matrix=mesh.vertices, face_matrix=mesh.faces) + + # Create a MeshSet and add your mesh + ms = pymeshlab.MeshSet() + ms.add_mesh(ml_mesh, "from_trimesh") + + # Apply filters + ms.apply_filter('meshing_remove_duplicate_faces') + ms.apply_filter('meshing_remove_duplicate_vertices') + percentageMerge = pymeshlab.PercentageValue(0.5) + ms.apply_filter('meshing_merge_close_vertices', threshold=percentageMerge) + ms.apply_filter('meshing_remove_unreferenced_vertices') + + + # Save or extract mesh + processed = ms.current_mesh() + mesh.vertices = processed.vertex_matrix() + mesh.faces = processed.face_matrix() + + print("after preprocessing...") + print(mesh.vertices.shape) + print(mesh.faces.shape) + + ### Save input + save_dir = f"exp_results/{self.result_name}" + os.makedirs(save_dir, exist_ok=True) + view_id = 0 + mesh.export(f'{save_dir}/input_{uid}_{view_id}.ply') + + try: + ###### Remesh ###### + size= 256 + level = 2 / size + + sdf = mesh2sdf.core.compute(mesh.vertices, mesh.faces, size) + # NOTE: the negative value is not reliable if the mesh is not watertight + udf = np.abs(sdf) + vertices, faces, _, _ = skimage.measure.marching_cubes(udf, level) + + #### Only use SDF mesh ### + # new_mesh = trimesh.Trimesh(vertices, faces) + ########################## + + #### Make tet ##### + components = trimesh.Trimesh(vertices, faces).split(only_watertight=False) + new_mesh = [] #trimesh.Trimesh() + if len(components) > 100000: + raise NotImplementedError + for i, c in enumerate(components): + c.fix_normals() + new_mesh.append(c) #trimesh.util.concatenate(new_mesh, c) + new_mesh = trimesh.util.concatenate(new_mesh) + + # generate tet mesh + tet = tetgen.TetGen(new_mesh.vertices, new_mesh.faces) + tet.tetrahedralize(plc=True, nobisect=1., quality=True, fixedvolume=True, maxvolume=math.sqrt(2) / 12 * (2 / size) ** 3) + tmp_vtk = tempfile.NamedTemporaryFile(suffix='.vtk', delete=True) + tet.grid.save(tmp_vtk.name) + + # extract surface mesh from tet mesh + reader = vtk.vtkUnstructuredGridReader() + reader.SetFileName(tmp_vtk.name) + reader.Update() + surface_filter = vtk.vtkDataSetSurfaceFilter() + surface_filter.SetInputConnection(reader.GetOutputPort()) + surface_filter.Update() + polydata = surface_filter.GetOutput() + writer = vtk.vtkOBJWriter() + tmp_obj = tempfile.NamedTemporaryFile(suffix='.obj', delete=True) + writer.SetFileName(tmp_obj.name) + writer.SetInputData(polydata) + writer.Update() + new_mesh = load_mesh_util(tmp_obj.name) + ########################## + + new_mesh.vertices = new_mesh.vertices * (2.0 / size) - 1.0 # normalize it to [-1, 1] + + mesh = new_mesh + #################### + + except: + print("Error in tet.") + mesh = mesh + + pc, _ = trimesh.sample.sample_surface(mesh, self.pc_num_pts) + + result = { + 'uid': uid + } + + result['pc'] = torch.tensor(pc, dtype=torch.float32) + result['vertices'] = mesh.vertices + result['faces'] = mesh.faces + + return result + + def __getitem__(self, index): + + gc.collect() + + return self.get_model(self.data_list[index]) + + +class Correspondence_Demo_Dataset(Demo_Dataset): + def __init__(self, cfg): + super().__init__(cfg) + + self.data_path = cfg.dataset.data_path + self.is_pc = cfg.is_pc + + self.data_list = cfg.dataset.all_files + + self.pc_num_pts = 100000 + + self.preprocess_mesh = cfg.preprocess_mesh + self.result_name = cfg.result_name + + print("val dataset len:", len(self.data_list)) + \ No newline at end of file diff --git a/PartField/partfield/model/PVCNN/__pycache__/conv_pointnet.cpython-310.pyc b/PartField/partfield/model/PVCNN/__pycache__/conv_pointnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f911d6454808f1f07087fa153774d6b7b055e2 Binary files /dev/null and b/PartField/partfield/model/PVCNN/__pycache__/conv_pointnet.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/__pycache__/dnnlib_util.cpython-310.pyc b/PartField/partfield/model/PVCNN/__pycache__/dnnlib_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ccab903a2dd05199bef0d4d8bf311dc54135238 Binary files /dev/null and b/PartField/partfield/model/PVCNN/__pycache__/dnnlib_util.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/__pycache__/encoder_pc.cpython-310.pyc b/PartField/partfield/model/PVCNN/__pycache__/encoder_pc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84848f5d541c7e5689de3e2bd05103b620c072d6 Binary files /dev/null and b/PartField/partfield/model/PVCNN/__pycache__/encoder_pc.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/__pycache__/pc_encoder.cpython-310.pyc b/PartField/partfield/model/PVCNN/__pycache__/pc_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61351109ddfb30ee43de8435b6efb3d28709278c Binary files /dev/null and b/PartField/partfield/model/PVCNN/__pycache__/pc_encoder.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/__pycache__/unet_3daware.cpython-310.pyc b/PartField/partfield/model/PVCNN/__pycache__/unet_3daware.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f77fadc20b13eb440d48700dbd4b2fa6592ba830 Binary files /dev/null and b/PartField/partfield/model/PVCNN/__pycache__/unet_3daware.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/conv_pointnet.py b/PartField/partfield/model/PVCNN/conv_pointnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8c5c806f1e725ed9a75aeb752f3e2ae4a5c606a1 --- /dev/null +++ b/PartField/partfield/model/PVCNN/conv_pointnet.py @@ -0,0 +1,251 @@ +""" +Taken from gensdf +https://github.com/princeton-computational-imaging/gensdf +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +# from dnnlib.util import printarr +try: + from torch_scatter import scatter_mean, scatter_max +except: + pass +# from .unet import UNet +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Resnet Blocks +class ResnetBlockFC(nn.Module): + ''' Fully connected ResNet Block class. + Args: + size_in (int): input dimension + size_out (int): output dimension + size_h (int): hidden dimension + ''' + + def __init__(self, size_in, size_out=None, size_h=None): + super().__init__() + # Attributes + if size_out is None: + size_out = size_in + + if size_h is None: + size_h = min(size_in, size_out) + + self.size_in = size_in + self.size_h = size_h + self.size_out = size_out + # Submodules + self.fc_0 = nn.Linear(size_in, size_h) + self.fc_1 = nn.Linear(size_h, size_out) + self.actvn = nn.ReLU() + + if size_in == size_out: + self.shortcut = None + else: + self.shortcut = nn.Linear(size_in, size_out, bias=False) + # Initialization + nn.init.zeros_(self.fc_1.weight) + + def forward(self, x): + net = self.fc_0(self.actvn(x)) + dx = self.fc_1(self.actvn(net)) + + if self.shortcut is not None: + x_s = self.shortcut(x) + else: + x_s = x + + return x_s + dx + + +class ConvPointnet(nn.Module): + ''' PointNet-based encoder network with ResNet blocks for each point. + Number of input points are fixed. + + Args: + c_dim (int): dimension of latent code c + dim (int): input points dimension + hidden_dim (int): hidden dimension of the network + scatter_type (str): feature aggregation when doing local pooling + unet (bool): weather to use U-Net + unet_kwargs (str): U-Net parameters + plane_resolution (int): defined resolution for plane feature + plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + n_blocks (int): number of blocks ResNetBlockFC layers + ''' + + def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', + # unet=False, unet_kwargs=None, + plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5): + super().__init__() + self.c_dim = c_dim + + self.fc_pos = nn.Linear(dim, 2*hidden_dim) + self.blocks = nn.ModuleList([ + ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks) + ]) + self.fc_c = nn.Linear(hidden_dim, c_dim) + + self.actvn = nn.ReLU() + self.hidden_dim = hidden_dim + + # if unet: + # self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs) + # else: + # self.unet = None + + self.reso_plane = plane_resolution + self.plane_type = plane_type + self.padding = padding + + if scatter_type == 'max': + self.scatter = scatter_max + elif scatter_type == 'mean': + self.scatter = scatter_mean + + + # takes in "p": point cloud and "query": sdf_xyz + # sample plane features for unlabeled_query as well + def forward(self, p):#, query2): + batch_size, T, D = p.size() + + # acquire the index for each point + coord = {} + index = {} + if 'xz' in self.plane_type: + coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding) + index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane) + if 'xy' in self.plane_type: + coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding) + index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane) + if 'yz' in self.plane_type: + coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding) + index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane) + + + net = self.fc_pos(p) + + net = self.blocks[0](net) + for block in self.blocks[1:]: + pooled = self.pool_local(coord, index, net) + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + + fea = {} + plane_feat_sum = 0 + #second_sum = 0 + if 'xz' in self.plane_type: + fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64) + # plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz') + #second_sum += self.sample_plane_feature(query2, fea['xz'], 'xz') + if 'xy' in self.plane_type: + fea['xy'] = self.generate_plane_features(p, c, plane='xy') + # plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy') + #second_sum += self.sample_plane_feature(query2, fea['xy'], 'xy') + if 'yz' in self.plane_type: + fea['yz'] = self.generate_plane_features(p, c, plane='yz') + # plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz') + #second_sum += self.sample_plane_feature(query2, fea['yz'], 'yz') + return fea + + # return plane_feat_sum.transpose(2,1)#, second_sum.transpose(2,1) + + + def normalize_coordinate(self, p, padding=0.1, plane='xz'): + ''' Normalize coordinate to [0, 1] for unit cube experiments + + Args: + p (tensor): point + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + plane (str): plane feature type, ['xz', 'xy', 'yz'] + ''' + if plane == 'xz': + xy = p[:, :, [0, 2]] + elif plane =='xy': + xy = p[:, :, [0, 1]] + else: + xy = p[:, :, [1, 2]] + + xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) + xy_new = xy_new + 0.5 # range (0, 1) + + # f there are outliers out of the range + if xy_new.max() >= 1: + xy_new[xy_new >= 1] = 1 - 10e-6 + if xy_new.min() < 0: + xy_new[xy_new < 0] = 0.0 + return xy_new + + + def coordinate2index(self, x, reso): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model + + Args: + x (tensor): coordinate + reso (int): defined resolution + coord_type (str): coordinate type + ''' + x = (x * reso).long() + index = x[:, :, 0] + reso * x[:, :, 1] + index = index[:, None, :] + return index + + + # xy is the normalized coordinates of the point cloud of each plane + # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input + def pool_local(self, xy, index, c): + bs, fea_dim = c.size(0), c.size(2) + keys = xy.keys() + + c_out = 0 + for key in keys: + # scatter plane features from points + fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2) + if self.scatter == scatter_max: + fea = fea[0] + # gather feature back to points + fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) + c_out += fea + return c_out.permute(0, 2, 1) + + + def generate_plane_features(self, p, c, plane='xz'): + # acquire indices of features in plane + xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1) + index = self.coordinate2index(xy, self.reso_plane) + + # scatter plane features from points + fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2) + c = c.permute(0, 2, 1) # B x 512 x T + fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 + fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso) + + # printarr(fea_plane, c, p, xy, index) + # import pdb; pdb.set_trace() + + # process the plane features with UNet + # if self.unet is not None: + # fea_plane = self.unet(fea_plane) + + return fea_plane + + + # sample_plane_feature function copied from /src/conv_onet/models/decoder.py + # uses values from plane_feature and pixel locations from vgrid to interpolate feature + def sample_plane_feature(self, query, plane_feature, plane): + xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding) + xy = xy[:, :, None].float() + vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) + sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1) + return sampled_feat + + + \ No newline at end of file diff --git a/PartField/partfield/model/PVCNN/dnnlib_util.py b/PartField/partfield/model/PVCNN/dnnlib_util.py new file mode 100644 index 0000000000000000000000000000000000000000..9514fe685275a66fc83bf78fb0cf3c94952678dd --- /dev/null +++ b/PartField/partfield/model/PVCNN/dnnlib_util.py @@ -0,0 +1,1074 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +"""Miscellaneous utility classes and functions.""" +from collections import namedtuple +import time +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import json +import os +import shutil +import sys +import types +import io +import pickle +import re +# import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid +import boto3 +import threading +from contextlib import ContextDecorator +from contextlib import contextmanager, nullcontext + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union +import importlib +from loguru import logger +# import wandb +import torch +import psutil +import subprocess + +import random +import string +import pdb + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def format_time_brief(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) + +# ------------------------------------------------------------------------------------------ +# util function modified from https://github.com/nv-tlabs/LION/blob/0467d2199076e95a7e88bafd99dcd7d48a04b4a7/utils/model_helper.py +def import_class(model_str): + from torch_utils.dist_utils import is_rank0 + if is_rank0(): + logger.info('import: {}', model_str) + p, m = model_str.rsplit('.', 1) + mod = importlib.import_module(p) + Model = getattr(mod, m) + return Model + +class ScopedTorchProfiler(ContextDecorator): + """ + Marks ranges for both nvtx profiling (with nsys) and torch autograd profiler + """ + __global_counts = {} + enabled=False + + def __init__(self, unique_name: str): + """ + Names must be unique! + """ + ScopedTorchProfiler.__global_counts[unique_name] = 0 + self._name = unique_name + self._autograd_scope = torch.profiler.record_function(unique_name) + + def __enter__(self): + if ScopedTorchProfiler.enabled: + torch.cuda.nvtx.range_push(self._name) + self._autograd_scope.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + self._autograd_scope.__exit__(exc_type, exc_value, traceback) + if ScopedTorchProfiler.enabled: + torch.cuda.nvtx.range_pop() + +class TimingsMonitor(): + CUDATimer = namedtuple('CUDATimer', ['start', 'end']) + def __init__(self, device, enabled=True, timing_names:List[str]=[], cuda_timing_names:List[str]=[]): + """ + Usage: + tmonitor = TimingsMonitor(device) + for i in range(n_iter): + # Record arbitrary scopes + with tmonitor.timing_scope('regular_scope_name'): + ... + with tmonitor.cuda_timing_scope('nested_scope_name'): + ... + with tmonitor.cuda_timing_scope('cuda_scope_name'): + ... + tmonitor.record_timing('duration_name', end_time - start_time) + + # Gather timings + tmonitor.record_all_cuda_timings() + tmonitor.update_all_averages() + averages = tmonitor.get_average_timings() + all_timings = tmonitor.get_timings() + + Two types of timers, standard report timing and cuda timings. + Cuda timing supports scoped context manager cuda_event_scope. + Args: + device: device to time on (needed for cuda timers) + # enabled: HACK to only report timings from rank 0, set enabled=(global_rank==0) + timing_names: timings to report optional (will auto add new names) + cuda_timing_names: cuda periods to time optional (will auto add new names) + """ + self.enabled=enabled + self.device = device + + # Normal timing + # self.all_timings_dict = {k:None for k in timing_names + cuda_timing_names} + self.all_timings_dict = {} + self.avg_meter_dict = {} + + # Cuda event timers to measure time spent on pushing data to gpu and on training step + self.cuda_event_timers = {} + + for k in timing_names: + self.add_new_timing(k) + + for k in cuda_timing_names: + self.add_new_cuda_timing(k) + + # Running averages + # self.avg_meter_dict = {k:AverageMeter() for k in self.all_timings_dict} + + def add_new_timing(self, name): + self.avg_meter_dict[name] = AverageMeter() + self.all_timings_dict[name] = None + + def add_new_cuda_timing(self, name): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + self.cuda_event_timers[name] = self.CUDATimer(start=start_event, end=end_event) + self.add_new_timing(name) + + def clear_timings(self): + self.all_timings_dict = {k:None for k in self.all_timings_dict} + + def get_timings(self): + return self.all_timings_dict + + def get_average_timings(self): + return {k:v.avg for k,v in self.avg_meter_dict.items()} + + def update_all_averages(self): + """ + Once per iter, when timings have been finished recording, one should + call update_average_iter to keep running average of timings. + """ + for k,v in self.all_timings_dict.items(): + if v is None: + print("none_timing", k) + continue + self.avg_meter_dict[k].update(v) + + def record_timing(self, name, value): + if name not in self.all_timings_dict: self.add_new_timing(name) + # assert name in self.all_timings_dict + self.all_timings_dict[name] = value + + def _record_cuda_event_start(self, name): + if name in self.cuda_event_timers: + self.cuda_event_timers[name].start.record( + torch.cuda.current_stream(self.device)) + + def _record_cuda_event_end(self, name): + if name in self.cuda_event_timers: + self.cuda_event_timers[name].end.record( + torch.cuda.current_stream(self.device)) + + @contextmanager + def cuda_timing_scope(self, name, profile=True): + if name not in self.all_timings_dict: self.add_new_cuda_timing(name) + with ScopedTorchProfiler(name) if profile else nullcontext(): + self._record_cuda_event_start(name) + try: + yield + finally: + self._record_cuda_event_end(name) + + @contextmanager + def timing_scope(self, name, profile=True): + if name not in self.all_timings_dict: self.add_new_timing(name) + with ScopedTorchProfiler(name) if profile else nullcontext(): + start_time = time.time() + try: + yield + finally: + self.record_timing(name, time.time()-start_time) + + def record_all_cuda_timings(self): + """ After all the cuda events call this to synchronize and record down the cuda timings. """ + for k, events in self.cuda_event_timers.items(): + with torch.no_grad(): + events.end.synchronize() + # Convert to seconds + time_elapsed = events.start.elapsed_time(events.end)/1000. + self.all_timings_dict[k] = time_elapsed + +def init_s3(config_file): + config = json.load(open(config_file, 'r')) + s3_client = boto3.client("s3", **config) + return s3_client + +def download_from_s3(file_path, target_path, cfg): + tic = time.time() + s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init + bucket_name = file_path.split('/')[2] + file_key = file_path.split(bucket_name+'/')[-1] + print(bucket_name, file_key) + s3_client.download_file(bucket_name, file_key, target_path) + logger.info(f'finish download from ! s3://{bucket_name}/{file_key} to {target_path} %.1f sec'%( + time.time() - tic)) + +def upload_to_s3(buffer, bucket_name, key, config_dict): + logger.info(f'start upload_to_s3! bucket_name={bucket_name}, key={key}') + tic = time.time() + s3 = boto3.client('s3', **config_dict) + s3.put_object(Bucket=bucket_name, Key=key, Body=buffer.getvalue()) + logger.info(f'finish upload_to_s3! s3://{bucket_name}/{key} %.1f sec'%(time.time() - tic)) + +def write_ckpt_to_s3(cfg, all_model_dict, ckpt_name): + buffer = io.BytesIO() + tic = time.time() + torch.save(all_model_dict, buffer) # take ~0.25 sec + # logger.info('write ckpt to buffer: %.2f sec'%(time.time() - tic)) + group, name = cfg.outdir.rstrip("/").split("/")[-2:] + key = f"checkpoints/{group}/{name}/ckpt/{ckpt_name}" + bucket_name = cfg.checkpoint.write_s3_bucket + + s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init + + config_dict = json.load(open(cfg.checkpoint.write_s3_config, 'r')) + upload_thread = threading.Thread(target=upload_to_s3, args=(buffer, bucket_name, key, config_dict)) + upload_thread.start() + path = f"s3://{bucket_name}/{key}" + return path + +def upload_file_to_s3(cfg, file_path, key_name=None): + # file_path is the local file path, can be a yaml file + # this function is used to upload the ckecpoint only + tic = time.time() + group, name = cfg.outdir.rstrip("/").split("/")[-2:] + if key_name is None: + key = os.path.basename(file_path) + key = f"checkpoints/{group}/{name}/{key}" + bucket_name = cfg.checkpoint.write_s3_bucket + s3_client = init_s3(cfg.checkpoint.write_s3_config) + # Upload the file + with open(file_path, 'rb') as f: + s3_client.upload_fileobj(f, bucket_name, key) + full_s3_path = f"s3://{bucket_name}/{key}" + logger.info(f'upload_to_s3: {file_path} {full_s3_path} | use time: {time.time()-tic}') + + return full_s3_path + + +def load_from_s3(file_path, cfg, load_fn): + """ + ckpt_path example: + s3://xzeng/checkpoints/2023_0413/vae_kl_5e-1/ckpt/snapshot_epo000163_iter164000.pt + """ + s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init + bucket_name = file_path.split("s3://")[-1].split('/')[0] + key = file_path.split(f'{bucket_name}/')[-1] + # logger.info(f"-> try to load s3://{bucket_name}/{key} ") + tic = time.time() + for attemp in range(10): + try: + # Download the state dict from S3 into memory (as a binary stream) + with io.BytesIO() as buffer: + s3_client.download_fileobj(bucket_name, key, buffer) + buffer.seek(0) + + # Load the state dict into a PyTorch model + # out = torch.load(buffer, map_location=torch.device("cpu")) + out = load_fn(buffer) + break + except: + logger.info(f"fail to load s3://{bucket_name}/{key} attemp: {attemp}") + from torch_utils.dist_utils import is_rank0 + if is_rank0(): + logger.info(f'loaded {file_path} | use time: {time.time()-tic:.1f} sec') + return out + +def load_torch_dict_from_s3(ckpt_path, cfg): + """ + ckpt_path example: + s3://xzeng/checkpoints/2023_0413/vae_kl_5e-1/ckpt/snapshot_epo000163_iter164000.pt + """ + s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init + bucket_name = ckpt_path.split("s3://")[-1].split('/')[0] + key = ckpt_path.split(f'{bucket_name}/')[-1] + for attemp in range(10): + try: + # Download the state dict from S3 into memory (as a binary stream) + with io.BytesIO() as buffer: + s3_client.download_fileobj(bucket_name, key, buffer) + buffer.seek(0) + + # Load the state dict into a PyTorch model + out = torch.load(buffer, map_location=torch.device("cpu")) + break + except: + logger.info(f"fail to load s3://{bucket_name}/{key} attemp: {attemp}") + return out + +def count_parameters_in_M(model): + return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 + +def printarr(*arrs, float_width=6, **kwargs): + """ + Print a pretty table giving name, shape, dtype, type, and content information for input tensors or scalars. + + Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a variable number of arguments. + + Inputs can be: + - Numpy tensor arrays + - Pytorch tensor arrays + - Jax tensor arrays + - Python ints / floats + - None + + It may also work with other array-like types, but they have not been tested. + + Use the `float_width` option specify the precision to which floating point types are printed. + + Author: Nicholas Sharp (nmwsharp.com) + Canonical source: https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233 + License: This snippet may be used under an MIT license, and it is also released into the public domain. + Please retain this docstring as a reference. + """ + + frame = inspect.currentframe().f_back + default_name = "[temporary]" + + ## helpers to gather data about each array + def name_from_outer_scope(a): + if a is None: + return '[None]' + name = default_name + for k, v in frame.f_locals.items(): + if v is a: + name = k + break + return name + + def type_strip(type_str): + return type_str.lstrip('').replace('torch.', '').strip("'") + + def dtype_str(a): + if a is None: + return 'None' + if isinstance(a, int): + return 'int' + if isinstance(a, float): + return 'float' + if isinstance(a, list) and len(a)>0: + return type_strip(str(type(a[0]))) + if hasattr(a, 'dtype'): + return type_strip(str(a.dtype)) + else: + return '' + def shape_str(a): + if a is None: + return 'N/A' + if isinstance(a, int): + return 'scalar' + if isinstance(a, float): + return 'scalar' + if isinstance(a, list): + return f"[{shape_str(a[0]) if len(a)>0 else '?'}]*{len(a)}" + if hasattr(a, 'shape'): + return str(tuple(a.shape)) + else: + return '' + def type_str(a): + return type_strip(str(type(a))) # TODO this is is weird... what's the better way? + def device_str(a): + if hasattr(a, 'device'): + device_str = str(a.device) + if len(device_str) < 10: + # heuristic: jax returns some goofy long string we don't want, ignore it + return device_str + return "" + def format_float(x): + return f"{x:{float_width}g}" + def minmaxmean_str(a): + if a is None: + return ('N/A', 'N/A', 'N/A', 'N/A') + if isinstance(a, int) or isinstance(a, float): + return (format_float(a),)*4 + + # compute min/max/mean. if anything goes wrong, just print 'N/A' + min_str = "N/A" + try: min_str = format_float(a.min()) + except: pass + max_str = "N/A" + try: max_str = format_float(a.max()) + except: pass + mean_str = "N/A" + try: mean_str = format_float(a.mean()) + except: pass + try: median_str = format_float(a.median()) + except: + try: median_str = format_float(np.median(np.array(a))) + except: median_str = 'N/A' + return (min_str, max_str, mean_str, median_str) + + def get_prop_dict(a,k=None): + minmaxmean = minmaxmean_str(a) + props = { + 'name' : name_from_outer_scope(a) if k is None else k, + # 'type' : str(type(a)).replace('torch.',''), + 'dtype' : dtype_str(a), + 'shape' : shape_str(a), + 'type' : type_str(a), + 'device' : device_str(a), + 'min' : minmaxmean[0], + 'max' : minmaxmean[1], + 'mean' : minmaxmean[2], + 'median': minmaxmean[3] + } + return props + + try: + + props = ['name', 'type', 'dtype', 'shape', 'device', 'min', 'max', 'mean', 'median'] + + # precompute all of the properties for each input + str_props = [] + for a in arrs: + str_props.append(get_prop_dict(a)) + for k,a in kwargs.items(): + str_props.append(get_prop_dict(a, k=k)) + + # for each property, compute its length + maxlen = {} + for p in props: maxlen[p] = 0 + for sp in str_props: + for p in props: + maxlen[p] = max(maxlen[p], len(sp[p])) + + # if any property got all empty strings, don't bother printing it, remove if from the list + props = [p for p in props if maxlen[p] > 0] + + # print a header + header_str = "" + for p in props: + prefix = "" if p == 'name' else " | " + fmt_key = ">" if p == 'name' else "<" + header_str += f"{prefix}{p:{fmt_key}{maxlen[p]}}" + print(header_str) + print("-"*len(header_str)) + + # now print the acual arrays + for strp in str_props: + for p in props: + prefix = "" if p == 'name' else " | " + fmt_key = ">" if p == 'name' else "<" + print(f"{prefix}{strp[p]:{fmt_key}{maxlen[p]}}", end='') + print("") + + finally: + del frame + +def debug_print_all_tensor_sizes(min_tot_size = 0): + import gc + print("---------------------------------------"*3) + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + if np.prod(obj.size())>=min_tot_size: + print(type(obj), obj.size()) + except: + pass +def print_cpu_usage(): + + # Get current CPU usage as a percentage + cpu_usage = psutil.cpu_percent() + + # Get current memory usage + memory_usage = psutil.virtual_memory().used + + # Convert memory usage to a human-readable format + memory_usage_str = psutil._common.bytes2human(memory_usage) + + # Print CPU and memory usage + msg = f"Current CPU usage: {cpu_usage}% | " + msg += f"Current memory usage: {memory_usage_str}" + return msg + +def calmsize(num_bytes): + if math.isnan(num_bytes): + return '' + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(num_bytes) < 1024.0: + return "{:.1f}{}B".format(num_bytes, unit) + num_bytes /= 1024.0 + return "{:.1f}{}B".format(num_bytes, 'Y') + +def readable_size(num_bytes: int) -> str: + return calmsize(num_bytes) ## '' if math.isnan(num_bytes) else '{:.1f}'.format(calmsize(num_bytes)) + +def get_gpu_memory(): + """ + Get the current GPU memory usage for each device as a dictionary + """ + output = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.used", "--format=csv"]) + output = output.decode("utf-8") + gpu_memory_values = output.split("\n")[1:-1] + gpu_memory_values = [int(x.strip().split()[0]) for x in gpu_memory_values] + gpu_memory = dict(zip(range(len(gpu_memory_values)), gpu_memory_values)) + return gpu_memory + +def get_gpu_util(): + """ + Get the current GPU memory usage for each device as a dictionary + """ + output = subprocess.check_output(["nvidia-smi", "--query-gpu=utilization.gpu", "--format=csv"]) + output = output.decode("utf-8") + gpu_memory_values = output.split("\n")[1:-1] + gpu_memory_values = [int(x.strip().split()[0]) for x in gpu_memory_values] + gpu_util = dict(zip(range(len(gpu_memory_values)), gpu_memory_values)) + return gpu_util + + +def print_gpu_usage(): + useage = get_gpu_memory() + msg = f" | GPU usage: " + for k, v in useage.items(): + msg += f"{k}: {v} MB " + # utilization = get_gpu_util() + # msg + ' | util ' + # for k, v in utilization.items(): + # msg += f"{k}: {v} % " + return msg + +class AverageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +def generate_random_string(length): + # This script will generate a string of 10 random ASCII letters (both lowercase and uppercase). + # You can adjust the length parameter to fit your needs. + letters = string.ascii_letters + return ''.join(random.choice(letters) for _ in range(length)) + + +class ForkedPdb(pdb.Pdb): + """ + PDB Subclass for debugging multi-processed code + Suggested in: https://stackoverflow.com/questions/4716533/how-to-attach-debugger-to-a-python-subproccess + """ + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + +def check_exist_in_s3(file_path, s3_config): + s3 = init_s3(s3_config) + bucket_name, object_name = s3path_to_bucket_key(file_path) + + try: + s3.head_object(Bucket=bucket_name, Key=object_name) + return 1 + except: + logger.info(f'file not found: s3://{bucket_name}/{object_name}') + return 0 + +def s3path_to_bucket_key(file_path): + bucket_name = file_path.split('/')[2] + object_name = file_path.split(bucket_name + '/')[-1] + return bucket_name, object_name + +def copy_file_to_s3(cfg, file_path_local, file_path_s3): + # work similar as upload_file_to_s3, but not trying to parse the file path + # file_path_s3: s3://{bucket}/{key} + bucket_name, key = s3path_to_bucket_key(file_path_s3) + tic = time.time() + s3_client = init_s3(cfg.checkpoint.write_s3_config) + + # Upload the file + with open(file_path_local, 'rb') as f: + s3_client.upload_fileobj(f, bucket_name, key) + full_s3_path = f"s3://{bucket_name}/{key}" + logger.info(f'copy file: {file_path_local} {full_s3_path} | use time: {time.time()-tic}') + return full_s3_path \ No newline at end of file diff --git a/PartField/partfield/model/PVCNN/encoder_pc.py b/PartField/partfield/model/PVCNN/encoder_pc.py new file mode 100644 index 0000000000000000000000000000000000000000..25a384ec0b7b8ce19d0336f1edd2463b8b3acd4e --- /dev/null +++ b/PartField/partfield/model/PVCNN/encoder_pc.py @@ -0,0 +1,243 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +from ast import Dict +import math + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torch_scatter import scatter_mean #, scatter_max + +from .unet_3daware import setup_unet #UNetTriplane3dAware +from .conv_pointnet import ConvPointnet + +from .pc_encoder import PVCNNEncoder #PointNet + +import einops + +from .dnnlib_util import ScopedTorchProfiler, printarr + +def generate_plane_features(p, c, resolution, plane='xz'): + """ + Args: + p: (B,3,n_p) + c: (B,C,n_p) + """ + padding = 0. + c_dim = c.size(1) + # acquire indices of features in plane + xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) # normalize to the range of (0, 1) + index = coordinate2index(xy, resolution) + + # scatter plane features from points + fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2) + fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 + fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) # sparce matrix (B x 512 x reso x reso) + return fea_plane + +def normalize_coordinate(p, padding=0.1, plane='xz'): + ''' Normalize coordinate to [0, 1] for unit cube experiments + + Args: + p (tensor): point + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + plane (str): plane feature type, ['xz', 'xy', 'yz'] + ''' + if plane == 'xz': + xy = p[:, :, [0, 2]] + elif plane =='xy': + xy = p[:, :, [0, 1]] + else: + xy = p[:, :, [1, 2]] + + xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) + xy_new = xy_new + 0.5 # range (0, 1) + + # if there are outliers out of the range + if xy_new.max() >= 1: + xy_new[xy_new >= 1] = 1 - 10e-6 + if xy_new.min() < 0: + xy_new[xy_new < 0] = 0.0 + return xy_new + + +def coordinate2index(x, resolution): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model + + Args: + x (tensor): coordinate + reso (int): defined resolution + coord_type (str): coordinate type + ''' + x = (x * resolution).long() + index = x[:, :, 0] + resolution * x[:, :, 1] + index = index[:, None, :] + return index + +def softclip(x, min, max, hardness=5): + # Soft clipping for the logsigma + x = min + F.softplus(hardness*(x - min))/hardness + x = max - F.softplus(-hardness*(x - max))/hardness + return x + + +def sample_triplane_feat(feature_triplane, normalized_pos): + ''' + normalized_pos [-1, 1] + ''' + tri_plane = torch.unbind(feature_triplane, dim=1) + + x_feat = F.grid_sample( + tri_plane[0], + torch.cat( + [normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]], + dim=-1).unsqueeze(dim=1), padding_mode='border', + align_corners=True) + y_feat = F.grid_sample( + tri_plane[1], + torch.cat( + [normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]], + dim=-1).unsqueeze(dim=1), padding_mode='border', + align_corners=True) + + z_feat = F.grid_sample( + tri_plane[2], + torch.cat( + [normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]], + dim=-1).unsqueeze(dim=1), padding_mode='border', + align_corners=True) + final_feat = (x_feat + y_feat + z_feat) + final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) # 32dimension + return final_feat + + +# @persistence.persistent_class +class TriPlanePC2Encoder(torch.nn.Module): + # Encoder that encode point cloud to triplane feature vector similar to ConvOccNet + def __init__( + self, + cfg, + device='cuda', + shape_min=-1.0, + shape_length=2.0, + use_2d_feat=False, + # point_encoder='pvcnn', + # use_point_scatter=False + ): + """ + Outputs latent triplane from PC input + Configs: + max_logsigma: (float) Soft clip upper range for logsigm + min_logsigma: (float) + point_encoder_type: (str) one of ['pvcnn', 'pointnet'] + pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel + features (instead of scattering point features) + unet_cfg: (dict) + z_triplane_channels: (int) output latent triplane + z_triplane_resolution: (int) + Args: + + """ + # assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 + super().__init__() + self.device = device + + self.cfg = cfg + + self.shape_min = shape_min + self.shape_length = shape_length + + self.z_triplane_resolution = cfg.z_triplane_resolution + z_triplane_channels = cfg.z_triplane_channels + + point_encoder_out_dim = z_triplane_channels #* 2 + + in_channels = 6 + # self.resample_filter=[1, 3, 3, 1] + if cfg.point_encoder_type == 'pvcnn': + self.pc_encoder = PVCNNEncoder(point_encoder_out_dim, + device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat) # Encode it to a volume vector. + elif cfg.point_encoder_type == 'pointnet': + # TODO the pointnet was buggy, investigate + self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim, + dim=in_channels, hidden_dim=32, + plane_resolution=self.z_triplane_resolution, + padding=0) + else: + raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented") + + if cfg.unet_cfg.enabled: + self.unet_encoder = setup_unet( + output_channels=point_encoder_out_dim, + input_channels=point_encoder_out_dim, + unet_cfg=cfg.unet_cfg) + else: + self.unet_encoder = None + + # @ScopedTorchProfiler('encode') + def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict: + # output = AttrDict() + point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length # [0, 1] + point_cloud_xyz = point_cloud_xyz - 0.5 # [-0.5, 0.5] + point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1) + + if self.cfg.point_encoder_type == 'pvcnn': + if mv_feat is not None: + pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx) + else: + pc_feat, points_feat = self.pc_encoder(point_cloud) # 3D feature volume: BxDx32x32x32 + if self.cfg.use_point_scatter: + # Scattering from PVCNN point features + points_feat_ = points_feat[0] + # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64) + pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_, + resolution=self.z_triplane_resolution, plane='xy') + pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_, + resolution=self.z_triplane_resolution, plane='yz') + pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_, + resolution=self.z_triplane_resolution, plane='xz') + pc_feat = pc_feat[0] + + else: + pc_feat = pc_feat[0] + sf = self.z_triplane_resolution//32 # 32 is PVCNN's voxel dim + + pc_feat_1 = torch.mean(pc_feat, dim=-1) #xy_plane, normalize in z plane + pc_feat_2 = torch.mean(pc_feat, dim=-3) #yz_plane, normalize in x plane + pc_feat_3 = torch.mean(pc_feat, dim=-2) #xz_plane, normalize in y plane + + # nearest upsample + pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf) + pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf) + pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf) + elif self.cfg.point_encoder_type == 'pointnet': + assert self.cfg.use_point_scatter + # Run ConvPointnet + pc_feat = self.pc_encoder(point_cloud) + pc_feat_1 = pc_feat['xy'] # + pc_feat_2 = pc_feat['yz'] + pc_feat_3 = pc_feat['xz'] + else: + raise NotImplementedError() + + if self.unet_encoder is not None: + # TODO eval adding a skip connection + # Unet expects B, 3, C, H, W + pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1) + # dpc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre) + # pc_feat_tri_plane_stack = pc_feat_tri_plane_stack_pre + dpc_feat_tri_plane_stack + pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre) + pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1) + + return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1) + + def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None): + return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx) \ No newline at end of file diff --git a/PartField/partfield/model/PVCNN/pc_encoder.py b/PartField/partfield/model/PVCNN/pc_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..adeaeba96147a24515eb289b09da77bc4716869c --- /dev/null +++ b/PartField/partfield/model/PVCNN/pc_encoder.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import functools + +from .pv_module import SharedMLP, PVConv + +def create_pointnet_components( + blocks, in_channels, with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=False, device='cuda'): + r, vr = width_multiplier, voxel_resolution_multiplier + layers, concat_channels = [], 0 + for out_channels, num_blocks, voxel_resolution in blocks: + out_channels = int(r * out_channels) + if voxel_resolution is None: + block = functools.partial(SharedMLP, device=device) + else: + block = functools.partial( + PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), + with_se=with_se, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn, device=device) + for _ in range(num_blocks): + layers.append(block(in_channels, out_channels)) + in_channels = out_channels + concat_channels += out_channels + return layers, in_channels, concat_channels + +class PCMerger(nn.Module): +# merge surface sampled PC and rendering backprojected PC (w/ 2D features): + def __init__(self, in_channels=204, device="cuda"): + super(PCMerger, self).__init__() + self.mlp_normal = SharedMLP(3, [128, 128], device=device) + self.mlp_rgb = SharedMLP(3, [128, 128], device=device) + self.mlp_sam = SharedMLP(204 - 6, [128, 128], device=device) + + def forward(self, feat, mv_feat, pc2pc_idx): + mv_feat_normal = self.mlp_normal(mv_feat[:, :3, :]) + mv_feat_rgb = self.mlp_rgb(mv_feat[:, 3:6, :]) + mv_feat_sam = self.mlp_sam(mv_feat[:, 6:, :]) + + mv_feat_normal = mv_feat_normal.permute(0, 2, 1) + mv_feat_rgb = mv_feat_rgb.permute(0, 2, 1) + mv_feat_sam = mv_feat_sam.permute(0, 2, 1) + feat = feat.permute(0, 2, 1) + + for i in range(mv_feat.shape[0]): + mask = (pc2pc_idx[i] != -1).reshape(-1) + idx = pc2pc_idx[i][mask].reshape(-1) + feat[i][mask] += mv_feat_normal[i][idx] + mv_feat_rgb[i][idx] + mv_feat_sam[i][idx] + + return feat.permute(0, 2, 1) + + +class PVCNNEncoder(nn.Module): + def __init__(self, pvcnn_feat_dim, device='cuda', in_channels=3, use_2d_feat=False): + super(PVCNNEncoder, self).__init__() + self.device = device + self.blocks = ((pvcnn_feat_dim, 1, 32), (128, 2, 16), (256, 1, 8)) + self.use_2d_feat=use_2d_feat + if in_channels == 6: + self.append_channel = 2 + elif in_channels == 3: + self.append_channel = 1 + else: + raise NotImplementedError + layers, channels_point, concat_channels_point = create_pointnet_components( + blocks=self.blocks, in_channels=in_channels + self.append_channel, with_se=False, normalize=False, + width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=True, + device=device + ) + self.encoder = nn.ModuleList(layers)#.to(self.device) + if self.use_2d_feat: + self.merger = PCMerger() + + + + def forward(self, input_pc, mv_feat=None, pc2pc_idx=None): + features = input_pc.permute(0, 2, 1) * 2 # make point cloud [-1, 1] + coords = features[:, :3, :] + out_features_list = [] + voxel_feature_list = [] + zero_padding = torch.zeros(features.shape[0], self.append_channel, features.shape[-1], device=features.device, dtype=torch.float) + features = torch.cat([features, zero_padding], dim=1)################## + + for i in range(len(self.encoder)): + features, _, voxel_feature = self.encoder[i]((features, coords)) + if i == 0 and mv_feat is not None: + features = self.merger(features, mv_feat.permute(0, 2, 1), pc2pc_idx) + out_features_list.append(features) + voxel_feature_list.append(voxel_feature) + return voxel_feature_list, out_features_list \ No newline at end of file diff --git a/PartField/partfield/model/PVCNN/pv_module/__init__.py b/PartField/partfield/model/PVCNN/pv_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7fd32e598f709503f4e35171e09fbbedec05f9c3 --- /dev/null +++ b/PartField/partfield/model/PVCNN/pv_module/__init__.py @@ -0,0 +1,2 @@ +from .pvconv import PVConv +from .shared_mlp import SharedMLP \ No newline at end of file diff --git a/PartField/partfield/model/PVCNN/pv_module/__pycache__/__init__.cpython-310.pyc b/PartField/partfield/model/PVCNN/pv_module/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..563237c0c86173bffe4e30c55fb1f00c1f04911b Binary files /dev/null and b/PartField/partfield/model/PVCNN/pv_module/__pycache__/__init__.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/pv_module/__pycache__/pvconv.cpython-310.pyc b/PartField/partfield/model/PVCNN/pv_module/__pycache__/pvconv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c510195626351629b6b77aa1d32ab2c3d609cd56 Binary files /dev/null and b/PartField/partfield/model/PVCNN/pv_module/__pycache__/pvconv.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/pv_module/__pycache__/shared_mlp.cpython-310.pyc b/PartField/partfield/model/PVCNN/pv_module/__pycache__/shared_mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a048f980040d1dfbce6eebef4b1dacf5279a0758 Binary files /dev/null and b/PartField/partfield/model/PVCNN/pv_module/__pycache__/shared_mlp.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/pv_module/__pycache__/voxelization.cpython-310.pyc b/PartField/partfield/model/PVCNN/pv_module/__pycache__/voxelization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bd23955af7f91c921f07c1d0a3e2618897f3e1e Binary files /dev/null and b/PartField/partfield/model/PVCNN/pv_module/__pycache__/voxelization.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/pv_module/ball_query.py b/PartField/partfield/model/PVCNN/pv_module/ball_query.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2a8203baa3fc1c94959f1ba852839365bf38ce --- /dev/null +++ b/PartField/partfield/model/PVCNN/pv_module/ball_query.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn + +from . import functional as F + +__all__ = ['BallQuery'] + + +class BallQuery(nn.Module): + def __init__(self, radius, num_neighbors, include_coordinates=True): + super().__init__() + self.radius = radius + self.num_neighbors = num_neighbors + self.include_coordinates = include_coordinates + + def forward(self, points_coords, centers_coords, points_features=None): + points_coords = points_coords.contiguous() + centers_coords = centers_coords.contiguous() + neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors) + neighbor_coordinates = F.grouping(points_coords, neighbor_indices) + neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1) + + if points_features is None: + assert self.include_coordinates, 'No Features For Grouping' + neighbor_features = neighbor_coordinates + else: + neighbor_features = F.grouping(points_features, neighbor_indices) + if self.include_coordinates: + neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1) + return neighbor_features + + def extra_repr(self): + return 'radius={}, num_neighbors={}{}'.format( + self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '') diff --git a/PartField/partfield/model/PVCNN/pv_module/frustum.py b/PartField/partfield/model/PVCNN/pv_module/frustum.py new file mode 100644 index 0000000000000000000000000000000000000000..fb302963a6472f949f4ab69ed42575d79b68b4ea --- /dev/null +++ b/PartField/partfield/model/PVCNN/pv_module/frustum.py @@ -0,0 +1,141 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import functional as PF + +__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d'] + + +class FrustumPointNetLoss(nn.Module): + def __init__( + self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0, + corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0): + super().__init__() + self.box_loss_weight = box_loss_weight + self.corners_loss_weight = corners_loss_weight + self.heading_residual_loss_weight = heading_residual_loss_weight + self.size_residual_loss_weight = size_residual_loss_weight + + self.num_heading_angle_bins = num_heading_angle_bins + self.num_size_templates = num_size_templates + self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3)) + self.register_buffer( + 'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins) + ) + + def forward(self, inputs, targets): + mask_logits = inputs['mask_logits'] # (B, 2, N) + center_reg = inputs['center_reg'] # (B, 3) + center = inputs['center'] # (B, 3) + heading_scores = inputs['heading_scores'] # (B, NH) + heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH) + heading_residuals = inputs['heading_residuals'] # (B, NH) + size_scores = inputs['size_scores'] # (B, NS) + size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3) + size_residuals = inputs['size_residuals'] # (B, NS, 3) + + mask_logits_target = targets['mask_logits'] # (B, N) + center_target = targets['center'] # (B, 3) + heading_bin_id_target = targets['heading_bin_id'] # (B, ) + heading_residual_target = targets['heading_residual'] # (B, ) + size_template_id_target = targets['size_template_id'] # (B, ) + size_residual_target = targets['size_residual'] # (B, 3) + + batch_size = center.size(0) + batch_id = torch.arange(batch_size, device=center.device) + + # Basic Classification and Regression losses + mask_loss = F.cross_entropy(mask_logits, mask_logits_target) + heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target) + size_loss = F.cross_entropy(size_scores, size_template_id_target) + center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0) + center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0) + + # Refinement losses for size/heading + heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, ) + heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins) + heading_residual_normalized_loss = PF.huber_loss( + heading_residuals_normalized - heading_residual_normalized_target, delta=1.0 + ) + size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3) + size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target] + size_residual_normalized_loss = PF.huber_loss( + torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0 + ) + + # Bounding box losses + heading = (heading_residuals[batch_id, heading_bin_id_target] + + self.heading_angle_bin_centers[heading_bin_id_target]) # (B, ) + # Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets) + size = (size_residuals[batch_id, size_template_id_target] + + self.size_templates[size_template_id_target]) # (B, 3) + corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8) + heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, ) + size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3) + corners_target, corners_target_flip = get_box_corners_3d( + centers=center_target, headings=heading_target, + sizes=size_target, with_flip=True) # (B, 3, 8) + corners_loss = PF.huber_loss( + torch.min( + torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1) + ), delta=1.0) + # Summing up + loss = mask_loss + self.box_loss_weight * ( + center_loss + center_reg_loss + heading_loss + size_loss + + self.heading_residual_loss_weight * heading_residual_normalized_loss + + self.size_residual_loss_weight * size_residual_normalized_loss + + self.corners_loss_weight * corners_loss + ) + + return loss + + +def get_box_corners_3d(centers, headings, sizes, with_flip=False): + """ + :param centers: coords of box centers, FloatTensor[N, 3] + :param headings: heading angles, FloatTensor[N, ] + :param sizes: box sizes, FloatTensor[N, 3] + :param with_flip: bool, whether to return flipped box (headings + np.pi) + :return: + coords of box corners, FloatTensor[N, 3, 8] + NOTE: corner points are in counter clockwise order, e.g., + 2--1 + 3--0 5 + 7--4 + """ + l = sizes[:, 0] # (N,) + w = sizes[:, 1] # (N,) + h = sizes[:, 2] # (N,) + x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8) + y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8) + z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8) + + c = torch.cos(headings) # (N,) + s = torch.sin(headings) # (N,) + o = torch.ones_like(headings) # (N,) + z = torch.zeros_like(headings) # (N,) + + centers = centers.unsqueeze(-1) # (B, 3, 1) + corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8) + R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3) + if with_flip: + R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3) + return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers + else: + return torch.matmul(R, corners) + centers + + # centers = centers.unsqueeze(1) # (B, 1, 3) + # corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3) + # RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3) + # if with_flip: + # RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3) + # return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3) + # else: + # return torch.matmul(corners, RT) + centers # (N, 8, 3) + + # corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8) + # R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3) + # corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8) + # corners = corners.transpose(1, 2) # (N, 8, 3) diff --git a/PartField/partfield/model/PVCNN/pv_module/functional/__init__.py b/PartField/partfield/model/PVCNN/pv_module/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..993d1d12511dce369d781b60be75e79a71762e47 --- /dev/null +++ b/PartField/partfield/model/PVCNN/pv_module/functional/__init__.py @@ -0,0 +1 @@ +from .devoxelization import trilinear_devoxelize diff --git a/PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/__init__.cpython-310.pyc b/PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07df15a976cf620a0cd1f5887c4e596b09498dd4 Binary files /dev/null and b/PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/__init__.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/devoxelization.cpython-310.pyc b/PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/devoxelization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a785cde458fca2e04cb85d27106e3d4254febe1 Binary files /dev/null and b/PartField/partfield/model/PVCNN/pv_module/functional/__pycache__/devoxelization.cpython-310.pyc differ diff --git a/PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py b/PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py new file mode 100644 index 0000000000000000000000000000000000000000..c60dab12d804ec3b41b53f7da3eecb20917077fc --- /dev/null +++ b/PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py @@ -0,0 +1,12 @@ +from torch.autograd import Function +import torch +import torch.nn.functional as F + +__all__ = ['trilinear_devoxelize'] + +def trilinear_devoxelize(c, coords, r, training=None): + coords = (coords * 2 + 1.0) / r - 1.0 + coords = coords.permute(0, 2, 1).reshape(c.shape[0], 1, 1, -1, 3) + f = F.grid_sample(input=c, grid=coords, padding_mode='border', align_corners=False) + f = f.squeeze(dim=2).squeeze(dim=2) + return f diff --git a/PartField/partfield/model/PVCNN/pv_module/loss.py b/PartField/partfield/model/PVCNN/pv_module/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a35cdd8a0fe83c8ca6b1d7040b66d142e76471df --- /dev/null +++ b/PartField/partfield/model/PVCNN/pv_module/loss.py @@ -0,0 +1,10 @@ +import torch.nn as nn + +from . import functional as F + +__all__ = ['KLLoss'] + + +class KLLoss(nn.Module): + def forward(self, x, y): + return F.kl_loss(x, y) diff --git a/PartField/partfield/model/PVCNN/pv_module/pointnet.py b/PartField/partfield/model/PVCNN/pv_module/pointnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e58e01cc2f84925d4817a8a04aefdbc3d36d484e --- /dev/null +++ b/PartField/partfield/model/PVCNN/pv_module/pointnet.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn + +from . import functional as F +from .ball_query import BallQuery +from .shared_mlp import SharedMLP + +__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule'] + + +class PointNetAModule(nn.Module): + def __init__(self, in_channels, out_channels, include_coordinates=True): + super().__init__() + if not isinstance(out_channels, (list, tuple)): + out_channels = [[out_channels]] + elif not isinstance(out_channels[0], (list, tuple)): + out_channels = [out_channels] + + mlps = [] + total_out_channels = 0 + for _out_channels in out_channels: + mlps.append( + SharedMLP( + in_channels=in_channels + (3 if include_coordinates else 0), + out_channels=_out_channels, dim=1) + ) + total_out_channels += _out_channels[-1] + + self.include_coordinates = include_coordinates + self.out_channels = total_out_channels + self.mlps = nn.ModuleList(mlps) + + def forward(self, inputs): + features, coords = inputs + if self.include_coordinates: + features = torch.cat([features, coords], dim=1) + coords = torch.zeros((coords.size(0), 3, 1), device=coords.device) + if len(self.mlps) > 1: + features_list = [] + for mlp in self.mlps: + features_list.append(mlp(features).max(dim=-1, keepdim=True).values) + return torch.cat(features_list, dim=1), coords + else: + return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords + + def extra_repr(self): + return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}' + + +class PointNetSAModule(nn.Module): + def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True): + super().__init__() + if not isinstance(radius, (list, tuple)): + radius = [radius] + if not isinstance(num_neighbors, (list, tuple)): + num_neighbors = [num_neighbors] * len(radius) + assert len(radius) == len(num_neighbors) + if not isinstance(out_channels, (list, tuple)): + out_channels = [[out_channels]] * len(radius) + elif not isinstance(out_channels[0], (list, tuple)): + out_channels = [out_channels] * len(radius) + assert len(radius) == len(out_channels) + + groupers, mlps = [], [] + total_out_channels = 0 + for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors): + groupers.append( + BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates) + ) + mlps.append( + SharedMLP( + in_channels=in_channels + (3 if include_coordinates else 0), + out_channels=_out_channels, dim=2) + ) + total_out_channels += _out_channels[-1] + + self.num_centers = num_centers + self.out_channels = total_out_channels + self.groupers = nn.ModuleList(groupers) + self.mlps = nn.ModuleList(mlps) + + def forward(self, inputs): + features, coords = inputs + centers_coords = F.furthest_point_sample(coords, self.num_centers) + features_list = [] + for grouper, mlp in zip(self.groupers, self.mlps): + features_list.append(mlp(grouper(coords, centers_coords, features)).max(dim=-1).values) + if len(features_list) > 1: + return torch.cat(features_list, dim=1), centers_coords + else: + return features_list[0], centers_coords + + def extra_repr(self): + return f'num_centers={self.num_centers}, out_channels={self.out_channels}' + + +class PointNetFPModule(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1) + + def forward(self, inputs): + if len(inputs) == 3: + points_coords, centers_coords, centers_features = inputs + points_features = None + else: + points_coords, centers_coords, centers_features, points_features = inputs + interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features) + if points_features is not None: + interpolated_features = torch.cat( + [interpolated_features, points_features], dim=1 + ) + return self.mlp(interpolated_features), points_coords diff --git a/PartField/partfield/model/PVCNN/pv_module/pvconv.py b/PartField/partfield/model/PVCNN/pv_module/pvconv.py new file mode 100644 index 0000000000000000000000000000000000000000..a64705da194cf2d32ff641025fad7b92d71dc67b --- /dev/null +++ b/PartField/partfield/model/PVCNN/pv_module/pvconv.py @@ -0,0 +1,38 @@ +import torch.nn as nn + +from . import functional as F +from .voxelization import Voxelization +from .shared_mlp import SharedMLP +import torch + +__all__ = ['PVConv'] + + +class PVConv(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, resolution, with_se=False, normalize=True, eps=0, scale_pvcnn=False, + device='cuda'): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.resolution = resolution + self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn) + voxel_layers = [ + nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, device=device), + nn.InstanceNorm3d(out_channels, eps=1e-4, device=device), + nn.LeakyReLU(0.1, True), + nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, device=device), + nn.InstanceNorm3d(out_channels, eps=1e-4, device=device), + nn.LeakyReLU(0.1, True), + ] + self.voxel_layers = nn.Sequential(*voxel_layers) + self.point_features = SharedMLP(in_channels, out_channels, device=device) + + def forward(self, inputs): + features, coords = inputs + voxel_features, voxel_coords = self.voxelization(features, coords) + voxel_features = self.voxel_layers(voxel_features) + devoxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training) + fused_features = devoxel_features + self.point_features(features) + return fused_features, coords, voxel_features diff --git a/PartField/partfield/model/PVCNN/pv_module/shared_mlp.py b/PartField/partfield/model/PVCNN/pv_module/shared_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d4ff864c05b894194ef11ac4b629ec72c4952b --- /dev/null +++ b/PartField/partfield/model/PVCNN/pv_module/shared_mlp.py @@ -0,0 +1,35 @@ +import torch.nn as nn + +__all__ = ['SharedMLP'] + + +class SharedMLP(nn.Module): + def __init__(self, in_channels, out_channels, dim=1, device='cuda'): + super().__init__() + # print('==> SharedMLP device: ', device) + if dim == 1: + conv = nn.Conv1d + bn = nn.InstanceNorm1d + elif dim == 2: + conv = nn.Conv2d + bn = nn.InstanceNorm1d + else: + raise ValueError + if not isinstance(out_channels, (list, tuple)): + out_channels = [out_channels] + layers = [] + for oc in out_channels: + layers.extend( + [ + conv(in_channels, oc, 1, device=device), + bn(oc, device=device), + nn.ReLU(True), + ]) + in_channels = oc + self.layers = nn.Sequential(*layers) + + def forward(self, inputs): + if isinstance(inputs, (list, tuple)): + return (self.layers(inputs[0]), *inputs[1:]) + else: + return self.layers(inputs) diff --git a/PartField/partfield/model/PVCNN/pv_module/voxelization.py b/PartField/partfield/model/PVCNN/pv_module/voxelization.py new file mode 100644 index 0000000000000000000000000000000000000000..15535abc63f7adebcd12bc436e84710c7c9862d2 --- /dev/null +++ b/PartField/partfield/model/PVCNN/pv_module/voxelization.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +from . import functional as F + +__all__ = ['Voxelization'] + + +def my_voxelization(features, coords, resolution): + b, c, _ = features.shape + result = torch.zeros(b, c + 1, resolution * resolution * resolution, device=features.device, dtype=torch.float) + r = resolution + r2 = resolution * resolution + indices = coords[:, 0] * r2 + coords[:, 1] * r + coords[:, 2] + indices = indices.unsqueeze(dim=1).expand(-1, result.shape[1], -1) + features = torch.cat([features, torch.ones(features.shape[0], 1, features.shape[2], device=features.device, dtype=features.dtype)], dim=1) + out_feature = result.scatter_(index=indices.long(), src=features, dim=2, reduce='add') + cnt = out_feature[:, -1:, :] + zero_mask = (cnt == 0).float() + cnt = cnt * (1 - zero_mask) + zero_mask * 1e-5 + vox_feature = out_feature[:, :-1, :] / cnt + return vox_feature.view(b, c, resolution, resolution, resolution) + +class Voxelization(nn.Module): + def __init__(self, resolution, normalize=True, eps=0, scale_pvcnn=False): + super().__init__() + self.r = int(resolution) + self.normalize = normalize + self.eps = eps + self.scale_pvcnn = scale_pvcnn + assert not normalize + + def forward(self, features, coords): + with torch.no_grad(): + coords = coords.detach() + + if self.normalize: + norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5 + else: + if self.scale_pvcnn: + norm_coords = (coords + 1) / 2.0 # [0, 1] + else: + norm_coords = (norm_coords + 1) / 2.0 + norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1) + vox_coords = torch.round(norm_coords) + new_vox_feat = my_voxelization(features, vox_coords, self.r) + return new_vox_feat, norm_coords + + def extra_repr(self): + return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '') diff --git a/PartField/partfield/model/PVCNN/unet_3daware.py b/PartField/partfield/model/PVCNN/unet_3daware.py new file mode 100644 index 0000000000000000000000000000000000000000..b0084f0c1d6989ae4ad103f364401c2f2bd5e361 --- /dev/null +++ b/PartField/partfield/model/PVCNN/unet_3daware.py @@ -0,0 +1,427 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init + +import einops + +def conv3x3(in_channels, out_channels, stride=1, + padding=1, bias=True, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=bias, + groups=groups) + +def upconv2x2(in_channels, out_channels, mode='transpose'): + if mode == 'transpose': + return nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=2, + stride=2) + else: + # out_channels is always going to be the same + # as in_channels + return nn.Sequential( + nn.Upsample(mode='bilinear', scale_factor=2), + conv1x1(in_channels, out_channels)) + +def conv1x1(in_channels, out_channels, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + groups=groups, + stride=1) + +class ConvTriplane3dAware(nn.Module): + """ 3D aware triplane conv (as described in RODIN) """ + def __init__(self, internal_conv_f, in_channels, out_channels, order='xz'): + """ + Args: + internal_conv_f: function that should return a 2D convolution Module + given in and out channels + order: if triplane input is in 'xz' order + """ + super(ConvTriplane3dAware, self).__init__() + # Need 3 seperate convolutions + self.in_channels = in_channels + self.out_channels = out_channels + assert order in ['xz', 'zx'] + self.order = order + # Going to stack from other planes + self.plane_convs = nn.ModuleList([ + internal_conv_f(3*self.in_channels, self.out_channels) for _ in range(3)]) + + def forward(self, triplanes_list): + """ + Args: + triplanes_list: [(B,Ci,H,W)]*3 in xy,yz,(zx or xz) depending on order + Returns: + out_triplanes_list: [(B,Co,H,W)]*3 in xy,yz,(zx or xz) depending on order + """ + inps = list(triplanes_list) + xp = 1 #(yz) + yp = 2 #(zx) + zp = 0 #(xy) + + if self.order == 'xz': + # get into zx order + inps[yp] = einops.rearrange(inps[yp], 'b c x z -> b c z x') + + + oplanes = [None]*3 + # order shouldn't matter + for iplane in [zp, xp, yp]: + # i_plane -> (j,k) + + # need to average out i and convert to (j,k) + # j_plane -> (k,i) + # k_plane -> (i,j) + jplane = (iplane+1)%3 + kplane = (iplane+2)%3 + + ifeat = inps[iplane] + # need to average out nonshared dim + # Average pool across + + # j_plane -> (k,i) -> (k,1) -> (1,k) -> (j,k) + # b c k i -> b c k 1 + jpool = torch.mean(inps[jplane], dim=3 ,keepdim=True) + jpool = einops.rearrange(jpool, 'b c k 1 -> b c 1 k') + jpool = einops.repeat(jpool, 'b c 1 k -> b c j k', j=ifeat.size(2)) + + # k_plane -> (i,j) -> (1,j) -> (j,1) -> (j,k) + # b c i j -> b c 1 j + kpool = torch.mean(inps[kplane], dim=2 ,keepdim=True) + kpool = einops.rearrange(kpool, 'b c 1 j -> b c j 1') + kpool = einops.repeat(kpool, 'b c j 1 -> b c j k', k=ifeat.size(3)) + + # b c h w + # jpool = jpool.expand_as(ifeat) + # kpool = kpool.expand_as(ifeat) + + # concat and conv on feature dim + catfeat = torch.cat([ifeat, jpool, kpool], dim=1) + oplane = self.plane_convs[iplane](catfeat) + oplanes[iplane] = oplane + + if self.order == 'xz': + # get back into xz order + oplanes[yp] = einops.rearrange(oplanes[yp], 'b c z x -> b c x z') + + return oplanes + +def roll_triplanes(triplanes_list): + # B, C, tri, h, w + tristack = torch.stack((triplanes_list),dim=2) + return einops.rearrange(tristack, 'b c tri h w -> b c (tri h) w', tri=3) + +def unroll_triplanes(rolled_triplane): + # B, C, tri*h, w + tristack = einops.rearrange(rolled_triplane, 'b c (tri h) w -> b c tri h w', tri=3) + return torch.unbind(tristack, dim=2) + +def conv1x1triplane3daware(in_channels, out_channels, order='xz', **kwargs): + return ConvTriplane3dAware(lambda inp, out: conv1x1(inp,out,**kwargs), + in_channels, out_channels,order=order) + +def Normalize(in_channels, num_groups=32): + num_groups = min(in_channels, num_groups) # avoid error if in_channels < 32 + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + +def nonlinearity(x): + # return F.relu(x) + # Swish + return x*torch.sigmoid(x) + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class ResnetBlock3dAware(nn.Module): + def __init__(self, in_channels, out_channels=None): + #, conv_shortcut=False): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + # self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = conv3x3(self.in_channels, self.out_channels) + + self.norm_mid = Normalize(out_channels) + self.conv_3daware = conv1x1triplane3daware(self.out_channels, self.out_channels) + + self.norm2 = Normalize(out_channels) + self.conv2 = conv3x3(self.out_channels, self.out_channels) + + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + # 3x3 plane comm + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + # 1x1 3d aware, crossplane comm + h = self.norm_mid(h) + h = nonlinearity(h) + h = unroll_triplanes(h) + h = self.conv_3daware(h) + h = roll_triplanes(h) + + # 3x3 plane comm + h = self.norm2(h) + h = nonlinearity(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x+h + +class DownConv3dAware(nn.Module): + """ + A helper Module that performs 2 convolutions and 1 MaxPool. + A ReLU activation follows each convolution. + """ + def __init__(self, in_channels, out_channels, downsample=True, with_conv=False): + super(DownConv3dAware, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + self.block = ResnetBlock3dAware(in_channels=in_channels, + out_channels=out_channels) + + self.do_downsample = downsample + self.downsample = Downsample(out_channels, with_conv=with_conv) + + def forward(self, x): + """ + rolled input, rolled output + Args: + x: rolled (b c (tri*h) w) + """ + x = self.block(x) + before_pool = x + # if self.pooling: + # x = self.pool(x) + if self.do_downsample: + # unroll and cat channel-wise (to prevent pooling across triplane boundaries) + x = einops.rearrange(x, 'b c (tri h) w -> b (c tri) h w', tri=3) + x = self.downsample(x) + # undo + x = einops.rearrange(x, 'b (c tri) h w -> b c (tri h) w', tri=3) + return x, before_pool + +class UpConv3dAware(nn.Module): + """ + A helper Module that performs 2 convolutions and 1 UpConvolution. + A ReLU activation follows each convolution. + """ + def __init__(self, in_channels, out_channels, + merge_mode='concat', with_conv=False): #up_mode='transpose', ): + super(UpConv3dAware, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.merge_mode = merge_mode + + self.upsample = Upsample(in_channels, with_conv) + + if self.merge_mode == 'concat': + self.norm1 = Normalize(in_channels+out_channels) + self.block = ResnetBlock3dAware(in_channels=in_channels+out_channels, + out_channels=out_channels) + else: + self.norm1 = Normalize(in_channels) + self.block = ResnetBlock3dAware(in_channels=in_channels, + out_channels=out_channels) + + + def forward(self, from_down, from_up): + """ Forward pass + rolled inputs, rolled output + rolled (b c (tri*h) w) + Arguments: + from_down: tensor from the encoder pathway + from_up: upconv'd tensor from the decoder pathway + """ + # from_up = self.upconv(from_up) + from_up = self.upsample(from_up) + if self.merge_mode == 'concat': + x = torch.cat((from_up, from_down), 1) + else: + x = from_up + from_down + + x = self.norm1(x) + x = self.block(x) + return x + +class UNetTriplane3dAware(nn.Module): + def __init__(self, out_channels, in_channels=3, depth=5, + start_filts=64,# up_mode='transpose', + use_initial_conv=False, + merge_mode='concat', **kwargs): + """ + Arguments: + in_channels: int, number of channels in the input tensor. + Default is 3 for RGB images. + depth: int, number of MaxPools in the U-Net. + start_filts: int, number of convolutional filters for the + first conv. + """ + super(UNetTriplane3dAware, self).__init__() + + + self.out_channels = out_channels + self.in_channels = in_channels + self.start_filts = start_filts + self.depth = depth + + self.use_initial_conv = use_initial_conv + if use_initial_conv: + self.conv_initial = conv1x1(self.in_channels, self.start_filts) + + self.down_convs = [] + self.up_convs = [] + + # create the encoder pathway and add to a list + for i in range(depth): + if i == 0: + ins = self.start_filts if use_initial_conv else self.in_channels + else: + ins = outs + outs = self.start_filts*(2**i) + downsamp_it = True if i < depth-1 else False + + down_conv = DownConv3dAware(ins, outs, downsample = downsamp_it) + self.down_convs.append(down_conv) + + for i in range(depth-1): + ins = outs + outs = ins // 2 + up_conv = UpConv3dAware(ins, outs, + merge_mode=merge_mode) + self.up_convs.append(up_conv) + + # add the list of modules to current module + self.down_convs = nn.ModuleList(self.down_convs) + self.up_convs = nn.ModuleList(self.up_convs) + + self.norm_out = Normalize(outs) + self.conv_final = conv1x1(outs, self.out_channels) + + self.reset_params() + + @staticmethod + def weight_init(m): + if isinstance(m, nn.Conv2d): + # init.xavier_normal_(m.weight, gain=0.1) + init.xavier_normal_(m.weight) + init.constant_(m.bias, 0) + + + def reset_params(self): + for i, m in enumerate(self.modules()): + self.weight_init(m) + + + def forward(self, x): + """ + Args: + x: Stacked triplane expected to be in (B,3,C,H,W) + """ + # Roll + x = einops.rearrange(x, 'b tri c h w -> b c (tri h) w', tri=3) + + if self.use_initial_conv: + x = self.conv_initial(x) + + encoder_outs = [] + # encoder pathway, save outputs for merging + for i, module in enumerate(self.down_convs): + x, before_pool = module(x) + encoder_outs.append(before_pool) + + # Spend a block in the middle + # x = self.block_mid(x) + + for i, module in enumerate(self.up_convs): + before_pool = encoder_outs[-(i+2)] + x = module(before_pool, x) + + x = self.norm_out(x) + + # No softmax is used. This means you need to use + # nn.CrossEntropyLoss is your training script, + # as this module includes a softmax already. + x = self.conv_final(nonlinearity(x)) + + # Unroll + x = einops.rearrange(x, 'b c (tri h) w -> b tri c h w', tri=3) + return x + + +def setup_unet(output_channels, input_channels, unet_cfg): + if unet_cfg['use_3d_aware']: + assert(unet_cfg['rolled']) + unet = UNetTriplane3dAware( + out_channels=output_channels, + in_channels=input_channels, + depth=unet_cfg['depth'], + use_initial_conv=unet_cfg['use_initial_conv'], + start_filts=unet_cfg['start_hidden_channels'],) + else: + raise NotImplementedError + return unet + diff --git a/PartField/partfield/model/UNet/__pycache__/buildingblocks.cpython-310.pyc b/PartField/partfield/model/UNet/__pycache__/buildingblocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ce9edda65d7543e79b1812ed14875c6d94d2f6f Binary files /dev/null and b/PartField/partfield/model/UNet/__pycache__/buildingblocks.cpython-310.pyc differ diff --git a/PartField/partfield/model/UNet/__pycache__/model.cpython-310.pyc b/PartField/partfield/model/UNet/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1eac5be40e8fd470e505a9c6033ad9801f769af4 Binary files /dev/null and b/PartField/partfield/model/UNet/__pycache__/model.cpython-310.pyc differ diff --git a/PartField/partfield/model/UNet/buildingblocks.py b/PartField/partfield/model/UNet/buildingblocks.py new file mode 100644 index 0000000000000000000000000000000000000000..e97f501d1813b03555dbec5658d024e06d761443 --- /dev/null +++ b/PartField/partfield/model/UNet/buildingblocks.py @@ -0,0 +1,546 @@ +#https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/buildingblocks.py +# MIT License + +# Copyright (c) 2018 Adrian Wolny + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from functools import partial + +import torch +from torch import nn as nn +from torch.nn import functional as F + +# from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D + + +def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding, + dropout_prob, is3d): + """ + Create a list of modules with together constitute a single conv layer with non-linearity + and optional batchnorm/groupnorm. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + kernel_size(int or tuple): size of the convolving kernel + order (string): order of things, e.g. + 'cr' -> conv + ReLU + 'gcr' -> groupnorm + conv + ReLU + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + 'bcr' -> batchnorm + conv + ReLU + 'cbrd' -> conv + batchnorm + ReLU + dropout + 'cbrD' -> conv + batchnorm + ReLU + dropout2d + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + dropout_prob (float): dropout probability + is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d + Return: + list of tuple (name, module) + """ + assert 'c' in order, "Conv layer MUST be present" + assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer' + + modules = [] + for i, char in enumerate(order): + if char == 'r': + modules.append(('ReLU', nn.ReLU(inplace=True))) + elif char == 'l': + modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True))) + elif char == 'e': + modules.append(('ELU', nn.ELU(inplace=True))) + elif char == 'c': + # add learnable bias only in the absence of batchnorm/groupnorm + bias = not ('g' in order or 'b' in order) + if is3d: + conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) + else: + conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) + + modules.append(('conv', conv)) + elif char == 'g': + is_before_conv = i < order.index('c') + if is_before_conv: + num_channels = in_channels + else: + num_channels = out_channels + + # use only one group if the given number of groups is greater than the number of channels + if num_channels < num_groups: + num_groups = 1 + + assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}' + modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels))) + elif char == 'b': + is_before_conv = i < order.index('c') + if is3d: + bn = nn.BatchNorm3d + else: + bn = nn.BatchNorm2d + + if is_before_conv: + modules.append(('batchnorm', bn(in_channels))) + else: + modules.append(('batchnorm', bn(out_channels))) + elif char == 'd': + modules.append(('dropout', nn.Dropout(p=dropout_prob))) + elif char == 'D': + modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob))) + else: + raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']") + + return modules + + +class SingleConv(nn.Sequential): + """ + Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order + of operations can be specified via the `order` parameter + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + kernel_size (int or tuple): size of the convolving kernel + order (string): determines the order of layers, e.g. + 'cr' -> conv + ReLU + 'crg' -> conv + ReLU + groupnorm + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding + dropout_prob (float): dropout probability, default 0.1 + is3d (bool): if True use Conv3d, otherwise use Conv2d + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8, + padding=1, dropout_prob=0.1, is3d=True): + super(SingleConv, self).__init__() + + for name, module in create_conv(in_channels, out_channels, kernel_size, order, + num_groups, padding, dropout_prob, is3d): + self.add_module(name, module) + + +class DoubleConv(nn.Sequential): + """ + A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d). + We use (Conv3d+ReLU+GroupNorm3d) by default. + This can be changed however by providing the 'order' argument, e.g. in order + to change to Conv3d+BatchNorm3d+ELU use order='cbe'. + Use padded convolutions to make sure that the output (H_out, W_out) is the same + as (H_in, W_in), so that you don't have to crop in the decoder path. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + encoder (bool): if True we're in the encoder path, otherwise we're in the decoder + kernel_size (int or tuple): size of the convolving kernel + order (string): determines the order of layers, e.g. + 'cr' -> conv + ReLU + 'crg' -> conv + ReLU + groupnorm + 'cl' -> conv + LeakyReLU + 'ce' -> conv + ELU + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 + dropout_prob (float or tuple): dropout probability for each convolution, default 0.1 + is3d (bool): if True use Conv3d instead of Conv2d layers + """ + + def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr', + num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): + super(DoubleConv, self).__init__() + if encoder: + # we're in the encoder path + conv1_in_channels = in_channels + if upscale == 1: + conv1_out_channels = out_channels + else: + conv1_out_channels = out_channels // 2 + if conv1_out_channels < in_channels: + conv1_out_channels = in_channels + conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels + else: + # we're in the decoder path, decrease the number of channels in the 1st convolution + conv1_in_channels, conv1_out_channels = in_channels, out_channels + conv2_in_channels, conv2_out_channels = out_channels, out_channels + + # check if dropout_prob is a tuple and if so + # split it for different dropout probabilities for each convolution. + if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple): + dropout_prob1 = dropout_prob[0] + dropout_prob2 = dropout_prob[1] + else: + dropout_prob1 = dropout_prob2 = dropout_prob + + # conv1 + self.add_module('SingleConv1', + SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups, + padding=padding, dropout_prob=dropout_prob1, is3d=is3d)) + # conv2 + self.add_module('SingleConv2', + SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups, + padding=padding, dropout_prob=dropout_prob2, is3d=is3d)) + + +class ResNetBlock(nn.Module): + """ + Residual block that can be used instead of standard DoubleConv in the Encoder module. + Motivated by: https://arxiv.org/pdf/1706.00120.pdf + + Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs): + super(ResNetBlock, self).__init__() + + if in_channels != out_channels: + # conv1x1 for increasing the number of channels + if is3d: + self.conv1 = nn.Conv3d(in_channels, out_channels, 1) + else: + self.conv1 = nn.Conv2d(in_channels, out_channels, 1) + else: + self.conv1 = nn.Identity() + + self.conv2 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups, + is3d=is3d) + # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual + n_order = order + for c in 'rel': + n_order = n_order.replace(c, '') + self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order, + num_groups=num_groups, is3d=is3d) + + # create non-linearity separately + if 'l' in order: + self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True) + elif 'e' in order: + self.non_linearity = nn.ELU(inplace=True) + else: + self.non_linearity = nn.ReLU(inplace=True) + + def forward(self, x): + # apply first convolution to bring the number of channels to out_channels + residual = self.conv1(x) + + out = self.conv2(x) + out = self.conv3(out) + + out += residual + out = self.non_linearity(out) + + return out + +class Encoder(nn.Module): + """ + A single module from the encoder path consisting of the optional max + pooling layer (one may specify the MaxPool kernel_size to be different + from the standard (2,2,2), e.g. if the volumetric data is anisotropic + (make sure to use complementary scale_factor in the decoder path) followed by + a basic module (DoubleConv or ResNetBlock). + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + conv_kernel_size (int or tuple): size of the convolving kernel + apply_pooling (bool): if True use MaxPool3d before DoubleConv + pool_kernel_size (int or tuple): the size of the window + pool_type (str): pooling layer: 'max' or 'avg' + basic_module(nn.Module): either ResNetBlock or DoubleConv + conv_layer_order (string): determines the order of layers + in `DoubleConv` module. See `DoubleConv` for more info. + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 + dropout_prob (float or tuple): dropout probability, default 0.1 + is3d (bool): use 3d or 2d convolutions/pooling operation + """ + + def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True, + pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr', + num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True): + super(Encoder, self).__init__() + assert pool_type in ['max', 'avg'] + if apply_pooling: + if pool_type == 'max': + if is3d: + self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size) + else: + self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size) + else: + if is3d: + self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size) + else: + self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size) + else: + self.pooling = None + + self.basic_module = basic_module(in_channels, out_channels, + encoder=True, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding, + upscale=upscale, + dropout_prob=dropout_prob, + is3d=is3d) + + def forward(self, x): + if self.pooling is not None: + x = self.pooling(x) + x = self.basic_module(x) + return x + + +class Decoder(nn.Module): + """ + A single module for decoder path consisting of the upsampling layer + (either learned ConvTranspose3d or nearest neighbor interpolation) + followed by a basic module (DoubleConv or ResNetBlock). + + Args: + in_channels (int): number of input channels + out_channels (int): number of output channels + conv_kernel_size (int or tuple): size of the convolving kernel + scale_factor (int or tuple): used as the multiplier for the image H/W/D in + case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation + from the corresponding encoder + basic_module(nn.Module): either ResNetBlock or DoubleConv + conv_layer_order (string): determines the order of layers + in `DoubleConv` module. See `DoubleConv` for more info. + num_groups (int): number of groups for the GroupNorm + padding (int or tuple): add zero-padding added to all three sides of the input + upsample (str): algorithm used for upsampling: + InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area' + TransposeConvUpsampling: 'deconv' + No upsampling: None + Default: 'default' (chooses automatically) + dropout_prob (float or tuple): dropout probability, default 0.1 + """ + + def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv, + conv_layer_order='gcr', num_groups=8, padding=1, upsample='default', + dropout_prob=0.1, is3d=True): + super(Decoder, self).__init__() + + # perform concat joining per default + concat = True + + # don't adapt channels after join operation + adapt_channels = False + + if upsample is not None and upsample != 'none': + if upsample == 'default': + if basic_module == DoubleConv: + upsample = 'nearest' # use nearest neighbor interpolation for upsampling + concat = True # use concat joining + adapt_channels = False # don't adapt channels + elif basic_module == ResNetBlock: #or basic_module == ResNetBlockSE: + upsample = 'deconv' # use deconvolution upsampling + concat = False # use summation joining + adapt_channels = True # adapt channels after joining + + # perform deconvolution upsampling if mode is deconv + if upsample == 'deconv': + self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels, + kernel_size=conv_kernel_size, scale_factor=scale_factor, + is3d=is3d) + else: + self.upsampling = InterpolateUpsampling(mode=upsample) + else: + # no upsampling + self.upsampling = NoUpsampling() + # concat joining + self.joining = partial(self._joining, concat=True) + + # perform joining operation + self.joining = partial(self._joining, concat=concat) + + # adapt the number of in_channels for the ResNetBlock + if adapt_channels is True: + in_channels = out_channels + + self.basic_module = basic_module(in_channels, out_channels, + encoder=False, + kernel_size=conv_kernel_size, + order=conv_layer_order, + num_groups=num_groups, + padding=padding, + dropout_prob=dropout_prob, + is3d=is3d) + + def forward(self, encoder_features, x): + x = self.upsampling(encoder_features=encoder_features, x=x) + x = self.joining(encoder_features, x) + x = self.basic_module(x) + return x + + @staticmethod + def _joining(encoder_features, x, concat): + if concat: + return torch.cat((encoder_features, x), dim=1) + else: + return encoder_features + x + + +def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding, + conv_upscale, dropout_prob, + layer_order, num_groups, pool_kernel_size, is3d): + # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)` + encoders = [] + for i, out_feature_num in enumerate(f_maps): + if i == 0: + # apply conv_coord only in the first encoder if any + encoder = Encoder(in_channels, out_feature_num, + apply_pooling=False, # skip pooling in the firs encoder + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + padding=conv_padding, + upscale=conv_upscale, + dropout_prob=dropout_prob, + is3d=is3d) + else: + encoder = Encoder(f_maps[i - 1], out_feature_num, + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + pool_kernel_size=pool_kernel_size, + padding=conv_padding, + upscale=conv_upscale, + dropout_prob=dropout_prob, + is3d=is3d) + + encoders.append(encoder) + + return nn.ModuleList(encoders) + + +def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order, + num_groups, upsample, dropout_prob, is3d): + # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1` + decoders = [] + reversed_f_maps = list(reversed(f_maps[1:])) + for i in range(len(reversed_f_maps) - 1): + if basic_module == DoubleConv and upsample != 'deconv': + in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1] + else: + in_feature_num = reversed_f_maps[i] + + out_feature_num = reversed_f_maps[i + 1] + + decoder = Decoder(in_feature_num, out_feature_num, + basic_module=basic_module, + conv_layer_order=layer_order, + conv_kernel_size=conv_kernel_size, + num_groups=num_groups, + padding=conv_padding, + upsample=upsample, + dropout_prob=dropout_prob, + is3d=is3d) + decoders.append(decoder) + return nn.ModuleList(decoders) + + +class AbstractUpsampling(nn.Module): + """ + Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either + interpolation or learned transposed convolution. + """ + + def __init__(self, upsample): + super(AbstractUpsampling, self).__init__() + self.upsample = upsample + + def forward(self, encoder_features, x): + # get the spatial dimensions of the output given the encoder_features + output_size = encoder_features.size()[2:] + # upsample the input and return + return self.upsample(x, output_size) + + +class InterpolateUpsampling(AbstractUpsampling): + """ + Args: + mode (str): algorithm used for upsampling: + 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest' + used only if transposed_conv is False + """ + + def __init__(self, mode='nearest'): + upsample = partial(self._interpolate, mode=mode) + super().__init__(upsample) + + @staticmethod + def _interpolate(x, size, mode): + return F.interpolate(x, size=size, mode=mode) + + +class TransposeConvUpsampling(AbstractUpsampling): + """ + Args: + in_channels (int): number of input channels for transposed conv + used only if transposed_conv is True + out_channels (int): number of output channels for transpose conv + used only if transposed_conv is True + kernel_size (int or tuple): size of the convolving kernel + used only if transposed_conv is True + scale_factor (int or tuple): stride of the convolution + used only if transposed_conv is True + is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d + """ + + class Upsample(nn.Module): + """ + Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in + transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary. + """ + + def __init__(self, conv_transposed, is3d): + super().__init__() + self.conv_transposed = conv_transposed + self.is3d = is3d + + def forward(self, x, size): + x = self.conv_transposed(x) + return F.interpolate(x, size=size) + + def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True): + # make sure that the output size reverses the MaxPool3d from the corresponding encoder + if is3d is True: + conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, + stride=scale_factor, padding=1, bias=False) + else: + conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, + stride=scale_factor, padding=1, bias=False) + upsample = self.Upsample(conv_transposed, is3d) + super().__init__(upsample) + + +class NoUpsampling(AbstractUpsampling): + def __init__(self): + super().__init__(self._no_upsampling) + + @staticmethod + def _no_upsampling(x, size): + return x \ No newline at end of file diff --git a/PartField/partfield/model/UNet/model.py b/PartField/partfield/model/UNet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..db20b2f5de3d37a52f7465450f915e003ef412d6 --- /dev/null +++ b/PartField/partfield/model/UNet/model.py @@ -0,0 +1,170 @@ +# https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/buildingblocks.py +# MIT License + +# Copyright (c) 2018 Adrian Wolny + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch.nn as nn + +from partfield.model.UNet.buildingblocks import DoubleConv, ResNetBlock, \ + create_decoders, create_encoders + +def number_of_features_per_level(init_channel_number, num_levels): + return [init_channel_number * 2 ** k for k in range(num_levels)] + +class AbstractUNet(nn.Module): + """ + Base class for standard and residual UNet. + + Args: + in_channels (int): number of input channels + out_channels (int): number of output segmentation masks; + Note that the of out_channels might correspond to either + different semantic classes or to different binary segmentation mask. + It's up to the user of the class to interpret the out_channels and + use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) + or BCEWithLogitsLoss (two-class) respectively) + f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number + of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 + final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1 convolution, + otherwise apply nn.Softmax. In effect only if `self.training == False`, i.e. during validation/testing + basic_module: basic model for the encoder/decoder (DoubleConv, ResNetBlock, ....) + layer_order (string): determines the order of layers in `SingleConv` module. + E.g. 'crg' stands for GroupNorm3d+Conv3d+ReLU. See `SingleConv` for more info + num_groups (int): number of groups for the GroupNorm + num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) + default: 4 + is_segmentation (bool): if True and the model is in eval mode, Sigmoid/Softmax normalization is applied + after the final convolution; if False (regression problem) the normalization layer is skipped + conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module + pool_kernel_size (int or tuple): the size of the window + conv_padding (int or tuple): add zero-padding added to all three sides of the input + conv_upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2 + upsample (str): algorithm used for decoder upsampling: + InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area' + TransposeConvUpsampling: 'deconv' + No upsampling: None + Default: 'default' (chooses automatically) + dropout_prob (float or tuple): dropout probability, default: 0.1 + is3d (bool): if True the model is 3D, otherwise 2D, default: True + """ + + def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr', + num_groups=8, num_levels=4, is_segmentation=False, conv_kernel_size=3, pool_kernel_size=2, + conv_padding=1, conv_upscale=2, upsample='default', dropout_prob=0.1, is3d=True, encoder_only=False): + super(AbstractUNet, self).__init__() + + if isinstance(f_maps, int): + f_maps = number_of_features_per_level(f_maps, num_levels=num_levels) + + assert isinstance(f_maps, list) or isinstance(f_maps, tuple) + assert len(f_maps) > 1, "Required at least 2 levels in the U-Net" + if 'g' in layer_order: + assert num_groups is not None, "num_groups must be specified if GroupNorm is used" + + # create encoder path + self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, + conv_padding, conv_upscale, dropout_prob, + layer_order, num_groups, pool_kernel_size, is3d) + + self.encoder_only = encoder_only + + if encoder_only == False: + # create decoder path + self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, + layer_order, num_groups, upsample, dropout_prob, + is3d) + + # in the last layer a 1×1 convolution reduces the number of output channels to the number of labels + if is3d: + self.final_conv = nn.Conv3d(f_maps[1], out_channels, 1) + else: + self.final_conv = nn.Conv2d(f_maps[1], out_channels, 1) + + if is_segmentation: + # semantic segmentation problem + if final_sigmoid: + self.final_activation = nn.Sigmoid() + else: + self.final_activation = nn.Softmax(dim=1) + else: + # regression problem + self.final_activation = None + + def forward(self, x, return_bottleneck_feat=False): + # encoder part + encoders_features = [] + for encoder in self.encoders: + x = encoder(x) + # reverse the encoder outputs to be aligned with the decoder + encoders_features.insert(0, x) + + # remove the last encoder's output from the list + # !!remember: it's the 1st in the list + bottleneck_feat = encoders_features[0] + if self.encoder_only: + return bottleneck_feat + else: + encoders_features = encoders_features[1:] + + # decoder part + for decoder, encoder_features in zip(self.decoders, encoders_features): + # pass the output from the corresponding encoder and the output + # of the previous decoder + x = decoder(encoder_features, x) + + x = self.final_conv(x) + # During training the network outputs logits + if self.final_activation is not None: + x = self.final_activation(x) + + if return_bottleneck_feat: + return x, bottleneck_feat + else: + return x + +class ResidualUNet3D(AbstractUNet): + """ + Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf. + Uses ResNetBlock as a basic building block, summation joining instead + of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts). + Since the model effectively becomes a residual net, in theory it allows for deeper UNet. + """ + + def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=(8, 16, 64, 256, 1024), layer_order='gcr', + num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1, + conv_upscale=2, upsample='default', dropout_prob=0.1, encoder_only=False, **kwargs): + super(ResidualUNet3D, self).__init__(in_channels=in_channels, + out_channels=out_channels, + final_sigmoid=final_sigmoid, + basic_module=ResNetBlock, + f_maps=f_maps, + layer_order=layer_order, + num_groups=num_groups, + num_levels=num_levels, + is_segmentation=is_segmentation, + conv_padding=conv_padding, + conv_upscale=conv_upscale, + upsample=upsample, + dropout_prob=dropout_prob, + encoder_only=encoder_only, + is3d=True) + + diff --git a/PartField/partfield/model/__pycache__/model_utils.cpython-310.pyc b/PartField/partfield/model/__pycache__/model_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db284a3e3591ed4cadc4d7100c99f17c501cdedd Binary files /dev/null and b/PartField/partfield/model/__pycache__/model_utils.cpython-310.pyc differ diff --git a/PartField/partfield/model/__pycache__/triplane.cpython-310.pyc b/PartField/partfield/model/__pycache__/triplane.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f4c2fa510e7e8e2338570e8cbb53b59f6bb955a Binary files /dev/null and b/PartField/partfield/model/__pycache__/triplane.cpython-310.pyc differ diff --git a/PartField/partfield/model/model_utils.py b/PartField/partfield/model/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c1cc16bd96d21b69f21998e577f7f5f970d25cf3 --- /dev/null +++ b/PartField/partfield/model/model_utils.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn + +class VanillaMLP(nn.Module): + def __init__(self, input_dim, output_dim, out_activation, n_hidden_layers=4, n_neurons=64, activation="ReLU"): + super().__init__() + self.n_neurons = n_neurons + self.n_hidden_layers = n_hidden_layers + self.activation = activation + self.out_activation = out_activation + layers = [ + self.make_linear(input_dim, self.n_neurons, is_first=True, is_last=False), + self.make_activation(), + ] + for i in range(self.n_hidden_layers - 1): + layers += [ + self.make_linear( + self.n_neurons, self.n_neurons, is_first=False, is_last=False + ), + self.make_activation(), + ] + layers += [ + self.make_linear(self.n_neurons, output_dim, is_first=False, is_last=True) + ] + if self.out_activation == "sigmoid": + layers += [nn.Sigmoid()] + elif self.out_activation == "tanh": + layers += [nn.Tanh()] + elif self.out_activation == "hardtanh": + layers += [nn.Hardtanh()] + elif self.out_activation == "GELU": + layers += [nn.GELU()] + elif self.out_activation == "RELU": + layers += [nn.ReLU()] + else: + raise NotImplementedError + self.layers = nn.Sequential(*layers) + + def forward(self, x, split_size=100000): + with torch.cuda.amp.autocast(enabled=False): + out = self.layers(x) + return out + + def make_linear(self, dim_in, dim_out, is_first, is_last): + layer = nn.Linear(dim_in, dim_out, bias=False) + return layer + + def make_activation(self): + if self.activation == "ReLU": + return nn.ReLU(inplace=True) + elif self.activation == "GELU": + return nn.GELU() + else: + raise NotImplementedError \ No newline at end of file diff --git a/PartField/partfield/model/triplane.py b/PartField/partfield/model/triplane.py new file mode 100644 index 0000000000000000000000000000000000000000..6274a8398d248d3ba4a5a4734c7b1bc90d596b10 --- /dev/null +++ b/PartField/partfield/model/triplane.py @@ -0,0 +1,331 @@ +#https://github.com/3DTopia/OpenLRM/blob/main/openlrm/models/modeling_lrm.py +# Copyright (c) 2023-2024, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from functools import partial + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) + inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) + projections = torch.bmm(coordinates, inv_planes) + return projections[..., :2] + +def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): + plane_axes = torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]], dtype=torch.float32).cuda() + + assert padding_mode == 'zeros' + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N*n_planes, C, H, W) + + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) + output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + return output_features + +def get_grid_coord(grid_size = 256, align_corners=False): + if align_corners == False: + coords = torch.linspace(-1 + 1/(grid_size), 1 - 1/(grid_size), steps=grid_size) + else: + coords = torch.linspace(-1, 1, steps=grid_size) + i, j, k = torch.meshgrid(coords, coords, coords, indexing='ij') + coordinates = torch.stack((i, j, k), dim=-1).reshape(-1, 3) + return coordinates + +class BasicBlock(nn.Module): + """ + Transformer block that is in its simplest form. + Designed for PF-LRM architecture. + """ + # Block contains a self-attention layer and an MLP + def __init__(self, inner_dim: int, num_heads: int, eps: float, + attn_drop: float = 0., attn_bias: bool = False, + mlp_ratio: float = 4., mlp_drop: float = 0.): + super().__init__() + self.norm1 = nn.LayerNorm(inner_dim, eps=eps) + self.self_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm2 = nn.LayerNorm(inner_dim, eps=eps) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x): + # x: [N, L, D] + before_sa = self.norm1(x) + x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] + x = x + self.mlp(self.norm2(x)) + return x + +class ConditionBlock(nn.Module): + """ + Transformer block that takes in a cross-attention condition. + Designed for SparseLRM architecture. + """ + # Block contains a cross-attention layer, a self-attention layer, and an MLP + def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float, + attn_drop: float = 0., attn_bias: bool = False, + mlp_ratio: float = 4., mlp_drop: float = 0.): + super().__init__() + self.norm1 = nn.LayerNorm(inner_dim, eps=eps) + self.cross_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm2 = nn.LayerNorm(inner_dim, eps=eps) + self.self_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm3 = nn.LayerNorm(inner_dim, eps=eps) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x, cond): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] + x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0] + before_sa = self.norm2(x) + x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0] + x = x + self.mlp(self.norm3(x)) + return x + +class TransformerDecoder(nn.Module): + def __init__(self, block_type: str, + num_layers: int, num_heads: int, + inner_dim: int, cond_dim: int = None, + eps: float = 1e-6): + super().__init__() + self.block_type = block_type + self.layers = nn.ModuleList([ + self._block_fn(inner_dim, cond_dim)( + num_heads=num_heads, + eps=eps, + ) + for _ in range(num_layers) + ]) + self.norm = nn.LayerNorm(inner_dim, eps=eps) + + @property + def block_type(self): + return self._block_type + + @block_type.setter + def block_type(self, block_type): + assert block_type in ['cond', 'basic'], \ + f"Unsupported block type: {block_type}" + self._block_type = block_type + + def _block_fn(self, inner_dim, cond_dim): + assert inner_dim is not None, f"inner_dim must always be specified" + if self.block_type == 'basic': + return partial(BasicBlock, inner_dim=inner_dim) + elif self.block_type == 'cond': + assert cond_dim is not None, f"Condition dimension must be specified for ConditionBlock" + return partial(ConditionBlock, inner_dim=inner_dim, cond_dim=cond_dim) + else: + raise ValueError(f"Unsupported block type during runtime: {self.block_type}") + + + def forward_layer(self, layer: nn.Module, x: torch.Tensor, cond: torch.Tensor,): + if self.block_type == 'basic': + return layer(x) + elif self.block_type == 'cond': + return layer(x, cond) + else: + raise NotImplementedError + + def forward(self, x: torch.Tensor, cond: torch.Tensor = None): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] or None + for layer in self.layers: + x = self.forward_layer(layer, x, cond) + x = self.norm(x) + return x + +class Voxel2Triplane(nn.Module): + """ + Full model of the basic single-view large reconstruction model. + """ + def __init__(self, transformer_dim: int, transformer_layers: int, transformer_heads: int, + triplane_low_res: int, triplane_high_res: int, triplane_dim: int, voxel_feat_dim: int, normalize_vox_feat=False, voxel_dim=16): + super().__init__() + + # attributes + self.triplane_low_res = triplane_low_res + self.triplane_high_res = triplane_high_res + self.triplane_dim = triplane_dim + self.voxel_feat_dim = voxel_feat_dim + + # initialize pos_embed with 1/sqrt(dim) * N(0, 1) + self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, transformer_dim) * (1. / transformer_dim) ** 0.5) + self.transformer = TransformerDecoder( + block_type='cond', + num_layers=transformer_layers, num_heads=transformer_heads, + inner_dim=transformer_dim, cond_dim=voxel_feat_dim + ) + self.upsampler = nn.ConvTranspose2d(transformer_dim, triplane_dim, kernel_size=8, stride=8, padding=0) + + self.normalize_vox_feat = normalize_vox_feat + if normalize_vox_feat: + self.vox_norm = nn.LayerNorm(voxel_feat_dim, eps=1e-6) + self.vox_pos_embed = nn.Parameter(torch.randn(1, voxel_dim * voxel_dim * voxel_dim, voxel_feat_dim) * (1. / voxel_feat_dim) ** 0.5) + + def forward_transformer(self, voxel_feats): + N = voxel_feats.shape[0] + x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] + if self.normalize_vox_feat: + vox_pos_embed = self.vox_pos_embed.repeat(N, 1, 1) # [N, L, D] + voxel_feats = self.vox_norm(voxel_feats + vox_pos_embed) + x = self.transformer( + x, + cond=voxel_feats + ) + return x + + def reshape_upsample(self, tokens): + N = tokens.shape[0] + H = W = self.triplane_low_res + x = tokens.view(N, 3, H, W, -1) + x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W] + x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] + x = self.upsampler(x) # [3*N, D', H', W'] + x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] + x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] + x = x.contiguous() + return x + + def forward(self, voxel_feats): + N = voxel_feats.shape[0] + + # encode image + assert voxel_feats.shape[-1] == self.voxel_feat_dim, \ + f"Feature dimension mismatch: {voxel_feats.shape[-1]} vs {self.voxel_feat_dim}" + + # transformer generating planes + tokens = self.forward_transformer(voxel_feats) + planes = self.reshape_upsample(tokens) + assert planes.shape[0] == N, "Batch size mismatch for planes" + assert planes.shape[1] == 3, "Planes should have 3 channels" + + return planes + + +class TriplaneTransformer(nn.Module): + """ + Full model of the basic single-view large reconstruction model. + """ + def __init__(self, input_dim: int, transformer_dim: int, transformer_layers: int, transformer_heads: int, + triplane_low_res: int, triplane_high_res: int, triplane_dim: int): + super().__init__() + + # attributes + self.triplane_low_res = triplane_low_res + self.triplane_high_res = triplane_high_res + self.triplane_dim = triplane_dim + + # initialize pos_embed with 1/sqrt(dim) * N(0, 1) + self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, transformer_dim) * (1. / transformer_dim) ** 0.5) + self.transformer = TransformerDecoder( + block_type='basic', + num_layers=transformer_layers, num_heads=transformer_heads, + inner_dim=transformer_dim, + ) + + self.downsampler = nn.Sequential( + nn.Conv2d(input_dim, transformer_dim, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), # Reduces size from 128x128 to 64x64 + + nn.Conv2d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), # Reduces size from 64x64 to 32x32 + ) + + self.upsampler = nn.ConvTranspose2d(transformer_dim, triplane_dim, kernel_size=4, stride=4, padding=0) + + self.mlp = nn.Sequential( + nn.Linear(input_dim, triplane_dim), + nn.ReLU(), + nn.Linear(triplane_dim, triplane_dim) + ) + + def forward_transformer(self, triplanes): + N = triplanes.shape[0] + tokens = torch.einsum('nidhw->nihwd', triplanes).reshape(N, self.pos_embed.shape[1], -1) # [N, L, D] + x = self.pos_embed.repeat(N, 1, 1) + tokens # [N, L, D] + x = self.transformer(x) + return x + + def reshape_downsample(self, triplanes): + N = triplanes.shape[0] + H = W = self.triplane_high_res + x = triplanes.view(N, 3, -1, H, W) + x = torch.einsum('nidhw->indhw', x) # [3, N, D, H, W] + x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] + x = self.downsampler(x) # [3*N, D', H', W'] + x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] + x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] + x = x.contiguous() + return x + + def reshape_upsample(self, tokens): + N = tokens.shape[0] + H = W = self.triplane_low_res + x = tokens.view(N, 3, H, W, -1) + x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W] + x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] + x = self.upsampler(x) # [3*N, D', H', W'] + x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] + x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] + x = x.contiguous() + return x + + def forward(self, triplanes): + downsampled_triplanes = self.reshape_downsample(triplanes) + tokens = self.forward_transformer(downsampled_triplanes) + residual = self.reshape_upsample(tokens) + + triplanes = triplanes.permute(0, 1, 3, 4, 2).contiguous() + triplanes = self.mlp(triplanes) + triplanes = triplanes.permute(0, 1, 4, 2, 3).contiguous() + planes = triplanes + residual + return planes diff --git a/PartField/partfield/model_trainer_pvcnn_only_demo.py b/PartField/partfield/model_trainer_pvcnn_only_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2f4423bb38424d8b05e850f48b429554da8dbd --- /dev/null +++ b/PartField/partfield/model_trainer_pvcnn_only_demo.py @@ -0,0 +1,286 @@ +import torch +import lightning.pytorch as pl +from .dataloader import Demo_Dataset, Demo_Remesh_Dataset, Correspondence_Demo_Dataset +from torch.utils.data import DataLoader +from partfield.model.UNet.model import ResidualUNet3D +from partfield.model.triplane import TriplaneTransformer, get_grid_coord #, sample_from_planes, Voxel2Triplane +from partfield.model.model_utils import VanillaMLP +import torch.nn.functional as F +import torch.nn as nn +import os +import trimesh +import skimage +import numpy as np +import h5py +import torch.distributed as dist +from partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat +import json +import gc +import time +from plyfile import PlyData, PlyElement + + +class Model(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + + self.save_hyperparameters() + self.cfg = cfg + self.automatic_optimization = False + self.triplane_resolution = cfg.triplane_resolution + self.triplane_channels_low = cfg.triplane_channels_low + self.triplane_transformer = TriplaneTransformer( + input_dim=cfg.triplane_channels_low * 2, + transformer_dim=1024, + transformer_layers=6, + transformer_heads=8, + triplane_low_res=32, + triplane_high_res=128, + triplane_dim=cfg.triplane_channels_high, + ) + self.sdf_decoder = VanillaMLP(input_dim=64, + output_dim=1, + out_activation="tanh", + n_neurons=64, #64 + n_hidden_layers=6) #6 + self.use_pvcnn = cfg.use_pvcnnonly + self.use_2d_feat = cfg.use_2d_feat + if self.use_pvcnn: + self.pvcnn = TriPlanePC2Encoder( + cfg.pvcnn, + device="cuda", + shape_min=-1, + shape_length=2, + use_2d_feat=self.use_2d_feat) #.cuda() + self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True)) + self.grid_coord = get_grid_coord(256) + self.mse_loss = torch.nn.MSELoss() + self.l1_loss = torch.nn.L1Loss(reduction='none') + + if cfg.regress_2d_feat: + self.feat_decoder = VanillaMLP(input_dim=64, + output_dim=192, + out_activation="GELU", + n_neurons=64, #64 + n_hidden_layers=6) #6 + + def predict_dataloader(self): + if self.cfg.remesh_demo: + dataset = Demo_Remesh_Dataset(self.cfg) + elif self.cfg.correspondence_demo: + dataset = Correspondence_Demo_Dataset(self.cfg) + else: + dataset = Demo_Dataset(self.cfg) + + dataloader = DataLoader(dataset, + num_workers=self.cfg.dataset.val_num_workers, + batch_size=self.cfg.dataset.val_batch_size, + shuffle=False, + pin_memory=True, + drop_last=False) + + return dataloader + + + @torch.no_grad() + def predict_step(self, batch, batch_idx): + save_dir = f"exp_results/{self.cfg.result_name}" + os.makedirs(save_dir, exist_ok=True) + + uid = batch['uid'][0] + view_id = 0 + starttime = time.time() + + if uid == "car" or uid == "complex_car": + # if uid == "complex_car": + print("Skipping this for now.") + print(uid) + return + + ### Skip if model already processed + if os.path.exists(f'{save_dir}/part_feat_{uid}_{view_id}.npy') or os.path.exists(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy'): + print("Already processed "+uid) + return + + N = batch['pc'].shape[0] + assert N == 1 + + if self.use_2d_feat: + print("ERROR. Dataloader not implemented with input 2d feat.") + exit() + else: + pc_feat = self.pvcnn(batch['pc'], batch['pc']) + + planes = pc_feat + planes = self.triplane_transformer(planes) + sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2) + + if self.cfg.is_pc: + tensor_vertices = batch['pc'].reshape(1, -1, 3).cuda().to(torch.float16) + point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C + point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448) + + np.save(f'{save_dir}/part_feat_{uid}_{view_id}.npy', point_feat) + print(f"Exported part_feat_{uid}_{view_id}.npy") + + ########### + from sklearn.decomposition import PCA + data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) + + pca = PCA(n_components=3) + + data_reduced = pca.fit_transform(data_scaled) + data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min()) + colors_255 = (data_reduced * 255).astype(np.uint8) + + points = batch['pc'].squeeze().detach().cpu().numpy() + + if colors_255 is None: + colors_255 = np.full_like(points, 255) # Default to white color (255,255,255) + else: + assert colors_255.shape == points.shape, "Colors must have the same shape as points" + + # Convert to structured array for PLY format + vertex_data = np.array( + [(*point, *color) for point, color in zip(points, colors_255)], + dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1")] + ) + + # Create PLY element + el = PlyElement.describe(vertex_data, "vertex") + # Write to file + filename = f'{save_dir}/feat_pca_{uid}_{view_id}.ply' + PlyData([el], text=True).write(filename) + print(f"Saved PLY file: {filename}") + ############ + + else: + use_cuda_version = True + if use_cuda_version: + + def sample_points(vertices, faces, n_point_per_face): + # Generate random barycentric coordinates + # borrowed from Kaolin https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/mesh/trianglemesh.py#L43 + n_f = faces.shape[0] + u = torch.sqrt(torch.rand((n_f, n_point_per_face, 1), + device=vertices.device, + dtype=vertices.dtype)) + v = torch.rand((n_f, n_point_per_face, 1), + device=vertices.device, + dtype=vertices.dtype) + w0 = 1 - u + w1 = u * (1 - v) + w2 = u * v + + face_v_0 = torch.index_select(vertices, 0, faces[:, 0].reshape(-1)) + face_v_1 = torch.index_select(vertices, 0, faces[:, 1].reshape(-1)) + face_v_2 = torch.index_select(vertices, 0, faces[:, 2].reshape(-1)) + points = w0 * face_v_0.unsqueeze(dim=1) + w1 * face_v_1.unsqueeze(dim=1) + w2 * face_v_2.unsqueeze(dim=1) + return points + + def sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face): + n_sample_each = self.cfg.n_sample_each # we iterate over this to avoid OOM + n_v = tensor_vertices.shape[1] + n_sample = n_v // n_sample_each + 1 + all_sample = [] + for i_sample in range(n_sample): + sampled_feature = sample_triplane_feat(part_planes, tensor_vertices[:, i_sample * n_sample_each: i_sample * n_sample_each + n_sample_each,]) + assert sampled_feature.shape[1] % n_point_per_face == 0 + sampled_feature = sampled_feature.reshape(1, -1, n_point_per_face, sampled_feature.shape[-1]) + sampled_feature = torch.mean(sampled_feature, axis=-2) + all_sample.append(sampled_feature) + return torch.cat(all_sample, dim=1) + + if self.cfg.vertex_feature: + tensor_vertices = batch['vertices'][0].reshape(1, -1, 3).to(torch.float32) + point_feat = sample_and_mean_memory_save_version(part_planes, tensor_vertices, 1) + else: + n_point_per_face = self.cfg.n_point_per_face + tensor_vertices = sample_points(batch['vertices'][0], batch['faces'][0], n_point_per_face) + tensor_vertices = tensor_vertices.reshape(1, -1, 3).to(torch.float32) + point_feat = sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face) # N, M, C + + #### Take mean feature in the triangle + print("Time elapsed for feature prediction: " + str(time.time() - starttime)) + point_feat = point_feat.reshape(-1, 448).cpu().numpy() + np.save(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy', point_feat) + print(f"Exported part_feat_{uid}_{view_id}.npy") + + ########### + from sklearn.decomposition import PCA + + combined_feat = np.load("/scratch/shared/beegfs/ruining/projects/PartField/exp_results/partfield_features/correspondence/combined.npy") + combined_feat_scaled = combined_feat / np.linalg.norm(combined_feat, axis=-1, keepdims=True) + data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) + + pca = PCA(n_components=3).fit(combined_feat_scaled) + + data_reduced = pca.transform(data_scaled) + data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min()) + colors_255 = (data_reduced * 255).astype(np.uint8) + V = batch['vertices'][0].cpu().numpy() + F = batch['faces'][0].cpu().numpy() + if self.cfg.vertex_feature: + colored_mesh = trimesh.Trimesh(vertices=V, faces=F, vertex_colors=colors_255, process=False) + else: + colored_mesh = trimesh.Trimesh(vertices=V, faces=F, face_colors=colors_255, process=False) + colored_mesh.export(f'{save_dir}/feat_pca_{uid}_{view_id}.ply') + ############ + torch.cuda.empty_cache() + + else: + ### Mesh input (obj file) + V = batch['vertices'][0].cpu().numpy() + F = batch['faces'][0].cpu().numpy() + + ##### Loop through faces ##### + num_samples_per_face = self.cfg.n_point_per_face + + all_point_feats = [] + for face in F: + # Get the vertices of the current face + v0, v1, v2 = V[face] + + # Generate random barycentric coordinates + u = np.random.rand(num_samples_per_face, 1) + v = np.random.rand(num_samples_per_face, 1) + is_prob = (u+v) >1 + u[is_prob] = 1 - u[is_prob] + v[is_prob] = 1 - v[is_prob] + w = 1 - u - v + + # Calculate points in Cartesian coordinates + points = u * v0 + v * v1 + w * v2 + + tensor_vertices = torch.from_numpy(points.copy()).reshape(1, -1, 3).cuda().to(torch.float32) + point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C + + #### Take mean feature in the triangle + point_feat = torch.mean(point_feat, axis=1).cpu().detach().numpy() + all_point_feats.append(point_feat) + ############################## + + all_point_feats = np.array(all_point_feats).reshape(-1, 448) + + point_feat = all_point_feats + + np.save(f'{save_dir}/part_feat_{uid}_{view_id}.npy', point_feat) + print(f"Exported part_feat_{uid}_{view_id}.npy") + + ########### + from sklearn.decomposition import PCA + data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) + + pca = PCA(n_components=3) + + data_reduced = pca.fit_transform(data_scaled) + data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min()) + colors_255 = (data_reduced * 255).astype(np.uint8) + + colored_mesh = trimesh.Trimesh(vertices=V, faces=F, face_colors=colors_255, process=False) + colored_mesh.export(f'{save_dir}/feat_pca_{uid}_{view_id}.ply') + ############ + + print("Time elapsed: " + str(time.time()-starttime)) + + return \ No newline at end of file diff --git a/PartField/partfield/utils.py b/PartField/partfield/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc176434f91bd753f2daee6cc7d3c4fd7ce163b --- /dev/null +++ b/PartField/partfield/utils.py @@ -0,0 +1,5 @@ +import trimesh + +def load_mesh_util(input_fname): + mesh = trimesh.load(input_fname, force='mesh', process=False) + return mesh \ No newline at end of file diff --git a/PartField/partfield_inference.py b/PartField/partfield_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..abc41c772d34df1c0427178cd3d76fc547b6ab48 --- /dev/null +++ b/PartField/partfield_inference.py @@ -0,0 +1,61 @@ +from partfield.config import default_argument_parser, setup +from lightning.pytorch import seed_everything, Trainer +from lightning.pytorch.strategies import DDPStrategy +from lightning.pytorch.callbacks import ModelCheckpoint +import lightning +import torch +import glob +import os, sys +import numpy as np +import random + +def predict(cfg): + seed_everything(cfg.seed) + + torch.manual_seed(0) + random.seed(0) + np.random.seed(0) + + checkpoint_callbacks = [ModelCheckpoint( + monitor="train/current_epoch", + dirpath=cfg.output_dir, + filename="{epoch:02d}", + save_top_k=100, + save_last=True, + every_n_epochs=cfg.save_every_epoch, + mode="max", + verbose=True + )] + + trainer = Trainer(devices=-1, + accelerator="gpu", + precision="16-mixed", + strategy=DDPStrategy(find_unused_parameters=True), + max_epochs=cfg.training_epochs, + log_every_n_steps=1, + limit_train_batches=3500, + limit_val_batches=None, + callbacks=checkpoint_callbacks + ) + + from partfield.model_trainer_pvcnn_only_demo import Model + model = Model(cfg) + + if cfg.remesh_demo: + cfg.n_point_per_face = 10 + + trainer.predict(model, ckpt_path=cfg.continue_ckpt) + +def main(): + parser = default_argument_parser() + + npz_file = "/scratch/shared/beegfs/ruining/data/articulate-3d/points-all-dinov3/7265-combination_000-pos_000.npz" + datum = np.load(npz_file) + pc = datum['points'] + + args = parser.parse_args() + cfg = setup(args, freeze=False) + predict(cfg) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/PartField/partfield_inference_pc.py b/PartField/partfield_inference_pc.py new file mode 100644 index 0000000000000000000000000000000000000000..34f95406003e04f4c56522169ce7497aab9804f7 --- /dev/null +++ b/PartField/partfield_inference_pc.py @@ -0,0 +1,201 @@ +from partfield.config import default_argument_parser, setup +from lightning.pytorch import seed_everything, Trainer +from lightning.pytorch.strategies import DDPStrategy +from lightning.pytorch.callbacks import ModelCheckpoint +import torch +import glob +import os +import numpy as np +import random +import zipfile + +from partfield.model.PVCNN.encoder_pc import sample_triplane_feat + +def predict(cfg): + seed_everything(cfg.seed) + + torch.manual_seed(0) + random.seed(0) + np.random.seed(0) + + checkpoint_callbacks = [ModelCheckpoint( + monitor="train/current_epoch", + dirpath=cfg.output_dir, + filename="{epoch:02d}", + save_top_k=100, + save_last=True, + every_n_epochs=cfg.save_every_epoch, + mode="max", + verbose=True + )] + + trainer = Trainer(devices=-1, + accelerator="gpu", + precision="16-mixed", + strategy=DDPStrategy(find_unused_parameters=True), + max_epochs=cfg.training_epochs, + log_every_n_steps=1, + limit_train_batches=3500, + limit_val_batches=None, + callbacks=checkpoint_callbacks + ) + + from partfield.model_trainer_pvcnn_only_demo import Model + model = Model(cfg) + + if cfg.remesh_demo: + cfg.n_point_per_face = 10 + + trainer.predict(model, ckpt_path=cfg.continue_ckpt) + +def main(): + from tqdm import tqdm + + parser = default_argument_parser() + parser.add_argument('--num_jobs', type=int, default=1, help='Total number of parallel jobs') + parser.add_argument('--job_id', type=int, default=0, help='Current job ID (0 to num_jobs-1)') + args = parser.parse_args() + cfg = setup(args, freeze=False) + cfg.is_pc = True + + # Validate job arguments + if args.job_id >= args.num_jobs: + raise ValueError(f"job_id ({args.job_id}) must be less than num_jobs ({args.num_jobs})") + if args.job_id < 0: + raise ValueError(f"job_id ({args.job_id}) must be >= 0") + + from partfield.model_trainer_pvcnn_only_demo import Model + model = Model.load_from_checkpoint(cfg.continue_ckpt, cfg=cfg) + model.eval() + model.to('cuda') + + encode_pc_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-uniform-100k-singlestate-pts" + decode_pc_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-sharp50pct-40k-singlestate-pts" + dest_feat_root = "/scratch/shared/beegfs/ruining/data/articulate-3d/Lightwheel/all-sharp50pct-40k-singlestate-feats" + + # Create destination directory + os.makedirs(dest_feat_root, exist_ok=True) + + encode_files = sorted(glob.glob(os.path.join(encode_pc_root, "*.npy"))) + decode_files = sorted(glob.glob(os.path.join(decode_pc_root, "*.npy"))) + + # Filter files for this job + job_files = [pair for i, pair in enumerate(zip(encode_files, decode_files)) if i % args.num_jobs == args.job_id] + + print(f"Job {args.job_id}/{args.num_jobs}: Processing {len(job_files)}/{len(encode_files)} files") + + num_bad_zip, num_failed_others = 0, 0 + for encode_file, decode_file in tqdm(job_files, desc=f"Job {args.job_id}"): + try: + # Get UID from decode file (the one we're extracting features for) + uid = os.path.basename(decode_file).split('.')[0] + assert uid == os.path.basename(encode_file).split('.')[0] + + dest_feat_file = os.path.join(dest_feat_root, f"{uid}.npy") + if os.path.exists(dest_feat_file): + continue + + # Load both encode and decode point clouds + encode_pc = np.load(encode_file) + decode_pc = np.load(decode_file) + + # Validate input data + if np.isnan(encode_pc).any() or np.isnan(decode_pc).any(): + print(f"Skipping {uid}: NaN values in point cloud") + num_failed_others += 1 + continue + if np.isinf(encode_pc).any() or np.isinf(decode_pc).any(): + print(f"Skipping {uid}: Inf values in point cloud") + num_failed_others += 1 + continue + + # Compute bounding box from ALL points (encode + decode) for consistent normalization + all_points = np.vstack([encode_pc, decode_pc]) + bbmin = all_points.min(0) + bbmax = all_points.max(0) + + # Check for degenerate bounding box + bbox_size = (bbmax - bbmin).max() + if bbox_size < 1e-6: + print(f"Skipping {uid}: Degenerate bounding box (size={bbox_size})") + num_failed_others += 1 + continue + + center = (bbmin + bbmax) * 0.5 + scale = 2.0 * 0.9 / bbox_size + + # Apply same normalization to both point clouds + encode_pc_normalized = (encode_pc - center) * scale + decode_pc_normalized = (decode_pc - center) * scale + + # Validate normalized coordinates + if np.isnan(encode_pc_normalized).any() or np.isnan(decode_pc_normalized).any(): + print(f"Skipping {uid}: NaN in normalized coordinates") + num_failed_others += 1 + continue + if np.isinf(encode_pc_normalized).any() or np.isinf(decode_pc_normalized).any(): + print(f"Skipping {uid}: Inf in normalized coordinates") + num_failed_others += 1 + continue + + # Check if normalized coordinates are within reasonable range (should be ~[-1, 1]) + encode_max = np.abs(encode_pc_normalized).max() + decode_max = np.abs(decode_pc_normalized).max() + if encode_max > 10 or decode_max > 10: + print(f"Skipping {uid}: Normalized coordinates out of range (encode_max={encode_max:.2f}, decode_max={decode_max:.2f})") + num_failed_others += 1 + continue + + # Use encode_pc to generate triplane + batch_encode_pc = torch.from_numpy(encode_pc_normalized).unsqueeze(0).float().to('cuda') + + with torch.no_grad(): + try: + # Generate triplane from encode_pc + pc_feat = model.pvcnn(batch_encode_pc, batch_encode_pc) + planes = model.triplane_transformer(pc_feat) + sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2) + + # Sample features at decode_pc points + tensor_vertices = torch.from_numpy(decode_pc_normalized).reshape(1, -1, 3).to(torch.float32).cuda() + + # Validate tensor before sampling + if torch.isnan(tensor_vertices).any() or torch.isinf(tensor_vertices).any(): + print(f"Skipping {uid}: Invalid tensor_vertices after conversion to torch") + num_failed_others += 1 + continue + + point_feat = sample_triplane_feat(part_planes, tensor_vertices) + point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448) + + # Save point features + np.save(dest_feat_file, point_feat.astype(np.float16)) + + except RuntimeError as e: + if "CUDA" in str(e) or "index" in str(e).lower(): + print(f"Skipping {uid}: CUDA error - {str(e)[:100]}") + print(f" encode shape: {encode_pc.shape}, decode shape: {decode_pc.shape}") + print(f" bbox_size: {bbox_size:.6f}, scale: {scale:.6f}") + print(f" normalized range: [{encode_pc_normalized.min():.3f}, {encode_pc_normalized.max():.3f}]") + num_failed_others += 1 + # Clear CUDA cache to recover from error + torch.cuda.empty_cache() + continue + else: + raise + + except zipfile.BadZipFile: + num_bad_zip += 1 + continue + + except Exception: + num_failed_others += 1 + continue + + print(f"Job {args.job_id} - Number of bad zip files: {num_bad_zip}") + print(f"Job {args.job_id} - Number of failed others: {num_failed_others}") + print(f"Job {args.job_id} completed successfully!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/PartField/run_part_clustering.py b/PartField/run_part_clustering.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e98a8f1c0c227b0b3f820ef4c98ebe9867d514 --- /dev/null +++ b/PartField/run_part_clustering.py @@ -0,0 +1,806 @@ +from sklearn.cluster import AgglomerativeClustering, KMeans +import numpy as np +import trimesh +import matplotlib.pyplot as plt +import numpy as np +import os +import argparse +import time + +import json +from os.path import join +from typing import List + +from collections import defaultdict +from scipy.sparse import coo_matrix, csr_matrix +from scipy.sparse.csgraph import connected_components +from sklearn.neighbors import NearestNeighbors +import networkx as nx + +from plyfile import PlyData +import open3d as o3d +from partfield.utils import * + +#### Export to file ##### +def export_colored_mesh_ply(V, F, FL, filename='segmented_mesh.ply'): + """ + Export a mesh with per-face segmentation labels into a colored PLY file. + + Parameters: + - V (np.ndarray): Vertices array of shape (N, 3) + - F (np.ndarray): Faces array of shape (M, 3) + - FL (np.ndarray): Face labels of shape (M,) + - filename (str): Output filename + """ + assert V.shape[1] == 3 + assert F.shape[1] == 3 + assert F.shape[0] == FL.shape[0] + + # Generate distinct colors for each unique label + unique_labels = np.unique(FL) + colormap = plt.cm.get_cmap("tab20", len(unique_labels)) + label_to_color = { + label: (np.array(colormap(i)[:3]) * 255).astype(np.uint8) + for i, label in enumerate(unique_labels) + } + + mesh = trimesh.Trimesh(vertices=V, faces=F) + FL = np.squeeze(FL) + for i, face in enumerate(F): + label = FL[i] + color = label_to_color[label] + color_with_alpha = np.append(color, 255) # Add alpha value + mesh.visual.face_colors[i] = color_with_alpha + + mesh.export(filename) + print(f"Exported mesh to {filename}") + +def export_pointcloud_with_labels_to_ply(V, VL, filename='colored_pointcloud.ply'): + """ + Export a labeled point cloud to a PLY file with vertex colors. + + Parameters: + - V: (N, 3) numpy array of XYZ coordinates + - VL: (N,) numpy array of integer labels + - filename: Output PLY file name + """ + assert V.shape[0] == VL.shape[0], "Number of vertices and labels must match" + + # Generate unique colors for each label + unique_labels = np.unique(VL) + colormap = plt.cm.get_cmap("tab20", len(unique_labels)) + label_to_color = { + label: colormap(i)[:3] for i, label in enumerate(unique_labels) + } + + VL = np.squeeze(VL) + # Map labels to RGB colors + colors = np.array([label_to_color[label] for label in VL]) + + # Open3D requires colors in float [0, 1] + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(V) + pcd.colors = o3d.utility.Vector3dVector(colors) + + # Save to .ply + o3d.io.write_point_cloud(filename, pcd) + print(f"Point cloud saved to {filename}") +######################### + +######################### +def construct_face_adjacency_matrix_ccmst(face_list, vertices, k=10, with_knn=True): + """ + Given a list of faces (each face is a 3-tuple of vertex indices), + construct a face-based adjacency matrix of shape (num_faces, num_faces). + + Two faces are adjacent if they share an edge (the "mesh adjacency"). + If multiple connected components remain, we: + 1) Compute the centroid of each connected component as the mean of all face centroids. + 2) Use a KNN graph (k=10) based on centroid distances on each connected component. + 3) Compute MST of that KNN graph. + 4) Add MST edges that connect different components as "dummy" edges + in the face adjacency matrix, ensuring one connected component. The selected face for + each connected component is the face closest to the component centroid. + + Parameters + ---------- + face_list : list of tuples + List of faces, each face is a tuple (v0, v1, v2) of vertex indices. + vertices : np.ndarray of shape (num_vertices, 3) + Array of vertex coordinates. + k : int, optional + Number of neighbors to use in centroid KNN. Default is 10. + + Returns + ------- + face_adjacency : scipy.sparse.csr_matrix + A CSR sparse matrix of shape (num_faces, num_faces), + containing 1s for adjacent faces (shared-edge adjacency) + plus dummy edges ensuring a single connected component. + """ + num_faces = len(face_list) + if num_faces == 0: + # Return an empty matrix if no faces + return csr_matrix((0, 0)) + + #-------------------------------------------------------------------------- + # 1) Build adjacency based on shared edges. + # (Same logic as the original code, plus import statements.) + #-------------------------------------------------------------------------- + edge_to_faces = defaultdict(list) + uf = UnionFind(num_faces) + for f_idx, (v0, v1, v2) in enumerate(face_list): + # Sort each edge’s endpoints so (i, j) == (j, i) + edges = [ + tuple(sorted((v0, v1))), + tuple(sorted((v1, v2))), + tuple(sorted((v2, v0))) + ] + for e in edges: + edge_to_faces[e].append(f_idx) + + row = [] + col = [] + for edge, face_indices in edge_to_faces.items(): + unique_faces = list(set(face_indices)) + if len(unique_faces) > 1: + # For every pair of distinct faces that share this edge, + # mark them as mutually adjacent + for i in range(len(unique_faces)): + for j in range(i + 1, len(unique_faces)): + fi = unique_faces[i] + fj = unique_faces[j] + row.append(fi) + col.append(fj) + row.append(fj) + col.append(fi) + uf.union(fi, fj) + + data = np.ones(len(row), dtype=np.int8) + face_adjacency = coo_matrix( + (data, (row, col)), shape=(num_faces, num_faces) + ).tocsr() + + #-------------------------------------------------------------------------- + # 2) Check if the graph from shared edges is already connected. + #-------------------------------------------------------------------------- + n_components = 0 + for i in range(num_faces): + if uf.find(i) == i: + n_components += 1 + print("n_components", n_components) + + if n_components == 1: + # Already a single connected component, no need for dummy edges + return face_adjacency + + #-------------------------------------------------------------------------- + # 3) Compute centroids of each face for building a KNN graph. + #-------------------------------------------------------------------------- + face_centroids = [] + for (v0, v1, v2) in face_list: + centroid = (vertices[v0] + vertices[v1] + vertices[v2]) / 3.0 + face_centroids.append(centroid) + face_centroids = np.array(face_centroids) + + # #-------------------------------------------------------------------------- + # # 4a) Build a KNN graph (k=10) over face centroids using scikit‐learn + # #-------------------------------------------------------------------------- + # knn = NearestNeighbors(n_neighbors=k, algorithm='auto') + # knn.fit(face_centroids) + # distances, indices = knn.kneighbors(face_centroids) + # # 'distances[i]' are the distances from face i to each of its 'k' neighbors + # # 'indices[i]' are the face indices of those neighbors + + #-------------------------------------------------------------------------- + # 4b) Build a KNN graph on connected components + #-------------------------------------------------------------------------- + # Group faces by their root representative in the Union-Find structure + component_dict = {} + for face_idx in range(num_faces): + root = uf.find(face_idx) + if root not in component_dict: + component_dict[root] = set() + component_dict[root].add(face_idx) + + connected_components = list(component_dict.values()) + + print("Using connected component MST.") + component_centroid_face_idx = [] + connected_component_centroids = [] + knn = NearestNeighbors(n_neighbors=1, algorithm='auto') + for component in connected_components: + curr_component_faces = list(component) + curr_component_face_centroids = face_centroids[curr_component_faces] + component_centroid = np.mean(curr_component_face_centroids, axis=0) + + ### Assign a face closest to the centroid + face_idx = curr_component_faces[np.argmin(np.linalg.norm(curr_component_face_centroids-component_centroid, axis=-1))] + + connected_component_centroids.append(component_centroid) + component_centroid_face_idx.append(face_idx) + + component_centroid_face_idx = np.array(component_centroid_face_idx) + connected_component_centroids = np.array(connected_component_centroids) + + if n_components < k: + knn = NearestNeighbors(n_neighbors=n_components, algorithm='auto') + else: + knn = NearestNeighbors(n_neighbors=k, algorithm='auto') + knn.fit(connected_component_centroids) + distances, indices = knn.kneighbors(connected_component_centroids) + + #-------------------------------------------------------------------------- + # 5) Build a weighted graph in NetworkX using centroid-distances as edges + #-------------------------------------------------------------------------- + G = nx.Graph() + # Add each face as a node in the graph + G.add_nodes_from(range(num_faces)) + + # For each face i, add edges (i -> j) for each neighbor j in the KNN + for idx1 in range(n_components): + i = component_centroid_face_idx[idx1] + for idx2, dist in zip(indices[idx1], distances[idx1]): + j = component_centroid_face_idx[idx2] + if i == j: + continue # skip self-loop + # Add an undirected edge with 'weight' = distance + # NetworkX handles parallel edges gracefully via last add_edge, + # but it typically overwrites the weight if (i, j) already exists. + G.add_edge(i, j, weight=dist) + + #-------------------------------------------------------------------------- + # 6) Compute MST on that KNN graph + #-------------------------------------------------------------------------- + mst = nx.minimum_spanning_tree(G, weight='weight') + # Sort MST edges by ascending weight, so we add the shortest edges first + mst_edges_sorted = sorted( + mst.edges(data=True), key=lambda e: e[2]['weight'] + ) + print("mst edges sorted", len(mst_edges_sorted)) + #-------------------------------------------------------------------------- + # 7) Use a union-find structure to add MST edges only if they + # connect two currently disconnected components of the adjacency matrix + #-------------------------------------------------------------------------- + + # Convert face_adjacency to LIL format for efficient edge addition + adjacency_lil = face_adjacency.tolil() + + # Now, step through MST edges in ascending order + for (u, v, attr) in mst_edges_sorted: + if uf.find(u) != uf.find(v): + # These belong to different components, so unify them + uf.union(u, v) + # And add a "dummy" edge to our adjacency matrix + adjacency_lil[u, v] = 1 + adjacency_lil[v, u] = 1 + + # Convert back to CSR format and return + face_adjacency = adjacency_lil.tocsr() + + if with_knn: + print("Adding KNN edges.") + ### Add KNN edges graph too + dummy_row = [] + dummy_col = [] + for idx1 in range(n_components): + i = component_centroid_face_idx[idx1] + for idx2 in indices[idx1]: + j = component_centroid_face_idx[idx2] + dummy_row.extend([i, j]) + dummy_col.extend([j, i]) ### duplicates are handled by coo + + dummy_data = np.ones(len(dummy_row), dtype=np.int16) + dummy_mat = coo_matrix( + (dummy_data, (dummy_row, dummy_col)), + shape=(num_faces, num_faces) + ).tocsr() + face_adjacency = face_adjacency + dummy_mat + ########################### + + return face_adjacency +######################### + +def construct_face_adjacency_matrix_facemst(face_list, vertices, k=10, with_knn=True): + """ + Given a list of faces (each face is a 3-tuple of vertex indices), + construct a face-based adjacency matrix of shape (num_faces, num_faces). + + Two faces are adjacent if they share an edge (the "mesh adjacency"). + If multiple connected components remain, we: + 1) Compute the centroid of each face. + 2) Use a KNN graph (k=10) based on centroid distances. + 3) Compute MST of that KNN graph. + 4) Add MST edges that connect different components as "dummy" edges + in the face adjacency matrix, ensuring one connected component. + + Parameters + ---------- + face_list : list of tuples + List of faces, each face is a tuple (v0, v1, v2) of vertex indices. + vertices : np.ndarray of shape (num_vertices, 3) + Array of vertex coordinates. + k : int, optional + Number of neighbors to use in centroid KNN. Default is 10. + + Returns + ------- + face_adjacency : scipy.sparse.csr_matrix + A CSR sparse matrix of shape (num_faces, num_faces), + containing 1s for adjacent faces (shared-edge adjacency) + plus dummy edges ensuring a single connected component. + """ + num_faces = len(face_list) + if num_faces == 0: + # Return an empty matrix if no faces + return csr_matrix((0, 0)) + + #-------------------------------------------------------------------------- + # 1) Build adjacency based on shared edges. + # (Same logic as the original code, plus import statements.) + #-------------------------------------------------------------------------- + edge_to_faces = defaultdict(list) + uf = UnionFind(num_faces) + for f_idx, (v0, v1, v2) in enumerate(face_list): + # Sort each edge’s endpoints so (i, j) == (j, i) + edges = [ + tuple(sorted((v0, v1))), + tuple(sorted((v1, v2))), + tuple(sorted((v2, v0))) + ] + for e in edges: + edge_to_faces[e].append(f_idx) + + row = [] + col = [] + for edge, face_indices in edge_to_faces.items(): + unique_faces = list(set(face_indices)) + if len(unique_faces) > 1: + # For every pair of distinct faces that share this edge, + # mark them as mutually adjacent + for i in range(len(unique_faces)): + for j in range(i + 1, len(unique_faces)): + fi = unique_faces[i] + fj = unique_faces[j] + row.append(fi) + col.append(fj) + row.append(fj) + col.append(fi) + uf.union(fi, fj) + + data = np.ones(len(row), dtype=np.int8) + face_adjacency = coo_matrix( + (data, (row, col)), shape=(num_faces, num_faces) + ).tocsr() + + #-------------------------------------------------------------------------- + # 2) Check if the graph from shared edges is already connected. + #-------------------------------------------------------------------------- + n_components = 0 + for i in range(num_faces): + if uf.find(i) == i: + n_components += 1 + print("n_components", n_components) + + if n_components == 1: + # Already a single connected component, no need for dummy edges + return face_adjacency + #-------------------------------------------------------------------------- + # 3) Compute centroids of each face for building a KNN graph. + #-------------------------------------------------------------------------- + face_centroids = [] + for (v0, v1, v2) in face_list: + centroid = (vertices[v0] + vertices[v1] + vertices[v2]) / 3.0 + face_centroids.append(centroid) + face_centroids = np.array(face_centroids) + + #-------------------------------------------------------------------------- + # 4) Build a KNN graph (k=10) over face centroids using scikit‐learn + #-------------------------------------------------------------------------- + knn = NearestNeighbors(n_neighbors=k, algorithm='auto') + knn.fit(face_centroids) + distances, indices = knn.kneighbors(face_centroids) + # 'distances[i]' are the distances from face i to each of its 'k' neighbors + # 'indices[i]' are the face indices of those neighbors + + #-------------------------------------------------------------------------- + # 5) Build a weighted graph in NetworkX using centroid-distances as edges + #-------------------------------------------------------------------------- + G = nx.Graph() + # Add each face as a node in the graph + G.add_nodes_from(range(num_faces)) + + # For each face i, add edges (i -> j) for each neighbor j in the KNN + for i in range(num_faces): + for j, dist in zip(indices[i], distances[i]): + if i == j: + continue # skip self-loop + # Add an undirected edge with 'weight' = distance + # NetworkX handles parallel edges gracefully via last add_edge, + # but it typically overwrites the weight if (i, j) already exists. + G.add_edge(i, j, weight=dist) + + #-------------------------------------------------------------------------- + # 6) Compute MST on that KNN graph + #-------------------------------------------------------------------------- + mst = nx.minimum_spanning_tree(G, weight='weight') + # Sort MST edges by ascending weight, so we add the shortest edges first + mst_edges_sorted = sorted( + mst.edges(data=True), key=lambda e: e[2]['weight'] + ) + print("mst edges sorted", len(mst_edges_sorted)) + #-------------------------------------------------------------------------- + # 7) Use a union-find structure to add MST edges only if they + # connect two currently disconnected components of the adjacency matrix + #-------------------------------------------------------------------------- + + # Convert face_adjacency to LIL format for efficient edge addition + adjacency_lil = face_adjacency.tolil() + + # Now, step through MST edges in ascending order + for (u, v, attr) in mst_edges_sorted: + if uf.find(u) != uf.find(v): + # These belong to different components, so unify them + uf.union(u, v) + # And add a "dummy" edge to our adjacency matrix + adjacency_lil[u, v] = 1 + adjacency_lil[v, u] = 1 + + # Convert back to CSR format and return + face_adjacency = adjacency_lil.tocsr() + + if with_knn: + print("Adding KNN edges.") + ### Add KNN edges graph too + dummy_row = [] + dummy_col = [] + for i in range(num_faces): + for j in indices[i]: + dummy_row.extend([i, j]) + dummy_col.extend([j, i]) ### duplicates are handled by coo + + dummy_data = np.ones(len(dummy_row), dtype=np.int16) + dummy_mat = coo_matrix( + (dummy_data, (dummy_row, dummy_col)), + shape=(num_faces, num_faces) + ).tocsr() + face_adjacency = face_adjacency + dummy_mat + ########################### + + return face_adjacency + +def construct_face_adjacency_matrix_naive(face_list): + """ + Given a list of faces (each face is a 3-tuple of vertex indices), + construct a face-based adjacency matrix of shape (num_faces, num_faces). + Two faces are adjacent if they share an edge. + + If multiple connected components exist, dummy edges are added to + turn them into a single connected component. Edges are added naively by + randomly selecting a face and connecting consecutive components -- (comp_i, comp_i+1) ... + + Parameters + ---------- + face_list : list of tuples + List of faces, each face is a tuple (v0, v1, v2) of vertex indices. + + Returns + ------- + face_adjacency : scipy.sparse.csr_matrix + A CSR sparse matrix of shape (num_faces, num_faces), + containing 1s for adjacent faces and 0s otherwise. + Additional edges are added if the faces are in multiple components. + """ + + num_faces = len(face_list) + if num_faces == 0: + # Return an empty matrix if no faces + return csr_matrix((0, 0)) + + # Step 1: Map each undirected edge -> list of face indices that contain that edge + edge_to_faces = defaultdict(list) + + # Populate the edge_to_faces dictionary + for f_idx, (v0, v1, v2) in enumerate(face_list): + # For an edge, we always store its endpoints in sorted order + # to avoid duplication (e.g. edge (2,5) is the same as (5,2)). + edges = [ + tuple(sorted((v0, v1))), + tuple(sorted((v1, v2))), + tuple(sorted((v2, v0))) + ] + for e in edges: + edge_to_faces[e].append(f_idx) + + # Step 2: Build the adjacency (row, col) lists among faces + row = [] + col = [] + for e, faces_sharing_e in edge_to_faces.items(): + # If an edge is shared by multiple faces, make each pair of those faces adjacent + f_indices = list(set(faces_sharing_e)) # unique face indices for this edge + if len(f_indices) > 1: + # For each pair of faces, mark them as adjacent + for i in range(len(f_indices)): + for j in range(i + 1, len(f_indices)): + f_i = f_indices[i] + f_j = f_indices[j] + row.append(f_i) + col.append(f_j) + row.append(f_j) + col.append(f_i) + + # Create a COO matrix, then convert it to CSR + data = np.ones(len(row), dtype=np.int8) + face_adjacency = coo_matrix( + (data, (row, col)), + shape=(num_faces, num_faces) + ).tocsr() + + # Step 3: Ensure single connected component + # Use connected_components to see how many components exist + n_components, labels = connected_components(face_adjacency, directed=False) + + if n_components > 1: + # We have multiple components; let's "connect" them via dummy edges + # The simplest approach is to pick one face from each component + # and connect them sequentially to enforce a single component. + component_representatives = [] + + for comp_id in range(n_components): + # indices of faces in this component + faces_in_comp = np.where(labels == comp_id)[0] + if len(faces_in_comp) > 0: + # take the first face in this component as a representative + component_representatives.append(faces_in_comp[0]) + + # Now, add edges between consecutive representatives + dummy_row = [] + dummy_col = [] + for i in range(len(component_representatives) - 1): + f_i = component_representatives[i] + f_j = component_representatives[i + 1] + dummy_row.extend([f_i, f_j]) + dummy_col.extend([f_j, f_i]) + + if dummy_row: + dummy_data = np.ones(len(dummy_row), dtype=np.int8) + dummy_mat = coo_matrix( + (dummy_data, (dummy_row, dummy_col)), + shape=(num_faces, num_faces) + ).tocsr() + face_adjacency = face_adjacency + dummy_mat + + return face_adjacency + +class UnionFind: + def __init__(self, n): + self.parent = list(range(n)) + self.rank = [1] * n + + def find(self, x): + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x, y): + rootX = self.find(x) + rootY = self.find(y) + + if rootX != rootY: + if self.rank[rootX] > self.rank[rootY]: + self.parent[rootY] = rootX + elif self.rank[rootX] < self.rank[rootY]: + self.parent[rootX] = rootY + else: + self.parent[rootY] = rootX + self.rank[rootX] += 1 + +def hierarchical_clustering_labels(children, n_samples, max_cluster=20): + # Union-Find structure to maintain cluster merges + uf = UnionFind(2 * n_samples - 1) # We may need to store up to 2*n_samples - 1 clusters + + current_cluster_count = n_samples + + # Process merges from the children array + hierarchical_labels = [] + for i, (child1, child2) in enumerate(children): + uf.union(child1, i + n_samples) + uf.union(child2, i + n_samples) + #uf.union(child1, child2) + current_cluster_count -= 1 # After each merge, we reduce the cluster count + + if current_cluster_count <= max_cluster: + labels = [uf.find(i) for i in range(n_samples)] + hierarchical_labels.append(labels) + + return hierarchical_labels + +def load_ply_to_numpy(filename): + """ + Load a PLY file and extract the point cloud as a (N, 3) NumPy array. + + Parameters: + filename (str): Path to the PLY file. + + Returns: + numpy.ndarray: Point cloud array of shape (N, 3). + """ + # Read PLY file + ply_data = PlyData.read(filename) + + # Extract vertex data + vertex_data = ply_data["vertex"] + + # Convert to NumPy array (x, y, z) + points = np.vstack([vertex_data["x"], vertex_data["y"], vertex_data["z"]]).T + + return points + +def solve_clustering(input_fname, uid, view_id, save_dir="test_results1", out_render_fol= "test_render_clustering", use_agglo=False, max_num_clusters=18, is_pc=False, option=1, with_knn=True, export_mesh=True): + print(uid, view_id) + + if not is_pc: + input_fname = f'{save_dir}/input_{uid}_{view_id}.ply' + mesh = load_mesh_util(input_fname) + + else: + pc = load_ply_to_numpy(input_fname) + + ### Load inferred PartField features + try: + point_feat = np.load(f'{save_dir}/part_feat_{uid}_{view_id}.npy') + except: + try: + point_feat = np.load(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy') + + except: + print() + print("pointfeat loading error. skipping...") + print(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy') + return + + point_feat = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) + + if not use_agglo: + for num_cluster in range(2, max_num_clusters): + clustering = KMeans(n_clusters=num_cluster, random_state=0).fit(point_feat) + labels = clustering.labels_ + + + pred_labels = np.zeros((len(labels), 1)) + for i, label in enumerate(np.unique(labels)): + # print(i, label) + pred_labels[labels == label] = i # Assign RGB values to each label + + fname_clustering = os.path.join(out_render_fol, "cluster_out", str(uid) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2)) + np.save(fname_clustering, pred_labels) + + + if not is_pc: + V = mesh.vertices + F = mesh.faces + + if export_mesh : + fname_mesh = os.path.join(out_render_fol, "ply", str(uid) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2) + ".ply") + export_colored_mesh_ply(V, F, pred_labels, filename=fname_mesh) + + + else: + if export_mesh: + fname_pc = os.path.join(out_render_fol, "ply", str(uid) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2) + ".ply") + export_pointcloud_with_labels_to_ply(pc, pred_labels, filename=fname_pc) + + else: + if is_pc: + print("Not implemented error. Agglomerative clustering only for mesh inputs.") + exit() + + if option == 0: + adj_matrix = construct_face_adjacency_matrix_naive(mesh.faces) + elif option == 1: + adj_matrix = construct_face_adjacency_matrix_facemst(mesh.faces, mesh.vertices, with_knn=with_knn) + else: + adj_matrix = construct_face_adjacency_matrix_ccmst(mesh.faces, mesh.vertices, with_knn=with_knn) + + clustering = AgglomerativeClustering(connectivity=adj_matrix, + n_clusters=1, + ).fit(point_feat) + hierarchical_labels = hierarchical_clustering_labels(clustering.children_, point_feat.shape[0], max_cluster=max_num_clusters) + + all_FL = [] + for n_cluster in range(max_num_clusters): + print("Processing cluster: "+str(n_cluster)) + labels = hierarchical_labels[n_cluster] + all_FL.append(labels) + + + all_FL = np.array(all_FL) + unique_labels = np.unique(all_FL) + + for n_cluster in range(max_num_clusters): + FL = all_FL[n_cluster] + relabel = np.zeros((len(FL), 1)) + for i, label in enumerate(unique_labels): + relabel[FL == label] = i # Assign RGB values to each label + + V = mesh.vertices + F = mesh.faces + + if export_mesh : + fname_mesh = os.path.join(out_render_fol, "ply", str(uid) + "_" + str(view_id) + "_" + str(max_num_clusters - n_cluster).zfill(2) + ".ply") + export_colored_mesh_ply(V, F, FL, filename=fname_mesh) + + fname_clustering = os.path.join(out_render_fol, "cluster_out", str(uid) + "_" + str(view_id) + "_" + str(max_num_clusters - n_cluster).zfill(2)) + np.save(fname_clustering, FL) + + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--source_dir', default= "", type=str) + parser.add_argument('--root', default= "", type=str) + parser.add_argument('--dump_dir', default= "", type=str) + + parser.add_argument('--max_num_clusters', default= 20, type=int) + parser.add_argument('--use_agglo', default= False, type=bool) + parser.add_argument('--is_pc', default= False, type=bool) + parser.add_argument('--option', default= 1, type=int) + parser.add_argument('--with_knn', default= False, type=bool) + + parser.add_argument('--export_mesh', default= True, type=bool) + + FLAGS = parser.parse_args() + root = FLAGS.root + OUTPUT_FOL = FLAGS.dump_dir + SOURCE_DIR = FLAGS.source_dir + + MAX_NUM_CLUSTERS = FLAGS.max_num_clusters + USE_AGGLO = FLAGS.use_agglo + IS_PC = FLAGS.is_pc + + OPTION = FLAGS.option + WITH_KNN = FLAGS.with_knn + + EXPORT_MESH = FLAGS.export_mesh + + models = os.listdir(root) + os.makedirs(OUTPUT_FOL, exist_ok=True) + + cluster_fol = os.path.join(OUTPUT_FOL, "cluster_out") + os.makedirs(cluster_fol, exist_ok=True) + + if EXPORT_MESH: + ply_fol = os.path.join(OUTPUT_FOL, "ply") + os.makedirs(ply_fol, exist_ok=True) + + #### Get existing model_ids ### + all_files = os.listdir(os.path.join(OUTPUT_FOL, "ply")) + + existing_model_ids = [] + for sample in all_files: + uid = sample.split("_")[0] + view_id = sample.split("_")[1] + # sample_name = str(uid) + "_" + str(view_id) + sample_name = str(uid) + + if sample_name not in existing_model_ids: + existing_model_ids.append(sample_name) + ############################## + + all_files = os.listdir(SOURCE_DIR) + selected = [] + for f in all_files: + if ".ply" in f and IS_PC and f.split(".")[0] not in existing_model_ids: + selected.append(f) + elif (".obj" in f or ".glb" in f) and not IS_PC and f.split(".")[0] not in existing_model_ids: + selected.append(f) + + print("Number of models to process: " + str(len(selected))) + + for model in selected: + fname = os.path.join(SOURCE_DIR, model) + uid = model.split(".")[-2] + view_id = 0 + + solve_clustering(fname, uid, view_id, save_dir=root, out_render_fol= OUTPUT_FOL, use_agglo=USE_AGGLO, max_num_clusters=MAX_NUM_CLUSTERS, is_pc=IS_PC, option=OPTION, with_knn=WITH_KNN, export_mesh=EXPORT_MESH) diff --git a/PartField/run_part_clustering_remesh.py b/PartField/run_part_clustering_remesh.py new file mode 100644 index 0000000000000000000000000000000000000000..098df0dcb0b4fe5accff91e5e6e7ffb926bda2ce --- /dev/null +++ b/PartField/run_part_clustering_remesh.py @@ -0,0 +1,439 @@ +from sklearn.cluster import AgglomerativeClustering, KMeans +import numpy as np +import trimesh +import matplotlib.pyplot as plt +import numpy as np +import os +import argparse +import time + +import json +from os.path import join +from typing import List + +from collections import defaultdict +from scipy.sparse import coo_matrix, csr_matrix +from scipy.sparse.csgraph import connected_components + +from plyfile import PlyData +import open3d as o3d + +from scipy.spatial import cKDTree +from collections import Counter +from partfield.utils import * + +#### Export to file ##### +def export_colored_mesh_ply(V, F, FL, filename='segmented_mesh.ply'): + """ + Export a mesh with per-face segmentation labels into a colored PLY file. + + Parameters: + - V (np.ndarray): Vertices array of shape (N, 3) + - F (np.ndarray): Faces array of shape (M, 3) + - FL (np.ndarray): Face labels of shape (M,) + - filename (str): Output filename + """ + assert V.shape[1] == 3 + assert F.shape[1] == 3 + assert F.shape[0] == FL.shape[0] + + # Generate distinct colors for each unique label + unique_labels = np.unique(FL) + colormap = plt.cm.get_cmap("tab20", len(unique_labels)) + label_to_color = { + label: (np.array(colormap(i)[:3]) * 255).astype(np.uint8) + for i, label in enumerate(unique_labels) + } + + mesh = trimesh.Trimesh(vertices=V, faces=F) + FL = np.squeeze(FL) + for i, face in enumerate(F): + label = FL[i] + color = label_to_color[label] + color_with_alpha = np.append(color, 255) # Add alpha value + mesh.visual.face_colors[i] = color_with_alpha + + mesh.export(filename) + print(f"Exported mesh to {filename}") + +def export_pointcloud_with_labels_to_ply(V, VL, filename='colored_pointcloud.ply'): + """ + Export a labeled point cloud to a PLY file with vertex colors. + + Parameters: + - V: (N, 3) numpy array of XYZ coordinates + - VL: (N,) numpy array of integer labels + - filename: Output PLY file name + """ + assert V.shape[0] == VL.shape[0], "Number of vertices and labels must match" + + # Generate unique colors for each label + unique_labels = np.unique(VL) + colormap = plt.cm.get_cmap("tab20", len(unique_labels)) + label_to_color = { + label: colormap(i)[:3] for i, label in enumerate(unique_labels) + } + + VL = np.squeeze(VL) + # Map labels to RGB colors + colors = np.array([label_to_color[label] for label in VL]) + + # Open3D requires colors in float [0, 1] + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(V) + pcd.colors = o3d.utility.Vector3dVector(colors) + + # Save to .ply + o3d.io.write_point_cloud(filename, pcd) + print(f"Point cloud saved to {filename}") +######################### + +def construct_face_adjacency_matrix(face_list): + """ + Given a list of faces (each face is a 3-tuple of vertex indices), + construct a face-based adjacency matrix of shape (num_faces, num_faces). + Two faces are adjacent if they share an edge. + + If multiple connected components exist, dummy edges are added to + turn them into a single connected component. + + Parameters + ---------- + face_list : list of tuples + List of faces, each face is a tuple (v0, v1, v2) of vertex indices. + + Returns + ------- + face_adjacency : scipy.sparse.csr_matrix + A CSR sparse matrix of shape (num_faces, num_faces), + containing 1s for adjacent faces and 0s otherwise. + Additional edges are added if the faces are in multiple components. + """ + + num_faces = len(face_list) + if num_faces == 0: + # Return an empty matrix if no faces + return csr_matrix((0, 0)) + + # Step 1: Map each undirected edge -> list of face indices that contain that edge + edge_to_faces = defaultdict(list) + + # Populate the edge_to_faces dictionary + for f_idx, (v0, v1, v2) in enumerate(face_list): + # For an edge, we always store its endpoints in sorted order + # to avoid duplication (e.g. edge (2,5) is the same as (5,2)). + edges = [ + tuple(sorted((v0, v1))), + tuple(sorted((v1, v2))), + tuple(sorted((v2, v0))) + ] + for e in edges: + edge_to_faces[e].append(f_idx) + + # Step 2: Build the adjacency (row, col) lists among faces + row = [] + col = [] + for e, faces_sharing_e in edge_to_faces.items(): + # If an edge is shared by multiple faces, make each pair of those faces adjacent + f_indices = list(set(faces_sharing_e)) # unique face indices for this edge + if len(f_indices) > 1: + # For each pair of faces, mark them as adjacent + for i in range(len(f_indices)): + for j in range(i + 1, len(f_indices)): + f_i = f_indices[i] + f_j = f_indices[j] + row.append(f_i) + col.append(f_j) + row.append(f_j) + col.append(f_i) + + # Create a COO matrix, then convert it to CSR + data = np.ones(len(row), dtype=np.int8) + face_adjacency = coo_matrix( + (data, (row, col)), + shape=(num_faces, num_faces) + ).tocsr() + + return face_adjacency + + +def relabel_coarse_mesh(dense_mesh, dense_labels, coarse_mesh): + """ + Relabels a coarse mesh using voting from a dense mesh, where every dense face gets to vote. + + Parameters: + dense_mesh (trimesh.Trimesh): High-resolution input mesh. + dense_labels (numpy.ndarray): Per-face labels for the dense mesh (shape: (N_dense_faces,)). + coarse_mesh (trimesh.Trimesh): Coarser mesh to be relabeled. + + Returns: + numpy.ndarray: New labels for the coarse mesh (shape: (N_coarse_faces,)). + """ + # Compute centroids for both dense and coarse mesh faces + dense_centroids = dense_mesh.vertices[dense_mesh.faces].mean(axis=1) # (N_dense_faces, 3) + coarse_centroids = coarse_mesh.vertices[coarse_mesh.faces].mean(axis=1) # (N_coarse_faces, 3) + + # Use KDTree to efficiently find nearest coarse face for each dense face + tree = cKDTree(coarse_centroids) + _, nearest_coarse_faces = tree.query(dense_centroids) # (N_dense_faces,) + + # Initialize label votes per coarse face + face_label_votes = {i: [] for i in range(len(coarse_mesh.faces))} + + # Every dense face votes for its nearest coarse face + dense_labels += 1 + for dense_face_idx, coarse_face_idx in enumerate(nearest_coarse_faces): + face_label_votes[coarse_face_idx].append(dense_labels[dense_face_idx]) + + # Assign new labels based on majority voting + coarse_labels = np.zeros(len(coarse_mesh.faces), dtype=np.int32) + + for i, votes in face_label_votes.items(): + if votes: # If this coarse face received votes + most_common_label = Counter(votes).most_common(1)[0][0] + coarse_labels[i] = most_common_label + else: + coarse_labels[i] = 0 # Mark as unassigned (optional) + + return coarse_labels + +class UnionFind: + def __init__(self, n): + self.parent = list(range(n)) + self.rank = [1] * n + + def find(self, x): + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x, y): + rootX = self.find(x) + rootY = self.find(y) + + if rootX != rootY: + if self.rank[rootX] > self.rank[rootY]: + self.parent[rootY] = rootX + elif self.rank[rootX] < self.rank[rootY]: + self.parent[rootX] = rootY + else: + self.parent[rootY] = rootX + self.rank[rootX] += 1 + +def hierarchical_clustering_labels(children, n_samples, max_cluster=20): + # Union-Find structure to maintain cluster merges + uf = UnionFind(2 * n_samples - 1) # We may need to store up to 2*n_samples - 1 clusters + + current_cluster_count = n_samples + + # Process merges from the children array + hierarchical_labels = [] + for i, (child1, child2) in enumerate(children): + uf.union(child1, i + n_samples) + uf.union(child2, i + n_samples) + #uf.union(child1, child2) + current_cluster_count -= 1 # After each merge, we reduce the cluster count + + if current_cluster_count <= max_cluster: + labels = [uf.find(i) for i in range(n_samples)] + hierarchical_labels.append(labels) + + return hierarchical_labels + +def load_ply_to_numpy(filename): + """ + Load a PLY file and extract the point cloud as a (N, 3) NumPy array. + + Parameters: + filename (str): Path to the PLY file. + + Returns: + numpy.ndarray: Point cloud array of shape (N, 3). + """ + # Read PLY file + ply_data = PlyData.read(filename) + + # Extract vertex data + vertex_data = ply_data["vertex"] + + # Convert to NumPy array (x, y, z) + points = np.vstack([vertex_data["x"], vertex_data["y"], vertex_data["z"]]).T + + return points + +def solve_clustering(input_fname, uid, view_id, save_dir="test_results1", max_cluster=20, out_render_fol= "test_render_clustering", filehandle=None, use_agglo=False, max_num_clusters=18, viz_dense=False, export_mesh=True): + print(uid, view_id) + + try: + mesh_fname = f'{save_dir}/feat_pca_{uid}_{view_id}.ply' + dense_mesh = load_mesh_util(mesh_fname) + except: + mesh_fname = f'{save_dir}/feat_pca_{uid}_{view_id}_batch.ply' + dense_mesh = load_mesh_util(mesh_fname) + + vertices = dense_mesh.vertices + bbmin = vertices.min(0) + bbmax = vertices.max(0) + center = (bbmin + bbmax) * 0.5 + scale = 2.0 * 0.9 / (bbmax - bbmin).max() + vertices = (vertices - center) * scale + dense_mesh.vertices = vertices + + ### Load coarse mesh + input_fname = f'{save_dir}/input_{uid}_{view_id}.ply' + coarse_mesh = trimesh.load(input_fname, force='mesh') + vertices = coarse_mesh.vertices + + bbmin = vertices.min(0) + bbmax = vertices.max(0) + center = (bbmin + bbmax) * 0.5 + scale = 2.0 * 0.9 / (bbmax - bbmin).max() + vertices = (vertices - center) * scale + coarse_mesh.vertices = vertices + ##################### + + try: + point_feat = np.load(f'{save_dir}/part_feat_{uid}_{view_id}.npy') + except: + try: + point_feat = np.load(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy') + + except: + print() + print("pointfeat loading error. skipping...") + print(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy') + return + + point_feat = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True) + + if not use_agglo: + for num_cluster in range(2, max_num_clusters): + clustering = KMeans(n_clusters=num_cluster, random_state=0).fit(point_feat) + labels = clustering.labels_ + + if not viz_dense: + #### Relabel coarse from dense #### + labels = relabel_coarse_mesh(dense_mesh, labels, coarse_mesh) + V = coarse_mesh.vertices + F = coarse_mesh.faces + ################################### + else: + V = dense_mesh.vertices + F = dense_mesh.faces + + pred_labels = np.zeros((len(labels), 1)) + for i, label in enumerate(np.unique(labels)): + # print(i, label) + pred_labels[labels == label] = i # Assign RGB values to each label + + + fname = str(uid) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2) + fname_clustering = os.path.join(out_render_fol, "cluster_out", str(uid) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2)) + np.save(fname_clustering, pred_labels) + + if export_mesh : + fname_mesh = os.path.join(out_render_fol, "ply", str(uid) + "_" + str(view_id) + "_" + str(num_cluster).zfill(2) + ".ply") + export_colored_mesh_ply(V, F, pred_labels, filename=fname_mesh) + + else: + + adj_matrix = construct_face_adjacency_matrix(dense_mesh.faces) + clustering = AgglomerativeClustering(connectivity=adj_matrix, + n_clusters=1, + ).fit(point_feat) + hierarchical_labels = hierarchical_clustering_labels(clustering.children_, point_feat.shape[0], max_cluster=max_num_clusters) + + all_FL = [] + for n_cluster in range(max_num_clusters): + print("Processing cluster: "+str(n_cluster)) + labels = hierarchical_labels[n_cluster] + all_FL.append(labels) + + + all_FL = np.array(all_FL) + unique_labels = np.unique(all_FL) + + for n_cluster in range(max_num_clusters): + FL = all_FL[n_cluster] + + if not viz_dense: + #### Relabel coarse from dense #### + FL = relabel_coarse_mesh(dense_mesh, FL, coarse_mesh) + V = coarse_mesh.vertices + F = coarse_mesh.faces + ################################### + else: + V = dense_mesh.vertices + F = dense_mesh.faces + + unique_labels = np.unique(FL) + relabel = np.zeros((len(FL), 1)) + for i, label in enumerate(unique_labels): + relabel[FL == label] = i # Assign RGB values to each label + + if export_mesh : + fname_mesh = os.path.join(out_render_fol, "ply", str(uid) + "_" + str(view_id) + "_" + str(max_cluster - n_cluster).zfill(2) + ".ply") + export_colored_mesh_ply(V, F, FL, filename=fname_mesh) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--source_dir', default= "", type=str) + parser.add_argument('--root', default= "", type=str) + parser.add_argument('--dump_dir', default= "", type=str) + + parser.add_argument('--max_num_clusters', default= 18, type=int) + parser.add_argument('--use_agglo', default= True, type=bool) + parser.add_argument('--viz_dense', default= False, type=bool) + parser.add_argument('--export_mesh', default= True, type=bool) + + + FLAGS = parser.parse_args() + root = FLAGS.root + OUTPUT_FOL = FLAGS.dump_dir + SOURCE_DIR = FLAGS.source_dir + + MAX_NUM_CLUSTERS = FLAGS.max_num_clusters + USE_AGGLO = FLAGS.use_agglo + EXPORT_MESH = FLAGS.export_mesh + + models = os.listdir(root) + os.makedirs(OUTPUT_FOL, exist_ok=True) + + if EXPORT_MESH: + ply_fol = os.path.join(OUTPUT_FOL, "ply") + os.makedirs(ply_fol, exist_ok=True) + + cluster_fol = os.path.join(OUTPUT_FOL, "cluster_out") + os.makedirs(cluster_fol, exist_ok=True) + + #### Get existing model_ids ### + all_files = os.listdir(os.path.join(OUTPUT_FOL, "ply")) + + existing_model_ids = [] + for sample in all_files: + uid = sample.split("_")[0] + view_id = sample.split("_")[1] + # sample_name = str(uid) + "_" + str(view_id) + sample_name = str(uid) + + if sample_name not in existing_model_ids: + existing_model_ids.append(sample_name) + ############################## + + all_files = os.listdir(SOURCE_DIR) + selected = [] + for f in all_files: + if (".obj" in f or ".glb" in f) and f.split(".")[0] not in existing_model_ids: + selected.append(f) + + print("Number of models to process: " + str(len(selected))) + + + for model in selected: + fname = os.path.join(SOURCE_DIR, model) + uid = model.split(".")[-2] + view_id = 0 + + solve_clustering(fname, uid, view_id, save_dir=root, out_render_fol= OUTPUT_FOL, use_agglo=USE_AGGLO, max_num_clusters=MAX_NUM_CLUSTERS, viz_dense=FLAGS.viz_dense, export_mesh=EXPORT_MESH) \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..bd78ab6e2a48872b99c0c4995dc5eacbddb7d798 --- /dev/null +++ b/app.py @@ -0,0 +1,404 @@ +import os +import gradio as gr +import tempfile +import shutil +import torch +from omegaconf import OmegaConf +import trimesh +from pathlib import Path +from huggingface_hub import hf_hub_download +import zipfile +from datetime import datetime + +from infer_asset import infer_single_asset +from particulate.models import Articulate3D_B +from particulate.data_utils import load_obj_raw_preserve +from particulate.export_utils import export_urdf, export_mjcf +from particulate.visualization_utils import plot_mesh +from yacs.config import CfgNode +torch.serialization.add_safe_globals([CfgNode]) + + +class ParticulateApp: + """ + Main application class for Particulate with Gradio interface. + """ + def __init__(self, model_config_path: str, output_dir: str): + self.model_config = OmegaConf.load(model_config_path) + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + self.model = Articulate3D_B(**self.model_config) + self.model.eval() + + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cpu": + print("WARNING: CUDA is not available. This application requires CUDA for full functionality as infer_asset.py assumes CUDA.") + # We attempt to use CUDA anyway because infer_asset.py hardcodes it in prepare_inputs + + print("Downloading/Loading model from Hugging Face...") + self.model_checkpoint = hf_hub_download(repo_id="rayli/Particulate", filename="model.pt") + self.model.load_state_dict(torch.load(self.model_checkpoint, map_location="cuda" if torch.cuda.is_available() else "cpu")) + self.model.to('cuda') + + model_dir = os.path.join("PartField", "model") + os.makedirs(model_dir, exist_ok=True) + hf_hub_download(repo_id="mikaelaangel/partfield-ckpt", filename="model_objaverse.ckpt", local_dir=model_dir) + print("Model loaded successfully.") + + def visualize_mesh(self, input_mesh_path): + if input_mesh_path is None: + return None, None + + # Handle Gradio file object (dict) or file path (string) + if isinstance(input_mesh_path, dict): + file_path = input_mesh_path.get("path") or input_mesh_path.get("name") + else: + file_path = input_mesh_path + + print(f"Visualizing mesh from: {file_path}") + if file_path.endswith(".obj"): + verts, faces = load_obj_raw_preserve(Path(file_path)) + mesh = trimesh.Trimesh(vertices=verts, faces=faces) + else: + mesh = trimesh.load(file_path, process=False) + if isinstance(mesh, trimesh.Scene): + mesh = trimesh.util.concatenate(mesh.geometry.values()) + + return plot_mesh(mesh), mesh + + def predict( + self, + mesh, + min_part_confidence, + num_points, + up_dir, + animation_frames, + strict, + ): + if mesh is None: + return None, "Please upload a 3D model." + + with tempfile.TemporaryDirectory() as temp_dir: + try: + ( + mesh_parts_original, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range + ) = infer_single_asset( + mesh=mesh, + up_dir=up_dir, + model=self.model, + num_points=int(num_points), + strict=strict, + output_path=temp_dir, + animation_frames=int(animation_frames), + min_part_confidence=min_part_confidence, + ) + + animated_glb_file = os.path.join(temp_dir, "animated_textured.glb") + prediction_file = os.path.join(temp_dir, "mesh_parts_with_axes.glb") + + if os.path.exists(animated_glb_file) and os.path.exists(prediction_file): + # Copy to a persistent location in the output directory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + dest_animated_glb_file = os.path.join(self.output_dir, f"animated_textured_{timestamp}.glb") + dest_prediction_file = os.path.join(self.output_dir, f"mesh_parts_with_axes_{timestamp}.glb") + shutil.copy(animated_glb_file, dest_animated_glb_file) + shutil.copy(prediction_file, dest_prediction_file) + return ( + dest_animated_glb_file, dest_prediction_file, f"Success!", + mesh_parts_original, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range + ) + else: + return ( + None, None, f"No output file generated.", + *[None] * 9 + ) + + except Exception as e: + import traceback + traceback.print_exc() + return ( + None, None, f"Error: {str(e)}", + *[None] * 9 + ) + + def export_urdf( + self, + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range + ): + if mesh_parts is None: + return None, "Please run inference first." + try: + with tempfile.TemporaryDirectory() as temp_dir: + export_urdf( + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range, + output_path=os.path.join(temp_dir, "urdf", "model.urdf"), + name="model" + ) + + # Zip the output directory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + with zipfile.ZipFile(os.path.join(self.output_dir, f"urdf_{timestamp}.zip"), "w") as zipf: + for root, dirs, files in os.walk(os.path.join(temp_dir, "urdf")): + for file in files: + zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(temp_dir, "urdf"))) + return os.path.join(self.output_dir, f"urdf_{timestamp}.zip"), "Success!" + + except Exception as e: + print(f"Error exporting URDF: {e}") + import traceback + traceback.print_exc() + return None, f"Error exporting URDF: {str(e)}" + + def export_mjcf( + self, + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range + ): + if mesh_parts is None: + return None, "Please run inference first." + try: + with tempfile.TemporaryDirectory() as temp_dir: + export_mjcf( + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range, + output_path=os.path.join(temp_dir, "mjcf", "model.xml"), + name="model" + ) + + # Zip the output directory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + with zipfile.ZipFile(os.path.join(self.output_dir, f"mjcf_{timestamp}.zip"), "w") as zipf: + for root, dirs, files in os.walk(os.path.join(temp_dir, "mjcf")): + for file in files: + zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(temp_dir, "mjcf"))) + return os.path.join(self.output_dir, f"mjcf_{timestamp}.zip"), "Success!" + + except Exception as e: + print(f"Error exporting MJCF: {e}") + import traceback + traceback.print_exc() + return None, f"Error exporting MJCF: {str(e)}" + +def create_gradio_app(particulate_app): + # Get example files from examples folder + examples_dir = "examples" + example_files = [] + if os.path.exists(examples_dir): + for file in os.listdir(examples_dir): + if file.lower().endswith(('.glb', '.obj')): + example_files.append(os.path.join(examples_dir, file)) + example_files.sort() # Sort for consistent ordering + + with gr.Blocks(title="Particulate Demo") as demo: + gr.HTML( + """ +

Particulate: Feed-Forward 3D Object Articulation

+

+ 🌟 GitHub Repository | + 🚀 Project Page +

+
+

Upload a 3D model (.obj or .glb format supported) to articulate it. Particulate takes this model and predicts the underlying articulated structure, which can be directly exported to URDF or MJCF format.

+

Getting Started:

+
    +
  1. Upload a 3D model.
  2. +
  3. Preview: Your uploaded 3D model will be visualized below.
  4. +
  5. Confirm Orientation: Select the direction (one of X, -X, Y, -Y, Z, -Z) that corresponds to the up direction of the object in the preview (for all example assets, the up direction is -Z).
  6. +
  7. Run Inference: Click the "Run Inference" button to start the inference process.
  8. +
  9. Visualization: The articulated 3D model with animation and model prediction (3D part segmentation, motion types and axes) will appear on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file.
  10. +
  11. + Adjust Inference Parameters (Optional): + You can potentially obtain better results by adjusting the following parameters: +
      +
    • Min Part Confidence: Increasing this value will merge parts that have low confidence scores to other parts. Consider increasing this value if the prediction is over segmented.
    • +
    • Refine with Connected Components: If toggled on, the prediction will be post-processed to ensure that each articulated part is a union of different connected components in the original mesh (i.e., no connected components are split across parts). Toggle this on (default) if the input mesh has clean connected components.
    • +
    • Normally, you should not need to change the other parameters.
    • +
    +
  12. +
+
+ """ + ) + + loaded_mesh = gr.State(None) + mesh_parts = gr.State(None) + unique_part_ids = gr.State(None) + motion_hierarchy = gr.State(None) + is_part_revolute = gr.State(None) + is_part_prismatic = gr.State(None) + revolute_plucker = gr.State(None) + revolute_range = gr.State(None) + prismatic_axis = gr.State(None) + prismatic_range = gr.State(None) + + with gr.Row(): + with gr.Column(scale=1): + input_mesh = gr.Model3D( + label="Upload 3D Model", + interactive=True + ) + if example_files: + gr.Examples( + examples=[[file] for file in example_files], + inputs=[input_mesh], + label="Example Models" + ) + mesh_plot = gr.Plot(label="Mesh Preview") + + with gr.Accordion("Inference Parameters", open=True): + with gr.Row(): + up_dir = gr.Radio(choices=["X", "Y", "Z", "-X", "-Y", "-Z"], value="-Z", label="Up Direction (Select after viewing plot)") + animation_frames = gr.Number(value=50, label="Animation Frames", precision=0) + + with gr.Row(): + num_points = gr.Number(value=102400, label="Number of Points", precision=0, minimum=2048, maximum=102400) + + min_part_confidence = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, label="Min Part Confidence") + + with gr.Row(): + strict = gr.Checkbox(label="Refine with Connected Components", value=True) + + run_btn = gr.Button("Run Inference", variant="primary") + + with gr.Column(scale=2): + animated_model = gr.Model3D(label="Animated 3D Model") + prediction_model = gr.Model3D(label="Visualization of Model Prediction") + status_text = gr.Textbox(label="Status") + + with gr.Row(): + urdf_btn = gr.Button("Export URDF") + mjcf_btn = gr.Button("Export MJCF") + + with gr.Row(): + urdf_status = gr.Textbox(label="URDF Status") + mjcf_status = gr.Textbox(label="MJCF Status") + with gr.Row(): + urdf_file = gr.File(label="URDF Zip File") + mjcf_file = gr.File(label="MJCF Zip File") + + # Event triggers + input_mesh.change( + fn=particulate_app.visualize_mesh, + inputs=[input_mesh], + outputs=[mesh_plot, loaded_mesh] + ) + + run_btn.click( + fn=particulate_app.predict, + inputs=[ + loaded_mesh, + min_part_confidence, + num_points, + up_dir, + animation_frames, + strict + ], + outputs=[ + animated_model, prediction_model, status_text, + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range + ] + ) + + urdf_btn.click( + fn=particulate_app.export_urdf, + inputs=[ + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range + ], + outputs=[urdf_file, urdf_status] + ) + + mjcf_btn.click( + fn=particulate_app.export_mjcf, + inputs=[ + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range + ], + outputs=[mjcf_file, mjcf_status] + ) + + return demo + +if __name__ == "__main__": + output_dir = "gradio_outputs" + + # Load model configuration + model_config_path = "configs/particulate-B.yaml" + + # Initialize app + print("Initializing Particulate App...") + app = ParticulateApp(model_config_path, output_dir) + + # Create and launch Gradio demo + demo = create_gradio_app(app) + print("Launching Gradio server...") + demo.launch(server_name="0.0.0.0", server_port=7860, share=True) diff --git a/configs/particulate-B.yaml b/configs/particulate-B.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5bea6c04b233f7c6bb69b46dba2b44f2bdb77e4e --- /dev/null +++ b/configs/particulate-B.yaml @@ -0,0 +1,10 @@ +input_dim: 448 +dropout: 0.1 +use_normals: true +use_text_prompts: false +max_parts: 16 +use_part_id_embedding: true +use_raw_coords: true +use_point_features_for_motion_decoding: false +num_mask_hypotheses: 1 +motion_representation: per_point_closest \ No newline at end of file diff --git a/examples/cabinet.glb b/examples/cabinet.glb new file mode 100644 index 0000000000000000000000000000000000000000..89d7094e9dbceba40939c672a1c7b1b2c39b9a81 --- /dev/null +++ b/examples/cabinet.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:296481ebdc263da8cf779602cfcb5ec89519b2d1dbea4d6afd0093ca7ed4ccb1 +size 27011396 diff --git a/examples/eyeglasses.glb b/examples/eyeglasses.glb new file mode 100644 index 0000000000000000000000000000000000000000..9e542a7f9805f2ea84cf05abf0f3e91b9a6a2503 --- /dev/null +++ b/examples/eyeglasses.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3cc580444996fa7ecb61d560480678862683e9ded586bb3fe626479f1bcf6a1d +size 16083668 diff --git a/examples/foldingchair.glb b/examples/foldingchair.glb new file mode 100644 index 0000000000000000000000000000000000000000..bf9971dcb4b68e9f87701386fa96452c03ba3725 --- /dev/null +++ b/examples/foldingchair.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ff7f2c815b1a78c7afef326c5550b0e32d44ae8359d63397cfc1b304eb50d08 +size 27675292 diff --git a/examples/laptop.glb b/examples/laptop.glb new file mode 100644 index 0000000000000000000000000000000000000000..eff0bcfc0911157ffcc2af14568c9d32658243d4 --- /dev/null +++ b/examples/laptop.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de967f39f39a32faeeb3f7abfe3a62c0e7595f5a14244d73f5775230be582563 +size 27460156 diff --git a/examples/scissors.glb b/examples/scissors.glb new file mode 100644 index 0000000000000000000000000000000000000000..451fe0e88d9464f17db98c4bd7d192ae6513e080 --- /dev/null +++ b/examples/scissors.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d679d51b5c48de754ea6ef31a4aee9a2ca78a7618604ebf4f23b02acee13d69d +size 20856212 diff --git a/examples/toilet.glb b/examples/toilet.glb new file mode 100644 index 0000000000000000000000000000000000000000..5940a79c241f8cc0585a3d4cce3db6a2522b465b --- /dev/null +++ b/examples/toilet.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cf9124bf1a9c5eac38c66aff8ddf7bbe9bfb9d65153a64c09541ba4e51ca047 +size 11346860 diff --git a/examples/trashcan.glb b/examples/trashcan.glb new file mode 100644 index 0000000000000000000000000000000000000000..e2b60a5c8b78dcbd7f04d2fc8f288895ff635bb4 --- /dev/null +++ b/examples/trashcan.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ccfb8024b37faff562ffbc12a0a5e18aef90f63859e0d90aafed6baf7bdb140c +size 18751768 diff --git a/examples/washingmachine.glb b/examples/washingmachine.glb new file mode 100644 index 0000000000000000000000000000000000000000..4e75a5fa8e3ca5d8f16a39f548422a947d345f4e --- /dev/null +++ b/examples/washingmachine.glb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:489b6ae5728de26689c52b7a9715509d5a93a906ac7b27e6f8d44c07dbde5c76 +size 26462956 diff --git a/infer_asset.py b/infer_asset.py new file mode 100644 index 0000000000000000000000000000000000000000..9910bf3833a7cfb445cdfe875c0306a35226e720 --- /dev/null +++ b/infer_asset.py @@ -0,0 +1,575 @@ +import numpy as np +import os +from pathlib import Path +import torch +import trimesh + +from particulate.visualization_utils import ( + get_3D_arrow_on_points, + create_arrow, + create_ring, + create_textured_mesh_parts, + ARROW_COLOR_REVOLUTE, + ARROW_COLOR_PRISMATIC +) +from particulate.articulation_utils import plucker_to_axis_point +from particulate.export_utils import ( + export_animated_glb_file, + export_urdf, + export_mjcf +) +from partfield_utils import obtain_partfield_feats, get_partfield_model + + +DATA_CONFIG = { + 'sharp_point_ratio': 0.5, + 'normalize_points': True +} + + +def sharp_sample_pointcloud(mesh, num_points: int = 8192): + V = mesh.vertices + N = mesh.face_normals + F = mesh.faces + + # Build edge-to-faces mapping + # Each edge is represented as (min_vertex_id, max_vertex_id) to ensure consistent ordering + edge_to_faces = {} + + for face_idx in range(len(F)): + face = F[face_idx] + # Get the three edges of this face + edges = [ + (face[0], face[1]), + (face[1], face[2]), + (face[2], face[0]) + ] + + for edge in edges: + # Normalize edge ordering (min vertex first) + edge_key = tuple(sorted(edge)) + if edge_key not in edge_to_faces: + edge_to_faces[edge_key] = [] + edge_to_faces[edge_key].append(face_idx) + + # Identify sharp edges based on face normal angles and store their averaged normals + sharp_edges = [] + sharp_edge_normals = [] + sharp_edge_faces = [] # Store adjacent faces for each sharp edge + cos_30 = np.cos(np.radians(30)) # ≈ 0.866 + cos_150 = np.cos(np.radians(150)) # ≈ -0.866 + + for edge_key, face_indices in edge_to_faces.items(): + # Check if edge has >= 2 faces + if len(face_indices) < 2: + continue + + # Check all pairs of face normals + is_sharp = False + for i in range(len(face_indices)): + for j in range(i + 1, len(face_indices)): + n1 = N[face_indices[i]] + n2 = N[face_indices[j]] + dot_product = np.dot(n1, n2) + + # Check if angle is between 30 and 150 degrees + if cos_150 < dot_product < cos_30 and np.linalg.norm(n1) > 1e-8 and np.linalg.norm(n2) > 1e-8: + is_sharp = True + sharp_edges.append(edge_key) + averaged_normal = (n1 + n2) / 2 + sharp_edge_normals.append(averaged_normal) + sharp_edge_faces.append(face_indices) # Store all adjacent faces + break + if is_sharp: + break + + # Convert sharp edges to vertex arrays + edge_a = np.array([edge[0] for edge in sharp_edges], dtype=np.int32) + edge_b = np.array([edge[1] for edge in sharp_edges], dtype=np.int32) + sharp_edge_normals = np.array(sharp_edge_normals, dtype=np.float64) + + # Handle the case where there are no sharp edges + if len(sharp_edges) == 0: + # Return empty arrays with appropriate shape + samples = np.zeros((0, 3), dtype=np.float64) + normals = np.zeros((0, 3), dtype=np.float64) + edge_indices = np.zeros((0,), dtype=np.int32) + return samples, normals, edge_indices, sharp_edge_faces + + sharp_verts_a = V[edge_a] + sharp_verts_b = V[edge_b] + + weights = np.linalg.norm(sharp_verts_b - sharp_verts_a, axis=-1) + weights /= np.sum(weights) + + random_number = np.random.rand(num_points) + w = np.random.rand(num_points, 1) + index = np.searchsorted(weights.cumsum(), random_number) + samples = w * sharp_verts_a[index] + (1 - w) * sharp_verts_b[index] + normals = sharp_edge_normals[index] # Use the averaged face normal for each edge + return samples, normals, index, sharp_edge_faces + + +def sample_points(mesh, num_points, sharp_point_ratio, at_least_one_point_per_face=False): + """Sample points from mesh using sharp edge and uniform sampling.""" + num_points_sharp_edges = int(num_points * sharp_point_ratio) + num_points_uniform = num_points - num_points_sharp_edges + points_sharp, normals_sharp, edge_indices, sharp_edge_faces = sharp_sample_pointcloud(mesh, num_points_sharp_edges) + + # If no sharp edges were found, sample all points uniformly + if len(points_sharp) == 0 and sharp_point_ratio > 0: + print(f"Warning: No sharp edges found, sampling all points uniformly") + num_points_uniform = num_points + + if at_least_one_point_per_face: + num_faces = len(mesh.faces) + assert num_points_uniform >= num_faces, f"num_points_uniform ({num_points_uniform}) < num_faces ({num_faces})" + + # Get a random permutation of face indices + face_perm = np.random.permutation(num_faces) + + # Sample one point from each face + points_per_face = [] + for face_idx in face_perm: + # Sample one random point on this face using barycentric coordinates + r1, r2 = np.random.random(), np.random.random() + sqrt_r1 = np.sqrt(r1) + # Barycentric coordinates + u = 1 - sqrt_r1 + v = sqrt_r1 * (1 - r2) + w = sqrt_r1 * r2 + + # Get vertices of the face + face = mesh.faces[face_idx] + vertices = mesh.vertices[face] + + # Compute point using barycentric coordinates + point = u * vertices[0] + v * vertices[1] + w * vertices[2] + points_per_face.append(point) + + points_per_face = np.array(points_per_face) + normals_per_face = mesh.face_normals[face_perm] + + # Sample remaining points uniformly + num_remaining_points = num_points_uniform - num_faces + if num_remaining_points > 0: + points_remaining, face_indices_remaining = mesh.sample(num_remaining_points, return_index=True) + normals_remaining = mesh.face_normals[face_indices_remaining] + + points_uniform = np.concatenate([points_per_face, points_remaining], axis=0) + normals_uniform = np.concatenate([normals_per_face, normals_remaining], axis=0) + face_indices = np.concatenate([face_perm, face_indices_remaining], axis=0) + else: + points_uniform = points_per_face + normals_uniform = normals_per_face + face_indices = face_perm + else: + points_uniform, face_indices = mesh.sample(num_points_uniform, return_index=True) + normals_uniform = mesh.face_normals[face_indices] + + points = np.concatenate([points_sharp, points_uniform], axis=0) + normals = np.concatenate([normals_sharp, normals_uniform], axis=0) + sharp_flag = np.concatenate([ + np.ones(len(points_sharp), dtype=np.bool_), + np.zeros(len(points_uniform), dtype=np.bool_) + ], axis=0) + + # For each sharp point, randomly select one of the adjacent faces from the edge + sharp_face_indices = np.zeros(len(points_sharp), dtype=np.int32) + for i, edge_idx in enumerate(edge_indices): + adjacent_faces = sharp_edge_faces[edge_idx] + # Randomly select one of the adjacent faces + sharp_face_indices[i] = np.random.choice(adjacent_faces) + + face_indices = np.concatenate([ + sharp_face_indices, + face_indices + ], axis=0) + + return points, normals, sharp_flag, face_indices + + +def prepare_inputs(mesh, num_points_global: int = 40000, num_points_decode: int = 2048, device: str = "cuda"): + """Prepare inputs from a mesh file for model inference.""" + sharp_point_ratio = DATA_CONFIG['sharp_point_ratio'] + all_points, _, _, _ = sample_points(mesh, num_points_global, sharp_point_ratio) + points, normals, sharp_flag, face_indices = sample_points(mesh, num_points_decode, sharp_point_ratio, at_least_one_point_per_face=True) + + if DATA_CONFIG['normalize_points']: + bbmin = np.concatenate([all_points, points], axis=0).min(0) + bbmax = np.concatenate([all_points, points], axis=0).max(0) + center = (bbmin + bbmax) * 0.5 + scale = 1.0 / (bbmax - bbmin).max() + all_points = (all_points - center) * scale + points = (points - center) * scale + + all_points = torch.from_numpy(all_points).to(device).float().unsqueeze(0) + points = torch.from_numpy(points).to(device).float().unsqueeze(0) + normals = torch.from_numpy(normals).to(device).float().unsqueeze(0) + + partfield_model = get_partfield_model(device=device) + feats = obtain_partfield_feats(partfield_model, all_points, points) + + return dict(xyz=points, normals=normals, feats=feats), sharp_flag, face_indices + + +def refine_part_ids_strict(mesh, face_part_ids): + """ + Refine face part IDs by treating each connected component (CC) in the mesh independently. + For each CC, all faces are labeled with the part ID that has the largest surface area in that CC. + + Args: + mesh: trimesh object + face_part_ids: part ID for each face [num_faces] + + Returns: + refined_face_part_ids: refined part ID for each face [num_faces] + """ + face_part_ids = face_part_ids.copy() # Don't modify the input + + # Use trimesh's built-in connected components functionality + # mesh.face_adjacency gives pairs of face indices that share an edge + mesh_components = trimesh.graph.connected_components( + edges=mesh.face_adjacency, + nodes=np.arange(len(mesh.faces)), + min_len=1 + ) + + # For each connected component, find the part ID with the largest surface area + for component in mesh_components: + if len(component) == 0: + continue + + # Collect part IDs in this component and their surface areas + part_id_areas = {} + for face_idx in component: + part_id = face_part_ids[face_idx] + if part_id == -1: + continue # Skip unassigned faces + + face_area = mesh.area_faces[face_idx] + if part_id not in part_id_areas: + part_id_areas[part_id] = 0.0 + part_id_areas[part_id] += face_area + + # Find the part ID with the largest area + if len(part_id_areas) == 0: + # No valid part IDs in this component, skip + continue + + dominant_part_id = max(part_id_areas.keys(), key=lambda pid: part_id_areas[pid]) + + # Assign all faces in this component to the dominant part ID + for face_idx in component: + face_part_ids[face_idx] = dominant_part_id + + return face_part_ids + + +def compute_part_components_for_mesh_cc(mesh, mesh_cc_faces, current_face_part_ids, face_adjacency_dict): + """ + Compute part-specific connected components for faces in this mesh CC. + Returns a list of dicts with 'faces', 'part_id', and 'area'. + + Two faces are in the same component if: + - They have the same part ID + - They are connected through faces of the same part ID + """ + components = [] + + # Get unique part IDs in this mesh CC + unique_part_ids = np.unique(current_face_part_ids[mesh_cc_faces]) + + for part_id in unique_part_ids: + if part_id == -1: + continue + + # Get faces in this mesh CC with this part ID + mask = current_face_part_ids[mesh_cc_faces] == part_id + faces_with_part = mesh_cc_faces[mask] + + if len(faces_with_part) == 0: + continue + + # Convert to set for faster lookup + faces_with_part_set = set(faces_with_part) + + # Build edges between these faces (both must have same part ID and be adjacent) + edges_for_part = [] + for face_i in faces_with_part: + for face_j in face_adjacency_dict[face_i]: + if face_j in faces_with_part_set: + edges_for_part.append([face_i, face_j]) + + if len(edges_for_part) == 0: + # Each face is its own component + for face_i in faces_with_part: + components.append({ + 'faces': np.array([face_i]), + 'part_id': part_id, + 'area': mesh.area_faces[face_i] + }) + else: + # Find connected components + edges_for_part = np.array(edges_for_part) + comps = trimesh.graph.connected_components( + edges=edges_for_part, + nodes=faces_with_part, + min_len=1 + ) + + for comp in comps: + comp_faces = np.array(list(comp)) + components.append({ + 'faces': comp_faces, + 'part_id': part_id, + 'area': np.sum(mesh.area_faces[comp_faces]) + }) + + return components + + +def refine_part_ids_for_faces(mesh, face_part_ids): + """ + Refine face part IDs to ensure each part ID forms a single connected component. + + For each part ID, if there are multiple disconnected components, the smaller + components (by surface area) are reassigned based on adjacent faces' part IDs. + This is done iteratively until convergence or max iterations. + + Args: + mesh: trimesh object + xyz: sampled points on the mesh [num_points, 3] + part_ids: part IDs for each sampled point [num_points] + face_indices: which face each point lies on (-1 means on edge) [num_points] + face_part_ids: initial part ID for each face [num_faces] + + Returns: + refined_face_part_ids: refined part ID for each face [num_faces] + """ + face_part_ids_final_strict = refine_part_ids_strict(mesh, face_part_ids) + + face_part_ids_final = face_part_ids.copy() # Don't modify the input + + # Step 1: Find connected components of the original mesh (immutable structure) + mesh_components = trimesh.graph.connected_components( + edges=mesh.face_adjacency, + nodes=np.arange(len(mesh.faces)), + min_len=1 + ) + mesh_components = [np.array(list(comp)) for comp in mesh_components] + + # Step 2: Build face adjacency dict (immutable structure) + face_adjacency_dict = {i: set() for i in range(len(mesh.faces))} + for face_i, face_j in mesh.face_adjacency: + face_adjacency_dict[face_i].add(face_j) + face_adjacency_dict[face_j].add(face_i) + + # Step 3: Process each mesh CC independently + for mesh_cc_faces in mesh_components: + done = False + while not done: + comps = compute_part_components_for_mesh_cc(mesh, mesh_cc_faces, face_part_ids_final, face_adjacency_dict) + comps.sort(key=lambda c: c['area']) + + part_id_areas = {} + for comp in comps: + pid = comp['part_id'] + if pid not in part_id_areas: + part_id_areas[pid] = 0.0 + part_id_areas[pid] += comp['area'] + + done = True + for comp_idx in range(len(comps)): + current_part_id = comps[comp_idx]['part_id'] + if len([c for c in comps if c['part_id'] == current_part_id]) > 1: + done = False + # Find adjacent components + adjacent_part_ids = set() + current_faces_set = set(comps[comp_idx]['faces']) + + for face_i in current_faces_set: + for face_j in face_adjacency_dict[face_i]: + if face_j in current_faces_set: + continue + adjacent_part_ids.add(face_part_ids_final[face_j]) + + chosen_part_id = max(adjacent_part_ids, key=lambda x: part_id_areas[x]) + comps[comp_idx]['part_id'] = chosen_part_id + face_part_ids_final[comps[comp_idx]['faces']] = chosen_part_id + break + + return face_part_ids_final_strict, face_part_ids_final + + +def find_part_ids_for_faces(mesh, part_ids, face_indices): + """ + Assign part IDs to each face in the mesh. + + Args: + mesh: trimesh object + xyz: sampled points on the mesh [num_points, 3] + part_ids: part IDs for each sampled point [num_points] + face_indices: which face each point lies on (-1 means on edge) [num_points] + + Returns: + face_part_ids: part ID for each face [num_faces] + """ + num_faces = len(mesh.faces) + face_part_ids = np.full(num_faces, -1, dtype=np.int32) + + # Step 1: Assign part IDs to faces that have points on them + # For each face, collect all points that lie on it and use majority vote + face_to_points = {} + for point_idx, face_idx in enumerate(face_indices): + if face_idx == -1: # Point is on an edge, ignore it + continue + if face_idx not in face_to_points: + face_to_points[face_idx] = [] + face_to_points[face_idx].append(part_ids[point_idx]) + + # Assign part IDs based on majority vote from points + for face_idx, point_part_ids in face_to_points.items(): + # Use bincount to find the majority part ID + counts = np.bincount(point_part_ids) + majority_part_id = np.argmax(counts) + face_part_ids[face_idx] = majority_part_id + + face_part_ids_refined_strict, face_part_ids_refined = refine_part_ids_for_faces(mesh, face_part_ids) + return face_part_ids, face_part_ids_refined_strict, face_part_ids_refined + + +@torch.no_grad() +def infer_single_asset( + mesh, + up_dir, + model, + num_points, + strict, + output_path, + animation_frames, + min_part_confidence=0.0 +): + if up_dir is ["x", "X"]: + rotation_matrix = np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]], dtype=np.float32) + mesh.vertices = mesh.vertices @ rotation_matrix.T + elif up_dir is ["-x", "-X"]: + rotation_matrix = np.array([[0, 0, 1], [0, 1, 0], [-1, 0, 0]], dtype=np.float32) + mesh.vertices = mesh.vertices @ rotation_matrix.T + if up_dir in ["y", "Y"]: + rotation_matrix = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float32) + mesh.vertices = mesh.vertices @ rotation_matrix.T + elif up_dir in ["-y", "-Y"]: + rotation_matrix = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=np.float32) + mesh.vertices = mesh.vertices @ rotation_matrix.T + elif up_dir in ["z", "Z"]: + pass + elif up_dir in ["-z", "-Z"]: + rotation_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]], dtype=np.float32) + mesh.vertices = mesh.vertices @ rotation_matrix.T + else: + raise ValueError(f"Invalid up direction: {up_dir}") + + # Normalize mesh to [-0.5, 0.5]^3 bounding box + bbox_min = mesh.vertices.min(axis=0) + bbox_max = mesh.vertices.max(axis=0) + center = (bbox_min + bbox_max) / 2 + mesh.vertices -= center # Center the mesh + + # Scale to fit in [-0.5, 0.5]^3 + scale = (bbox_max - bbox_min).max() # Use the largest dimension + mesh.vertices /= scale + + inputs, sharp_flag, face_indices = prepare_inputs(mesh, num_points_global=40000, num_points_decode=num_points) + + with torch.no_grad(): + outputs = model.infer( + xyz=inputs['xyz'], + feats=inputs['feats'], + normals=inputs['normals'], + output_all_hyps=True, + min_part_confidence=min_part_confidence + ) + + part_ids = outputs[0]['part_ids'] + motion_hierarchy = outputs[0]['motion_hierarchy'] + is_part_revolute = outputs[0]['is_part_revolute'] + is_part_prismatic = outputs[0]['is_part_prismatic'] + revolute_plucker = outputs[0]['revolute_plucker'] + revolute_range = outputs[0]['revolute_range'] + prismatic_axis = outputs[0]['prismatic_axis'] + prismatic_range = outputs[0]['prismatic_range'] + + _, face_part_ids_refined_strict, face_part_ids_refined = find_part_ids_for_faces( + mesh, + part_ids, + face_indices + ) + face_part_ids = face_part_ids_refined_strict if strict else face_part_ids_refined + unique_part_ids = np.unique(face_part_ids) + num_parts = len(unique_part_ids) + print(f"Found {num_parts} unique parts") + + # Check if original mesh has texture/UV coordinates + has_original_texture = ( + hasattr(mesh.visual, 'uv') and + mesh.visual.uv is not None and + len(mesh.visual.uv) > 0 + ) + + mesh_parts_original = [mesh.submesh([face_part_ids == part_id], append=True) for part_id in unique_part_ids] + mesh_parts_segmented = create_textured_mesh_parts([mp.copy() for mp in mesh_parts_original]) + + # Create axes + axes = [] + for i, mesh_part in enumerate(mesh_parts_segmented): + part_id = unique_part_ids[i] + if is_part_revolute[part_id]: + axis, point = plucker_to_axis_point(revolute_plucker[part_id]) + arrow_start, arrow_end = get_3D_arrow_on_points(axis, mesh_part.vertices, fixed_point=point, extension=0.2) + axes.append(create_arrow(arrow_start, arrow_end, color=ARROW_COLOR_REVOLUTE, radius=0.01, radius_tip=0.018)) + # Add rings at arrow_start and arrow_end + arrow_dir = arrow_end - arrow_start + axes.append(create_ring(arrow_start, arrow_dir, major_radius=0.03, minor_radius=0.006, color=ARROW_COLOR_REVOLUTE)) + axes.append(create_ring(arrow_end, arrow_dir, major_radius=0.03, minor_radius=0.006, color=ARROW_COLOR_REVOLUTE)) + elif is_part_prismatic[part_id]: + axis = prismatic_axis[part_id] + arrow_start, arrow_end = get_3D_arrow_on_points(axis, mesh_part.vertices, extension=0.2) + axes.append(create_arrow(arrow_start, arrow_end, color=ARROW_COLOR_PRISMATIC, radius=0.01, radius_tip=0.018)) + + trimesh.Scene(mesh_parts_segmented + axes).export(Path(output_path) / "mesh_parts_with_axes.glb") + + print("Exporting animated GLB files...") + + try: + export_animated_glb_file( + mesh_parts_original, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range, + animation_frames, + str(Path(output_path) / "animated_textured.glb"), + include_axes=False, + axes_meshes=None + ) + except Exception as e: + print(f"Error exporting animated.glb: {e}") + import traceback + traceback.print_exc() + + return ( + mesh_parts_original, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range + ) \ No newline at end of file diff --git a/partfield_utils.py b/partfield_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5b3d5a00818e82248beff96192ff8d72803af9e3 --- /dev/null +++ b/partfield_utils.py @@ -0,0 +1,40 @@ +import argparse +import os +import sys + +import torch + +sys.path.append(os.path.join(os.path.dirname(__file__), 'PartField')) +from partfield.model.PVCNN.encoder_pc import sample_triplane_feat +from partfield.model_trainer_pvcnn_only_demo import Model +from partfield.config import setup + +@torch.no_grad() +@torch.autocast(device_type='cuda', dtype=torch.bfloat16) +def obtain_partfield_feats( + partfield_model, + points_enc, + points_dec, +): + bbmin = points_enc.min(dim=-2, keepdim=True)[0] + bbmax = points_enc.max(dim=-2, keepdim=True)[0] + center = (bbmin + bbmax) * 0.5 + scale = 2.0 * 0.9 / (bbmax - bbmin).max() + points_enc = (points_enc - center) * scale + points_dec = (points_dec - center) * scale + + pc_feat = partfield_model.pvcnn(points_enc, points_enc) + planes = partfield_model.triplane_transformer(pc_feat) + sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2) + point_feat = sample_triplane_feat(part_planes, points_dec) + return point_feat + + +def get_partfield_model(device='cuda'): + partfield_model = Model.load_from_checkpoint( + os.path.join(os.path.dirname(__file__), 'PartField', 'model', 'model_objaverse.ckpt'), + cfg=setup(argparse.Namespace(config_file=os.path.join(os.path.dirname(__file__), 'PartField', 'configs', 'final', 'demo.yaml'), opts=[]), freeze=False) + ) + partfield_model.eval() + partfield_model.to(device=device) + return partfield_model diff --git a/particulate/__pycache__/articulation_utils.cpython-310.pyc b/particulate/__pycache__/articulation_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbb186b498909669e78e3c2dd9f749ea7605babb Binary files /dev/null and b/particulate/__pycache__/articulation_utils.cpython-310.pyc differ diff --git a/particulate/__pycache__/data_utils.cpython-310.pyc b/particulate/__pycache__/data_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a0330d3de948cfe812ae70f73f3c855ca2fd930 Binary files /dev/null and b/particulate/__pycache__/data_utils.cpython-310.pyc differ diff --git a/particulate/__pycache__/export_utils.cpython-310.pyc b/particulate/__pycache__/export_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d1221f0cf670d7801b503c7a0900c857ed7c24e Binary files /dev/null and b/particulate/__pycache__/export_utils.cpython-310.pyc differ diff --git a/particulate/__pycache__/inference_utils.cpython-310.pyc b/particulate/__pycache__/inference_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7782fe2079742a6cb83cd555d857cf877a02aaf5 Binary files /dev/null and b/particulate/__pycache__/inference_utils.cpython-310.pyc differ diff --git a/particulate/__pycache__/matcher.cpython-310.pyc b/particulate/__pycache__/matcher.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4155c7dd7b140bc7ca6e3345dd9b58c4e6bed559 Binary files /dev/null and b/particulate/__pycache__/matcher.cpython-310.pyc differ diff --git a/particulate/__pycache__/models.cpython-310.pyc b/particulate/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb4a20764d45c82de91ccb6a0546120c2604949f Binary files /dev/null and b/particulate/__pycache__/models.cpython-310.pyc differ diff --git a/particulate/__pycache__/visualization_utils.cpython-310.pyc b/particulate/__pycache__/visualization_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9795bb47dd8a5b8bc90da2664b19e558a1b4ba13 Binary files /dev/null and b/particulate/__pycache__/visualization_utils.cpython-310.pyc differ diff --git a/particulate/articulation_utils.py b/particulate/articulation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f9255abb15daf2b3cb1e4a40dddb7093a43065df --- /dev/null +++ b/particulate/articulation_utils.py @@ -0,0 +1,329 @@ +import numpy as np +from typing import List, Tuple, Optional, Union + + +def axis_point_to_plucker(axis: np.ndarray, point: np.ndarray) -> np.ndarray: + """ + Convert axis-point coordinates to plucker coordinates. + """ + assert axis.shape[-1] == 3 + assert point.shape[-1] == 3 + l = axis / (np.linalg.norm(axis, axis=-1, keepdims=True) + 1e-8) + m = np.cross(l, point, axis=-1) + return np.concatenate([l, m], axis=-1) + + +def plucker_to_axis_point(plucker: np.ndarray) -> np.ndarray: + """ + Convert plucker coordinates to axis-point coordinates. + """ + assert plucker.shape[-1] == 6 + l, m = plucker[..., :3], plucker[..., 3:] + axis = l / (np.linalg.norm(l, axis=-1, keepdims=True) + 1e-8) + point = np.cross(m, axis, axis=-1) + return axis, point + + +def plucker_to_4x4_transform_matrix(plucker: np.ndarray, angle: float) -> np.ndarray: + """ + Convert plucker coordinates to a 4x4 transformation matrix. + """ + assert plucker.shape == (6,) + axis, point = plucker_to_axis_point(plucker) + + # Create rotation matrix using Rodrigues' rotation formula + # R = I + sin(θ) * K + (1 - cos(θ)) * K² + # where K is the skew-symmetric matrix of the axis + + # Skew-symmetric matrix of the axis + K = np.array([ + [0, -axis[2], axis[1]], + [axis[2], 0, -axis[0]], + [-axis[1], axis[0], 0] + ]) + + # Identity matrix + I = np.eye(3) + + # Rotation matrix using Rodrigues' formula + R = I + np.sin(angle) * K + (1 - np.cos(angle)) * (K @ K) + + # Create 4x4 transformation matrix + # We need to translate to origin, rotate, then translate back + T = np.eye(4) + T[:3, :3] = R + T[:3, 3] = point - R @ point + + return T + + +def transform_points(points: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray: + """ + Transform points by a 4x4 transformation matrix. + + points: (..., 3) + transform_matrix: (4, 4) + """ + return points @ transform_matrix[:3, :3].T + transform_matrix[:3, 3] + + +def transform_direction(direction: np.ndarray, transform_matrix: np.ndarray) -> np.ndarray: + """ + Transform a direction vector by a 4x4 transformation matrix. + + direction: (..., 3) + transform_matrix: (4, 4) + """ + return direction @ transform_matrix[:3, :3].T + + +def get_subtree_part_ids(motion_hierarchy: List[Tuple[int, int]], part_id: int) -> List[int]: + """ + Get the subtree part ids for a given part id. + """ + subtree_part_ids = [part_id] + for parent_id, child_id in motion_hierarchy: + if parent_id == part_id: + subtree_part_ids.extend(get_subtree_part_ids(motion_hierarchy, child_id)) + return subtree_part_ids + + +def get_part_order_from_root(motion_hierarchy: List[Tuple[int, int]]) -> List[int]: + """ + Depth-first search to get the part order from the root. + """ + part_order = [] + visited = set() + def dfs(part_id): + if part_id in visited: + return + part_order.append(part_id) + visited.add(part_id) + for parent_id, child_id in motion_hierarchy: + if parent_id == part_id: + dfs(child_id) + + # Find the base/root part id + all_part_ids = set([parent_id for parent_id, _ in motion_hierarchy]) + all_part_ids.update([child_id for _, child_id in motion_hierarchy]) + + # Find the root part id + for _, child_id in motion_hierarchy: + all_part_ids.remove(child_id) + + # assert len(all_part_ids) == 1 + root_part_id = all_part_ids.pop() + + dfs(root_part_id) # Populate part_order + return part_order + + +def articulate_points( + xyz: np.ndarray, + part_ids: np.ndarray, + motion_hierarchy: List[Tuple[int, int]], + is_part_revolute: np.ndarray, + is_part_prismatic: np.ndarray, + revolute_plucker: np.ndarray, + revolute_range: np.ndarray, + prismatic_axis: np.ndarray, + prismatic_range: np.ndarray, + articulation_state: Union[float, np.ndarray], # Value between 0 and 1 + normals: Optional[np.ndarray] = None, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Articulate points based on given articulation state. + + Args: + articulation_state: Value between 0 (low limit) and 1 (high limit) + + Returns: + Tuple of (articulated_xyz, articulated_revolute_plucker, articulated_prismatic_axis) + """ + articulated_xyz = xyz.copy() + articulated_revolute_plucker = revolute_plucker.copy() + articulated_prismatic_axis = prismatic_axis.copy() + articulated_revolute_range = revolute_range.copy() + articulated_prismatic_range = prismatic_range.copy() + if normals is not None: + articulated_normals = normals.copy() + + if len(motion_hierarchy) == 0: + return articulated_xyz, articulated_revolute_plucker, articulated_prismatic_axis + + part_order = get_part_order_from_root(motion_hierarchy) + + for pid in part_order: + affected_part_ids = get_subtree_part_ids(motion_hierarchy, pid) + + if is_part_revolute[pid]: + low_limit, high_limit = revolute_range[pid] + part_articulation_state = articulation_state if isinstance(articulation_state, float) else articulation_state[pid] + # Interpolate between low and high limits + angle = low_limit + part_articulation_state * (high_limit - low_limit) + articulated_revolute_range[pid] = np.array([low_limit - angle, high_limit - angle]) + transform_matrix = plucker_to_4x4_transform_matrix(articulated_revolute_plucker[pid], angle) + + for affected_pid in affected_part_ids: + # Transform points + articulated_xyz[part_ids == affected_pid] = transform_points( + articulated_xyz[part_ids == affected_pid], transform_matrix + ) + # Transform normals + if normals is not None: + articulated_normals[part_ids == affected_pid] = transform_direction( + articulated_normals[part_ids == affected_pid], transform_matrix + ) + + # Transform revolute axes for affected parts + if is_part_revolute[affected_pid]: + current_axis, current_point = plucker_to_axis_point(articulated_revolute_plucker[affected_pid]) + new_axis = transform_direction(current_axis, transform_matrix) + new_point = transform_points(current_point, transform_matrix) + articulated_revolute_plucker[affected_pid] = axis_point_to_plucker(new_axis, new_point) + + # Transform prismatic axes for affected parts + if is_part_prismatic[affected_pid]: + articulated_prismatic_axis[affected_pid] = transform_direction( + articulated_prismatic_axis[affected_pid], transform_matrix + ) + + if is_part_prismatic[pid]: + low_limit, high_limit = prismatic_range[pid] + part_articulation_state = articulation_state if isinstance(articulation_state, float) else articulation_state[pid] + # Interpolate between low and high limits + displacement = low_limit + part_articulation_state * (high_limit - low_limit) + articulated_prismatic_range[pid] = np.array([low_limit - displacement, high_limit - displacement]) + paxis = articulated_prismatic_axis[pid] + + for affected_pid in affected_part_ids: + # Translate points + articulated_xyz[part_ids == affected_pid] = ( + articulated_xyz[part_ids == affected_pid] + displacement * paxis + ) + # Translation does not affect normals + + # Translate revolute axes for affected parts (only the point, not the direction) + if is_part_revolute[affected_pid]: + current_axis, current_point = plucker_to_axis_point(articulated_revolute_plucker[affected_pid]) + new_point = current_point + displacement * paxis + articulated_revolute_plucker[affected_pid] = axis_point_to_plucker(current_axis, new_point) + + output = (articulated_xyz, articulated_revolute_plucker, articulated_prismatic_axis, articulated_revolute_range, articulated_prismatic_range) + if normals is not None: + output = output + (articulated_normals,) + return output + + +def compute_part_transforms( + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range, + articulation_state +): + """ + Compute the 4x4 transformation matrix for each part at a given articulation state. + Returns a dictionary mapping part_id to its cumulative transformation matrix. + + The transformation represents how to transform each part from its rest pose to the articulated pose. + """ + if len(motion_hierarchy) == 0: + return {pid: np.eye(4) for pid in unique_part_ids} + + # Collect all relevant part IDs from motion hierarchy and unique_part_ids + all_part_ids = set(unique_part_ids) + for parent, child in motion_hierarchy: + all_part_ids.add(parent) + all_part_ids.add(child) + + transforms = {pid: np.eye(4) for pid in all_part_ids} + + # Process parts in hierarchical order (BFS/DFS from root) + part_order = get_part_order_from_root(motion_hierarchy) + + for pid in part_order: + affected_part_ids = get_subtree_part_ids(motion_hierarchy, pid) + + # Compute transformation for this part's joint + joint_transform = np.eye(4) + + if is_part_revolute[pid]: + low_limit, high_limit = revolute_range[pid] + angle = low_limit + articulation_state * (high_limit - low_limit) + joint_transform = plucker_to_4x4_transform_matrix(revolute_plucker[pid], angle) + + elif is_part_prismatic[pid]: + low_limit, high_limit = prismatic_range[pid] + displacement = low_limit + articulation_state * (high_limit - low_limit) + paxis = prismatic_axis[pid] + joint_transform[:3, 3] = displacement * paxis + + # Apply joint transformation to all affected (descendant) parts + for affected_pid in affected_part_ids: + if affected_pid in transforms: + transforms[affected_pid] = joint_transform @ transforms[affected_pid] + + return transforms + + +def closest_point_on_axis_to_revolute_plucker( + closest_point_on_axis: np.ndarray, + part_ids: np.ndarray, + is_part_revolute: np.ndarray, + is_part_prismatic: np.ndarray, + revolute_axis: np.ndarray, +) -> np.ndarray: + """ + Convert closest point on axis to motion parameters. + """ + num_parts = revolute_axis.shape[0] + revolute_plucker = np.zeros((num_parts, 6)) + for pid in np.unique(part_ids): + if is_part_revolute[pid]: + closest_points_on_axis = closest_point_on_axis[part_ids == pid] + current_revolute_axis = revolute_axis[pid] + selected_point = np.median(closest_points_on_axis, axis=0) + plucker = axis_point_to_plucker(current_revolute_axis, selected_point) + revolute_plucker[pid, :] = plucker + return revolute_plucker + + +def articulate_mesh_parts( + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, is_part_prismatic, + revolute_plucker, revolute_range, + prismatic_axis, prismatic_range, + articulation_state +): + """ + Articulate mesh parts based on given articulation state. + """ + all_verts = [mesh_parts[i].vertices for i in range(len(mesh_parts))] + all_part_ids = [np.full(len(mesh_parts[i].vertices), unique_part_ids[i], dtype=np.int32) for i in range(len(mesh_parts))] + + verts_transformed = articulate_points( + xyz=np.concatenate(all_verts, axis=0), + part_ids=np.concatenate(all_part_ids, axis=0), + motion_hierarchy=motion_hierarchy, + is_part_revolute=is_part_revolute, + is_part_prismatic=is_part_prismatic, + revolute_plucker=revolute_plucker, + revolute_range=revolute_range, + prismatic_axis=prismatic_axis, + prismatic_range=prismatic_range, + articulation_state=articulation_state + )[0] + + mesh_parts_articulated = [mesh_parts[i].copy() for i in range(len(mesh_parts))] + vert_offset = 0 + for i in range(len(mesh_parts)): + mesh_parts_articulated[i].vertices = verts_transformed[vert_offset:vert_offset + len(mesh_parts[i].vertices)] + vert_offset += len(mesh_parts[i].vertices) + return mesh_parts_articulated diff --git a/particulate/data_utils.py b/particulate/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f580e90ef470403337efb8cfc08cedad66afca2b --- /dev/null +++ b/particulate/data_utils.py @@ -0,0 +1,85 @@ +from pathlib import Path +from typing import List, Tuple + +import numpy as np + + +def load_obj_raw_preserve(path: Path) -> Tuple[np.ndarray, np.ndarray]: + """Load vertices and faces from an OBJ file while preserving vertex order. + + Args: + path (Path): Path to the OBJ file + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple containing: + - vertices: Nx3 array of vertex positions + - faces: Mx3 array of face indices (0-based) + """ + verts, faces = [], [] + with path.open() as fh: + for ln in fh: + if ln.startswith('v '): # keep order *exactly* as file + _, x, y, z = ln.split()[:4] + verts.append([float(x), float(y), float(z)]) + elif ln.startswith('f '): + toks = ln[2:].strip().split() + if len(toks) == 3: + faces.append([int(t.split('/')[0]) - 1 for t in toks]) + else: + faces.append([int(t.split('/')[0]) - 1 for t in toks[:3]]) + for i in range(2, len(toks) - 1): + faces.append([int(toks[0].split('/')[0]) - 1, + int(toks[i].split('/')[0]) - 1, + int(toks[i + 1].split('/')[0]) - 1]) + return np.asarray(verts, float), np.asarray(faces, int) + + +def get_plucker_coordinates(axis: np.ndarray, point: np.ndarray) -> np.ndarray: + """Converts a joint axis to Plücker coordinates.""" + l = axis / np.linalg.norm(axis, axis=-1, keepdims=True) + m = np.cross(l, point, axis=-1) + return np.concatenate([l, m], axis=-1) + + +def get_axis_point_from_plucker(plucker: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Converts Plücker coordinates to axis and point.""" + l = plucker[..., :3] + m = plucker[..., 3:] + axis = l / np.linalg.norm(l, axis=-1, keepdims=True) + point = np.cross(m, axis, axis=-1) + return axis, point + + +def normalize_meshes(all_verts: List[np.ndarray]) -> Tuple[List[np.ndarray], float, float, float, float]: + x_min, y_min, z_min = ( + min(verts[:, 0].min() for verts in all_verts), + min(verts[:, 1].min() for verts in all_verts), + min(verts[:, 2].min() for verts in all_verts) + ) + x_max, y_max, z_max = ( + max(verts[:, 0].max() for verts in all_verts), + max(verts[:, 1].max() for verts in all_verts), + max(verts[:, 2].max() for verts in all_verts) + ) + x_center, y_center, z_center = (x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2 + scale = 1.0 / max(x_max - x_min, y_max - y_min, z_max - z_min) + all_new_verts = [] + for verts in all_verts: + new_verts = verts.copy() + new_verts[:, 0] -= x_center + new_verts[:, 1] -= y_center + new_verts[:, 2] -= z_center + new_verts *= scale + all_new_verts.append(new_verts) + return all_new_verts, x_center, y_center, z_center, scale + + +def shift_axes_plucker( + axes_plucker: np.ndarray, + x_center: float, y_center: float, z_center: float, scale: float +) -> np.ndarray: + """Shift the axes plucker coordinates.""" + axis, point = get_axis_point_from_plucker(axes_plucker) + point_new = point - np.array([x_center, y_center, z_center]) + point_new *= scale + return get_plucker_coordinates(axis, point_new) diff --git a/particulate/export_utils.py b/particulate/export_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..429708874207e7860dccdc2fda00c8776afac3ad --- /dev/null +++ b/particulate/export_utils.py @@ -0,0 +1,850 @@ +import numpy as np +import os +import trimesh +from particulate.articulation_utils import ( + compute_part_transforms, + plucker_to_axis_point +) + + +def export_animated_glb_file( + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range, + animation_frames, + output_path, + include_axes=False, + axes_meshes=None +): + """ + Export an animated GLB file with proper node transformations. + + This function creates a GLB file with baked animations where each mesh part is a separate node + with transformation animations (translation, rotation, scale) that represent the articulation + motion over time. + + Args: + mesh_parts: List of trimesh objects, one per part + unique_part_ids: Array of unique part IDs + motion_hierarchy: List of (parent_id, child_id) tuples defining the kinematic tree + is_part_revolute: Boolean array indicating if each part has a revolute joint + is_part_prismatic: Boolean array indicating if each part has a prismatic joint + revolute_plucker: Plucker coordinates for revolute joint axes + revolute_range: [low, high] angle limits for revolute joints + prismatic_axis: Direction vectors for prismatic joints + prismatic_range: [low, high] displacement limits for prismatic joints + animation_frames: Number of keyframes in the animation + output_path: Path where the GLB file will be saved + include_axes: Whether to include axis visualization meshes + axes_meshes: List of trimesh objects representing axis visualizations (arrows/rings) + + The animation interpolates linearly from the low limit (state=0) to high limit (state=1) + over the specified number of frames at 30 FPS. + """ + import tempfile + from pygltflib import GLTF2, Animation, AnimationChannel, AnimationSampler, Accessor, BufferView + import os + + # Step 1: Export base mesh using trimesh (which handles textures/UVs correctly) + # Create a Scene with all parts and axes + scene = trimesh.Scene() + + # Keep track of part names to find their node indices later + part_node_names = [] + + for i, mesh_part in enumerate(mesh_parts): + # Assign a unique name for this part + # We use a specific prefix to identify it later + node_name = f"part_node_{i}" + part_node_names.append(node_name) + scene.add_geometry(mesh_part, node_name=node_name) + + if include_axes and axes_meshes: + for i, axis_mesh in enumerate(axes_meshes): + scene.add_geometry(axis_mesh, node_name=f"axis_node_{i}") + + # Export to a temporary file using trimesh + with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as tmp: + tmp_path = tmp.name + + try: + scene.export(tmp_path) + + # Step 2: Load the GLB using pygltflib + gltf = GLTF2().load(tmp_path) + + # Map node names to node indices + node_name_to_idx = {} + if gltf.nodes: + for i, node in enumerate(gltf.nodes): + if node.name: + node_name_to_idx[node.name] = i + + # Step 3: Add animation data + if not gltf.animations: + gltf.animations = [] + gltf.animations.append(Animation(channels=[], samplers=[])) + + animation_idx = len(gltf.animations) - 1 + + # Get the current binary buffer + # Read it from the file directly to ensure we have the correct data + with open(tmp_path, 'rb') as f: + # GLB format: 12-byte header, then chunks + header = f.read(12) + # Read JSON chunk + json_chunk_length = int.from_bytes(f.read(4), byteorder='little') + json_chunk_type = f.read(4) + json_data = f.read(json_chunk_length) + # Read binary chunk + bin_chunk_length = int.from_bytes(f.read(4), byteorder='little') + bin_chunk_type = f.read(4) + binary_data = bytearray(f.read(bin_chunk_length)) + + # Helper function to add binary data to the GLB buffer + def add_to_binary(data_bytes): + """Add data to binary blob and return BufferView info.""" + nonlocal binary_data + + # Align to 4 bytes + while len(binary_data) % 4 != 0: + binary_data.append(0) + + start = len(binary_data) + binary_data.extend(data_bytes) + + # Update buffer length in gltf structure + gltf.buffers[0].byteLength = len(binary_data) + + return start, len(data_bytes) + + # Step 4: Create animation data + states = np.linspace(0, 1, animation_frames) + times = np.linspace(0, animation_frames / 30.0, animation_frames).astype(np.float32) # 30 FPS + + # Add time accessor + time_bytes = times.tobytes() + time_start, time_length = add_to_binary(time_bytes) + time_bv_idx = len(gltf.bufferViews) + gltf.bufferViews.append(BufferView( + buffer=0, + byteOffset=time_start, + byteLength=time_length + )) + + time_acc_idx = len(gltf.accessors) + gltf.accessors.append(Accessor( + bufferView=time_bv_idx, + componentType=5126, # FLOAT + count=len(times), + type='SCALAR', + max=[float(times.max())], + min=[float(times.min())] + )) + + # For each part, create TRS animation samplers + for part_idx, part_id in enumerate(unique_part_ids): + # Find the correct node index for this part + part_node_name = part_node_names[part_idx] + target_node_idx = node_name_to_idx.get(part_node_name) + + if target_node_idx is None: + print(f"Warning: Could not find node index for part {part_idx} (name: {part_node_name})") + continue + + # Compute transforms for all frames + transforms_over_time = [] + for state in states: + transforms = compute_part_transforms( + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range, + state + ) + transforms_over_time.append(transforms[part_id]) + + # Decompose transforms into TRS + translations = [] + rotations = [] + scales = [] + + for T in transforms_over_time: + # Extract translation + translation = T[:3, 3] + translations.append(translation) + + # Extract rotation (convert to quaternion) + R = T[:3, :3] + # Compute scale + scale = np.array([ + np.linalg.norm(R[:, 0]), + np.linalg.norm(R[:, 1]), + np.linalg.norm(R[:, 2]) + ]) + scales.append(scale) + + # Remove scale from rotation matrix + R_normalized = R / scale + + # Convert rotation matrix to quaternion + trace = np.trace(R_normalized) + if trace > 0: + s = 0.5 / np.sqrt(trace + 1.0) + w = 0.25 / s + x = (R_normalized[2, 1] - R_normalized[1, 2]) * s + y = (R_normalized[0, 2] - R_normalized[2, 0]) * s + z = (R_normalized[1, 0] - R_normalized[0, 1]) * s + else: + if R_normalized[0, 0] > R_normalized[1, 1] and R_normalized[0, 0] > R_normalized[2, 2]: + s = 2.0 * np.sqrt(1.0 + R_normalized[0, 0] - R_normalized[1, 1] - R_normalized[2, 2]) + w = (R_normalized[2, 1] - R_normalized[1, 2]) / s + x = 0.25 * s + y = (R_normalized[0, 1] + R_normalized[1, 0]) / s + z = (R_normalized[0, 2] + R_normalized[2, 0]) / s + elif R_normalized[1, 1] > R_normalized[2, 2]: + s = 2.0 * np.sqrt(1.0 + R_normalized[1, 1] - R_normalized[0, 0] - R_normalized[2, 2]) + w = (R_normalized[0, 2] - R_normalized[2, 0]) / s + x = (R_normalized[0, 1] + R_normalized[1, 0]) / s + y = 0.25 * s + z = (R_normalized[1, 2] + R_normalized[2, 1]) / s + else: + s = 2.0 * np.sqrt(1.0 + R_normalized[2, 2] - R_normalized[0, 0] - R_normalized[1, 1]) + w = (R_normalized[1, 0] - R_normalized[0, 1]) / s + x = (R_normalized[0, 2] + R_normalized[2, 0]) / s + y = (R_normalized[1, 2] + R_normalized[2, 1]) / s + z = 0.25 * s + + rotations.append([x, y, z, w]) + + translations = np.array(translations, dtype=np.float32) + rotations = np.array(rotations, dtype=np.float32) + scales = np.array(scales, dtype=np.float32) + + # Add translation accessor + trans_bytes = translations.tobytes() + trans_start, trans_length = add_to_binary(trans_bytes) + trans_bv_idx = len(gltf.bufferViews) + gltf.bufferViews.append(BufferView( + buffer=0, + byteOffset=trans_start, + byteLength=trans_length + )) + + trans_acc_idx = len(gltf.accessors) + gltf.accessors.append(Accessor( + bufferView=trans_bv_idx, + componentType=5126, + count=len(translations), + type='VEC3', + max=translations.max(axis=0).tolist(), + min=translations.min(axis=0).tolist() + )) + + # Add rotation accessor + rot_bytes = rotations.tobytes() + rot_start, rot_length = add_to_binary(rot_bytes) + rot_bv_idx = len(gltf.bufferViews) + gltf.bufferViews.append(BufferView( + buffer=0, + byteOffset=rot_start, + byteLength=rot_length + )) + + rot_acc_idx = len(gltf.accessors) + gltf.accessors.append(Accessor( + bufferView=rot_bv_idx, + componentType=5126, + count=len(rotations), + type='VEC4', + max=rotations.max(axis=0).tolist(), + min=rotations.min(axis=0).tolist() + )) + + # Add scale accessor + scale_bytes = scales.tobytes() + scale_start, scale_length = add_to_binary(scale_bytes) + scale_bv_idx = len(gltf.bufferViews) + gltf.bufferViews.append(BufferView( + buffer=0, + byteOffset=scale_start, + byteLength=scale_length + )) + + scale_acc_idx = len(gltf.accessors) + gltf.accessors.append(Accessor( + bufferView=scale_bv_idx, + componentType=5126, + count=len(scales), + type='VEC3', + max=scales.max(axis=0).tolist(), + min=scales.min(axis=0).tolist() + )) + + # Create animation samplers and channels + # Translation sampler + trans_sampler_idx = len(gltf.animations[animation_idx].samplers) + gltf.animations[animation_idx].samplers.append(AnimationSampler( + input=time_acc_idx, + output=trans_acc_idx, + interpolation='LINEAR' + )) + gltf.animations[animation_idx].channels.append(AnimationChannel( + sampler=trans_sampler_idx, + target={'node': target_node_idx, 'path': 'translation'} + )) + + # Rotation sampler + rot_sampler_idx = len(gltf.animations[animation_idx].samplers) + gltf.animations[animation_idx].samplers.append(AnimationSampler( + input=time_acc_idx, + output=rot_acc_idx, + interpolation='LINEAR' + )) + gltf.animations[animation_idx].channels.append(AnimationChannel( + sampler=rot_sampler_idx, + target={'node': target_node_idx, 'path': 'rotation'} + )) + + # Scale sampler + scale_sampler_idx = len(gltf.animations[animation_idx].samplers) + gltf.animations[animation_idx].samplers.append(AnimationSampler( + input=time_acc_idx, + output=scale_acc_idx, + interpolation='LINEAR' + )) + gltf.animations[animation_idx].channels.append(AnimationChannel( + sampler=scale_sampler_idx, + target={'node': target_node_idx, 'path': 'scale'} + )) + + # Step 5: Save the animated GLB with updated binary data + # We need to manually write the GLB file to ensure our binary_data is used + import json + + # Helper function to recursively convert non-serializable objects to dicts + def make_json_serializable(obj): + """Recursively convert objects to JSON-serializable format.""" + # Handle numpy arrays and scalars + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, (np.integer, np.floating)): + return obj.item() + elif isinstance(obj, np.bool_): + return bool(obj) + # Handle objects with __dict__ (like Attributes) + elif hasattr(obj, '__dict__') and not isinstance(obj, (str, bytes, type)): + result = {} + for key, value in obj.__dict__.items(): + if not key.startswith('_'): # Skip private attributes + result[key] = make_json_serializable(value) + return result + elif isinstance(obj, dict): + return {k: make_json_serializable(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [make_json_serializable(item) for item in obj] + elif hasattr(obj, 'to_dict') and callable(getattr(obj, 'to_dict')): + # Handle objects with to_dict method + return make_json_serializable(obj.to_dict()) + else: + # Return primitive types as-is (str, int, float, bool, None) + return obj + + # Helper function to clean GLTF dict by removing null values and empty arrays + def clean_gltf_dict(obj): + """Remove null values and empty arrays to comply with GLTF spec.""" + if isinstance(obj, dict): + result = {} + for key, value in obj.items(): + cleaned_value = clean_gltf_dict(value) + # Skip null values (GLTF spec: optional fields should be omitted, not null) + if cleaned_value is None: + continue + # Skip empty arrays (GLTF spec: empty arrays should be omitted) + if isinstance(cleaned_value, list) and len(cleaned_value) == 0: + continue + result[key] = cleaned_value + return result + elif isinstance(obj, list): + cleaned_list = [clean_gltf_dict(item) for item in obj] + # Filter out None values from lists + return [item for item in cleaned_list if item is not None] + else: + return obj + + # Helper function to validate and fix mesh primitives + def validate_mesh_primitives(gltf_dict): + """Remove invalid accessor indices from mesh primitives.""" + if 'meshes' not in gltf_dict: + return gltf_dict + + num_accessors = len(gltf_dict.get('accessors', [])) + + for mesh in gltf_dict['meshes']: + if 'primitives' not in mesh: + continue + for primitive in mesh['primitives']: + if 'attributes' not in primitive: + continue + # Remove invalid attribute references + valid_attributes = {} + for attr_name, accessor_idx in primitive['attributes'].items(): + # Only keep attributes with valid accessor indices + if (isinstance(accessor_idx, int) and + accessor_idx >= 0 and + accessor_idx < num_accessors): + valid_attributes[attr_name] = accessor_idx + primitive['attributes'] = valid_attributes + + # Validate indices accessor if present + if 'indices' in primitive: + indices_idx = primitive['indices'] + if not (isinstance(indices_idx, int) and + indices_idx >= 0 and + indices_idx < num_accessors): + del primitive['indices'] + + # Validate material index if present + if 'material' in primitive: + material_idx = primitive['material'] + num_materials = len(gltf_dict.get('materials', [])) + if not (isinstance(material_idx, int) and + material_idx >= 0 and + material_idx < num_materials): + del primitive['material'] + + return gltf_dict + + # Helper function to validate node references + def validate_node_references(gltf_dict): + """Validate and fix node references to other objects.""" + if 'nodes' not in gltf_dict: + return gltf_dict + + num_meshes = len(gltf_dict.get('meshes', [])) + num_cameras = len(gltf_dict.get('cameras', [])) + num_skins = len(gltf_dict.get('skins', [])) + num_nodes = len(gltf_dict['nodes']) + + for node in gltf_dict['nodes']: + # Validate mesh reference + if 'mesh' in node: + mesh_idx = node['mesh'] + if not (isinstance(mesh_idx, int) and + mesh_idx >= 0 and + mesh_idx < num_meshes): + del node['mesh'] + + # Validate camera reference + if 'camera' in node: + camera_idx = node['camera'] + if not (isinstance(camera_idx, int) and + camera_idx >= 0 and + camera_idx < num_cameras): + del node['camera'] + + # Validate skin reference + if 'skin' in node: + skin_idx = node['skin'] + if not (isinstance(skin_idx, int) and + skin_idx >= 0 and + skin_idx < num_skins): + del node['skin'] + + # Validate children references + if 'children' in node: + valid_children = [] + for child_idx in node['children']: + if (isinstance(child_idx, int) and + child_idx >= 0 and + child_idx < num_nodes): + valid_children.append(child_idx) + if len(valid_children) > 0: + node['children'] = valid_children + else: + del node['children'] + + return gltf_dict + + # Helper function to validate texture and image references + def validate_texture_references(gltf_dict): + """Validate and fix texture and image references.""" + num_images = len(gltf_dict.get('images', [])) + num_samplers = len(gltf_dict.get('samplers', [])) + + # Validate textures + if 'textures' in gltf_dict: + for texture in gltf_dict['textures']: + # Validate sampler reference + if 'sampler' in texture: + sampler_idx = texture['sampler'] + if not (isinstance(sampler_idx, int) and + sampler_idx >= 0 and + sampler_idx < num_samplers): + del texture['sampler'] + + # Validate source (image) reference + if 'source' in texture: + source_idx = texture['source'] + if not (isinstance(source_idx, int) and + source_idx >= 0 and + source_idx < num_images): + del texture['source'] + + return gltf_dict + + # Update JSON to reflect new buffer size + gltf_dict = gltf.to_dict() + # Recursively convert all nested objects to be JSON serializable + gltf_dict = make_json_serializable(gltf_dict) + # Validate and fix references + gltf_dict = validate_mesh_primitives(gltf_dict) + gltf_dict = validate_node_references(gltf_dict) + gltf_dict = validate_texture_references(gltf_dict) + # Clean up null values and empty arrays (must be last to remove invalid fields) + gltf_dict = clean_gltf_dict(gltf_dict) + + # Write GLB file manually + with open(output_path, 'wb') as f: + # Write GLB header + # Magic: "glTF" + f.write(b'glTF') + # Version: 2 + f.write((2).to_bytes(4, byteorder='little')) + # Total length (will update later) + total_length_pos = f.tell() + f.write((0).to_bytes(4, byteorder='little')) + + # Write JSON chunk + json_str = json.dumps(gltf_dict, separators=(',', ':')) + json_bytes = json_str.encode('utf-8') + json_chunk_length = len(json_bytes) + # Align JSON to 4 bytes + while json_chunk_length % 4 != 0: + json_bytes += b' ' + json_chunk_length += 1 + + f.write(json_chunk_length.to_bytes(4, byteorder='little')) + f.write(b'JSON') + f.write(json_bytes) + + # Write binary chunk + # Align binary to 4 bytes + while len(binary_data) % 4 != 0: + binary_data.append(0) + + bin_chunk_length = len(binary_data) + f.write(bin_chunk_length.to_bytes(4, byteorder='little')) + f.write(b'BIN\x00') + f.write(binary_data) + + # Update total length + total_length = f.tell() + f.seek(total_length_pos) + f.write(total_length.to_bytes(4, byteorder='little')) + + print(f"Saved animated GLB to {output_path}") + + finally: + # Clean up temporary file + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + +def export_urdf( + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range, + output_path, + name="robot" +): + urdf_dir = os.path.dirname(output_path) + os.makedirs(urdf_dir, exist_ok=True) + mesh_dir = os.path.abspath(os.path.join(urdf_dir, "meshes")) + os.makedirs(mesh_dir, exist_ok=True) + + # Identify parents and children + unique_part_ids_set = set(unique_part_ids) + parent_map = {} + children_map = {pid: [] for pid in unique_part_ids} + for p, c in motion_hierarchy: + # Filter out hierarchy edges where parts don't exist in the mesh + if p not in unique_part_ids_set or c not in unique_part_ids_set: + continue + + parent_map[c] = p + if p in children_map: + children_map[p].append(c) + else: + children_map[p] = [c] + + # Find roots + roots = [] + for pid in unique_part_ids: + if pid not in parent_map: + roots.append(pid) + + # Determine local frame origins for each link (in World Coordinates) + link_origins_world = {} + + for i, pid in enumerate(unique_part_ids): + if pid in roots: + link_origins_world[pid] = np.zeros(3) + elif is_part_revolute[pid]: + # Revolute: Origin at point on axis + axis, point = plucker_to_axis_point(revolute_plucker[pid]) + link_origins_world[pid] = point + elif is_part_prismatic[pid]: + # Prismatic: Origin at Centroid of mesh + link_origins_world[pid] = mesh_parts[i].vertices.mean(axis=0) + else: + # Fixed/Other + link_origins_world[pid] = mesh_parts[i].vertices.mean(axis=0) + + # Prepare URDF string + urdf_lines = [] + urdf_lines.append(f'') + urdf_lines.append(f'') + + # Process each part + for i, pid in enumerate(unique_part_ids): + mesh = mesh_parts[i] + origin = link_origins_world[pid] + + # Save mesh (centered at local origin) + mesh_local = mesh.copy() + mesh_local.vertices -= origin + + mesh_filename = f"part_{pid}.obj" + mesh_path = os.path.join(mesh_dir, mesh_filename) + mesh_local.export(mesh_path) + + link_name = f"link_{pid}" + + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + + # Joints + for pid in unique_part_ids: + if pid in parent_map: + parent_pid = parent_map[pid] + child_pid = pid + + joint_name = f"joint_{parent_pid}_{child_pid}" + parent_link = f"link_{parent_pid}" + child_link = f"link_{child_pid}" + + p_origin = link_origins_world[parent_pid] + c_origin = link_origins_world[child_pid] + offset = c_origin - p_origin + + if is_part_revolute[pid]: + j_type = "revolute" + axis, _ = plucker_to_axis_point(revolute_plucker[pid]) + axis = axis / (np.linalg.norm(axis) + 1e-6) + lower, upper = revolute_range[pid] + elif is_part_prismatic[pid]: + j_type = "prismatic" + axis = prismatic_axis[pid] + axis = axis / (np.linalg.norm(axis) + 1e-6) + lower, upper = prismatic_range[pid] + else: + j_type = "fixed" + axis = [0, 0, 1] + lower, upper = 0, 0 + + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + if j_type != "fixed": + urdf_lines.append(f' ') + urdf_lines.append(f' ') + urdf_lines.append(f' ') + + urdf_lines.append(f'') + + with open(output_path, 'w') as f: + f.write('\n'.join(urdf_lines)) + + print(f"Exported URDF to {output_path}") + + +def export_mjcf( + mesh_parts, + unique_part_ids, + motion_hierarchy, + is_part_revolute, + is_part_prismatic, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range, + output_path, + name="robot" +): + import os + + mjcf_dir = os.path.dirname(output_path) + os.makedirs(mjcf_dir, exist_ok=True) + mesh_dir = os.path.join(mjcf_dir, "meshes") + os.makedirs(mesh_dir, exist_ok=True) + + # Identify parents and children + unique_part_ids_set = set(unique_part_ids) + parent_map = {} + children_map = {pid: [] for pid in unique_part_ids} + for p, c in motion_hierarchy: + # Filter out hierarchy edges where parts don't exist in the mesh + if p not in unique_part_ids_set or c not in unique_part_ids_set: + continue + + parent_map[c] = p + if p in children_map: + children_map[p].append(c) + else: + children_map[p] = [c] + + # Find roots + roots = [] + for pid in unique_part_ids: + if pid not in parent_map: + roots.append(pid) + + # Determine local frame origins for each link (in World Coordinates) + link_origins_world = {} + + for i, pid in enumerate(unique_part_ids): + if pid in roots: + link_origins_world[pid] = np.zeros(3) + elif is_part_revolute[pid]: + # Revolute: Origin at point on axis + axis, point = plucker_to_axis_point(revolute_plucker[pid]) + link_origins_world[pid] = point + elif is_part_prismatic[pid]: + # Prismatic: Origin at Centroid of mesh + link_origins_world[pid] = mesh_parts[i].vertices.mean(axis=0) + else: + # Fixed/Other + link_origins_world[pid] = mesh_parts[i].vertices.mean(axis=0) + + # Save meshes and prepare assets + asset_lines = [] + asset_lines.append(f' ') + + for i, pid in enumerate(unique_part_ids): + mesh = mesh_parts[i] + origin = link_origins_world[pid] + + # Save mesh (centered at local origin) + mesh_local = mesh.copy() + mesh_local.vertices -= origin + + mesh_filename = f"part_{pid}.obj" + mesh_path = os.path.join(mesh_dir, mesh_filename) + mesh_local.export(mesh_path) + + asset_lines.append(f' ') + + asset_lines.append(f' ') + + # Recursive function to build body hierarchy + def build_body_xml(pid, parent_pid=None, indent=" "): + lines = [] + + # Calculate position relative to parent + origin = link_origins_world[pid] + if parent_pid is not None: + parent_origin = link_origins_world[parent_pid] + rel_pos = origin - parent_origin + else: + rel_pos = origin # Relative to world (0,0,0) + + lines.append(f'{indent}') + + # Add geom + lines.append(f'{indent} ') + # Optional: Add collision geom (using same mesh for now) + # lines.append(f'{indent} ') + + # Add joint if not root + if parent_pid is not None: + if is_part_revolute[pid]: + axis, _ = plucker_to_axis_point(revolute_plucker[pid]) + axis = axis / (np.linalg.norm(axis) + 1e-6) + lower, upper = revolute_range[pid] + # Convert radians to degrees for MJCF default + lower_deg = np.degrees(lower) + upper_deg = np.degrees(upper) + lines.append(f'{indent} ') + elif is_part_prismatic[pid]: + axis = prismatic_axis[pid] + axis = axis / (np.linalg.norm(axis) + 1e-6) + lower, upper = prismatic_range[pid] + lines.append(f'{indent} ') + else: + # Fixed joint (no joint element needed in MJCF, bodies are fused) + pass + + # Process children + for child_pid in children_map[pid]: + lines.extend(build_body_xml(child_pid, pid, indent + " ")) + + lines.append(f'{indent}') + return lines + + # Build full MJCF + mjcf_lines = [] + mjcf_lines.append(f'') + mjcf_lines.append(f' ') # Explicitly set angle unit + mjcf_lines.extend(asset_lines) + mjcf_lines.append(f' ') + + # Add floor (optional but good for visualization) + # mjcf_lines.append(f' ') + + for root_pid in roots: + mjcf_lines.extend(build_body_xml(root_pid, indent=" ")) + + mjcf_lines.append(f' ') + mjcf_lines.append(f'') + + with open(output_path, 'w') as f: + f.write("\n".join(mjcf_lines)) + + print(f"Exported MJCF to {output_path}") diff --git a/particulate/inference_utils.py b/particulate/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..933fadc5cce6dc49637f848e5cdc6147a140cb45 --- /dev/null +++ b/particulate/inference_utils.py @@ -0,0 +1,40 @@ +from typing import List, Tuple + +import torch +import torch.nn.functional as F +import networkx as nx + + +def extract_motion_hierarchy(motion_structure_logits: torch.Tensor) -> List[Tuple[int, int]]: + """ + Extract the motion hierarchy from the motion structure logits using NetworkX's + maximum_spanning_arborescence (which implements Edmonds' algorithm). + + Args: + motion_structure_logits: (N, N) tensor where motion_structure_logits[i,j] + represents the logit for directed edge from i to j + + Returns: + List of (parent, child) tuples representing the directed spanning tree + """ + weights = F.logsigmoid(motion_structure_logits).detach().cpu().numpy() + + n = weights.shape[0] + + if n <= 1: + return [] + + G = nx.DiGraph() + + for i in range(n): + for j in range(n): + if i != j: + G.add_edge(i, j, weight=weights[i, j]) + + arborescence = nx.maximum_spanning_arborescence(G, attr='weight') + + result = [] + for p, c in arborescence.edges(): + result.append((p, c)) + + return result diff --git a/particulate/matcher.py b/particulate/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..72d0c3a3465884baba6e94e11a11d66c745a64aa --- /dev/null +++ b/particulate/matcher.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +from scipy.optimize import linear_sum_assignment + + +class HungarianMatcher(nn.Module): + def __init__(self): + super().__init__() + + @torch.no_grad() + def forward(self, point_mask, part_ids, num_valid_parts): + """ + Perform Hungarian matching between predicted columns and ground truth parts. + + Args: + point_mask: (B, N, M) - predicted logits before softmax + part_ids: (B, N) - ground truth part-id for each point + num_valid_parts: (B,) - number of available parts in each batch + + Returns: + matched_part_ids: (B, M) - mapped part-id for each column (-1 if unmapped) + """ + batch_size, num_points, num_columns = point_mask.shape + device = point_mask.device + + # Convert logits to probabilities and log probabilities + probs = torch.softmax(point_mask, dim=-1) # (B, N, M) + log_probs = torch.log_softmax(point_mask, dim=-1) # (B, N, M) + + result = [] + + for i in range(batch_size): + n_valid = num_valid_parts[i].item() + + if n_valid == 0: + # No valid parts, all columns get -1 + result.append(torch.full((num_columns,), -1, dtype=torch.long, device=device)) + continue + + # Create part masks: (n_valid, N) - True if point belongs to part + part_masks = (part_ids[i].unsqueeze(0) == torch.arange(n_valid, device=device).unsqueeze(1)) # (n_valid, N) + + # Create prediction masks: (N, num_columns) - True if predicted to belong to column + pred_assignments = torch.argmax(probs[i], dim=-1) # (N,) + pred_masks = torch.zeros_like(probs[i], dtype=torch.bool) # (N, num_columns) + pred_masks[torch.arange(pred_masks.size(0)), pred_assignments] = True + + # Matrix multiplication to compute sum of log probs for each part-column pair + log_prob_sums = part_masks.float() @ log_probs[i] # (n_valid, num_columns) + cost_matrix = -log_prob_sums # (n_valid, num_columns) + + # Apply Hungarian algorithm to find optimal assignment + # Convert to float32 to avoid BFloat16 compatibility issues with scipy + row_indices, col_indices = linear_sum_assignment(cost_matrix.cpu().float().numpy()) + + # Create result for this batch + batch_result = torch.full((num_columns,), -1, dtype=torch.long, device=device) + batch_result[col_indices] = torch.tensor(row_indices, dtype=torch.long, device=device) + result.append(batch_result) + + return torch.stack(result) diff --git a/particulate/models.py b/particulate/models.py new file mode 100644 index 0000000000000000000000000000000000000000..9246f7347b64453846cd33277b78b6c140145122 --- /dev/null +++ b/particulate/models.py @@ -0,0 +1,1046 @@ +from typing import List, Optional, Tuple +import math + +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.normalization import FP32LayerNorm +import torch +import torch.nn as nn + +from particulate.inference_utils import extract_motion_hierarchy +from particulate.matcher import HungarianMatcher +from particulate.articulation_utils import closest_point_on_axis_to_revolute_plucker + + +class PositionalEmbedder(nn.Module): + def __init__( + self, + frequency_embedding_size: int, + hidden_size: int, + input_dim: int, + raw: bool = False, + ): + super(PositionalEmbedder, self).__init__() + self.frequency_embedding_size = frequency_embedding_size + self.raw = raw + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size * input_dim + (input_dim if raw else 0), hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + ) + self.input_dim = input_dim + + @staticmethod + def pos_embedding(x, dim, max_period=10000, mult_factor: float = 1000.0): + x = mult_factor * x + half = dim // 2 + # freqs = torch.exp(torch.arange(half, dtype=torch.float32) / half).to(x.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + x.device + ) + args = x[..., None].float() * freqs[None, :] + embeddings = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + if dim % 2: + embeddings = torch.cat([embeddings, torch.zeros_like(embeddings[..., :1])], dim=-1) + return embeddings + + def forward( + self, + x: torch.FloatTensor + ): + assert x.shape[-1] == self.input_dim + x_embed = self.pos_embedding(x, self.frequency_embedding_size) + x_embed = x_embed.flatten(start_dim=-2) # Flatten first: (..., input_dim * frequency_embedding_size) + if self.raw: + x_embed = torch.cat([x_embed, x], dim=-1) # Now concatenate: (..., input_dim * frequency_embedding_size + input_dim) + x_embed = self.mlp(x_embed) + return x_embed + + +class Block(nn.Module): + def __init__( + self, + hidden_size: int, + n_heads: int, + dropout: float, + ): + super(Block, self).__init__() + + # Query self-attention + self.norm1 = FP32LayerNorm(hidden_size, eps=1e-5, elementwise_affine=True) + self.attn1 = Attention( + query_dim=hidden_size, + dim_head=hidden_size // n_heads, + heads=n_heads, + dropout=dropout, + qk_norm="rms_norm", + eps=1e-6, + bias=False + ) + + # Point cloud to query cross-attention + self.norm2 = FP32LayerNorm(hidden_size, eps=1e-5, elementwise_affine=True) + # self.norm2_k = nn.RMSNorm(hidden_size, eps=1e-6, elementwise_affine=True) + self.attn2 = Attention( + query_dim=hidden_size, + dim_head=hidden_size // n_heads, + heads=n_heads, + dropout=dropout, + qk_norm="rms_norm", + eps=1e-6, + bias=False + ) + + # Query to point cloud cross-attention + self.norm3 = FP32LayerNorm(hidden_size, eps=1e-5, elementwise_affine=True) + # self.norm3_k = nn.RMSNorm(hidden_size, eps=1e-6, elementwise_affine=True) + self.attn3 = Attention( + query_dim=hidden_size, + dim_head=hidden_size // n_heads, + heads=n_heads, + dropout=dropout, + qk_norm="rms_norm", + eps=1e-6, + bias=False + ) + + self.final_norm_x = FP32LayerNorm(hidden_size, eps=1e-5, elementwise_affine=True) + self.final_norm_q = FP32LayerNorm(hidden_size, eps=1e-5, elementwise_affine=True) + self.mlp_x = FeedForward(dim=hidden_size, dropout=dropout) + self.mlp_q = FeedForward(dim=hidden_size, dropout=dropout) + + def forward( + self, + x: torch.FloatTensor, + q: torch.FloatTensor + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + # 1. query self-attention + # q_norm = self.norm1(q) + # q = self.attn1(q_norm) + q + q = self.norm1(q) + q = self.attn1(q) + q + + # 2. point cloud to query cross-attention + q = q + self.attn2(self.norm2(q), x) + + # 3. query to point cloud cross-attention + x = x + self.attn3(self.norm3(x), q) + + # 4. final MLP + x = x + self.mlp_x(self.final_norm_x(x)) + q = q + self.mlp_q(self.final_norm_q(q)) + + return x, q + + +class Articulate3D(nn.Module): + def __init__( + self, + input_dim: int, + hidden_size: int, + num_layers: int, + n_heads: int, + dropout: float, + use_normals: bool = False, + use_text_prompts: bool = False, + max_parts: int = 128, + use_part_id_embedding: bool = True, + use_raw_coords: bool = False, + use_point_features_for_motion_decoding: bool = False, + point_feature_random_ratio: float = 0.0, + num_mask_hypotheses: int = 1, + motion_representation: str = 'per_part_plucker', # one of ["per_part_plucker", "per_point_closest"] + ): + super(Articulate3D, self).__init__() + + self.feat_proj = nn.Linear(input_dim, hidden_size) + self.pos_embed = PositionalEmbedder( + frequency_embedding_size=64, + hidden_size=hidden_size, + input_dim=3, + raw=use_raw_coords + ) + + self.use_normals = use_normals + if use_normals: + self.normal_embed = PositionalEmbedder( + frequency_embedding_size=64, + hidden_size=hidden_size, + input_dim=3, + raw=use_raw_coords + ) + + self.blocks = nn.ModuleList([ + Block( + hidden_size=hidden_size, + n_heads=n_heads, + dropout=dropout, + ) + for _ in range(num_layers) + ]) + + # Decoders + self.num_mask_hypotheses = num_mask_hypotheses + if self.num_mask_hypotheses == 1: + self.point_mask_decoder = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, 1) + ) + self.point_mask_decoding_func = lambda p, q: [( + self.point_mask_decoder(torch.cat([ + p.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, M, D) + q.unsqueeze(1).expand(-1, p.size(1), -1, -1) # (B, N, M, D) + ], dim=-1)).squeeze(-1) + )] + else: + self.point_mask_decoder = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, 1) + ) + for _ in range(self.num_mask_hypotheses) + ]) + self.point_mask_decoding_func = lambda p, q: [( + self.point_mask_decoder[i](torch.cat([ + p.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, M, D) + q.unsqueeze(1).expand(-1, p.size(1), -1, -1) # (B, N, M, D) + ], dim=-1)).squeeze(-1) + ) for i in range(self.num_mask_hypotheses)] + + self.use_point_features_for_motion_decoding = use_point_features_for_motion_decoding + self.point_feature_random_ratio = point_feature_random_ratio + + part_hierarchy_input_dim = hidden_size * 4 if self.use_point_features_for_motion_decoding else hidden_size * 2 + self.part_hierarchy_decoder = nn.Sequential( + nn.Linear(part_hierarchy_input_dim, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, 1) + ) + + motion_input_dim = hidden_size * 2 if self.use_point_features_for_motion_decoding else hidden_size + self.part_motion_classifier = nn.Sequential( + nn.Linear(motion_input_dim, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, 4) # 4 classes: 0 --> "no motion", 1 --> "revolute", 2 --> "prismatic", 3 --> "both" + ) + + self.motion_representation = motion_representation + motion_input_dim = hidden_size * 2 if self.use_point_features_for_motion_decoding else hidden_size + self.revolute_motion_decoder = nn.Sequential( + nn.Linear(motion_input_dim, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, (6 if self.motion_representation == 'per_part_plucker' else 3) + 2) # Plucker coordinate in R^6 and low & high limits. + ) + self.prismatic_motion_decoder = nn.Sequential( + nn.Linear(motion_input_dim, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, 3 + 2) # Axis only in R^3 and low & high limits. + ) + if self.motion_representation == 'per_point_closest': + motion_input_dim = hidden_size * 3 if self.use_point_features_for_motion_decoding else hidden_size * 2 + self.point_motion_decoder = nn.Sequential( + nn.Linear(motion_input_dim, hidden_size * 4), + nn.SiLU(), + nn.Linear(hidden_size * 4, 3) # Closest point on the axis to the point + ) + + if use_text_prompts: + raise NotImplementedError("Text prompts are not implemented yet") + + self.max_parts = max_parts + self.use_part_id_embedding = use_part_id_embedding + if use_part_id_embedding: + self.part_id_embed = PositionalEmbedder( + frequency_embedding_size=64, + hidden_size=hidden_size, + input_dim=1, + raw=False + ) + + self.matcher = HungarianMatcher() + + def forward_attn( + self, + xyz: torch.FloatTensor, + feats: torch.FloatTensor, + query_xyz: Optional[torch.FloatTensor] = None, + query_feats: Optional[torch.FloatTensor] = None, + normals: Optional[torch.FloatTensor] = None, + text_prompts: Optional[List[str]] = None, + ): + batch_size = xyz.shape[0] + x = self.feat_proj(feats) + self.pos_embed(xyz) + if self.use_normals: + assert normals is not None + x = x + self.normal_embed(normals) + + if text_prompts is not None: + raise NotImplementedError("Text prompts are not implemented yet") + + assert query_xyz is not None or query_feats is not None or self.use_part_id_embedding + q = 0 + if self.use_part_id_embedding: + num_parts = self.max_parts if query_xyz is None else query_xyz.shape[1] + part_indices = torch.arange(num_parts, device=x.device, dtype=torch.float32) / num_parts + q = self.part_id_embed(part_indices.unsqueeze(0).unsqueeze(-1).expand(batch_size, -1, -1)) + if query_xyz is not None: + q = q + self.pos_embed(query_xyz) + if query_feats is not None: + q = q + self.feat_proj(query_feats) + + for block in self.blocks: + x, q = block(x, q) + + return x, q + + def forward_results( + self, + xyz: torch.FloatTensor, + feats: torch.FloatTensor, + query_xyz: Optional[torch.FloatTensor] = None, + query_feats: Optional[torch.FloatTensor] = None, + normals: Optional[torch.FloatTensor] = None, + text_prompts: Optional[List[str]] = None, + forward_motion_class: bool = True, + forward_motion_params: bool = True, + gt_part_ids: Optional[torch.LongTensor] = None, + overwrite_part_ids: Optional[torch.LongTensor] = None, + num_valid_parts: Optional[torch.LongTensor] = None, + run_matching: bool = False, + force_hyp_idx: int = -1, + min_part_confidence: float = 0.0 + ): + batch_size, num_points = xyz.shape[:2] + x, q = self.forward_attn(xyz, feats, query_xyz, query_feats, normals, text_prompts) + + point_masks = self.point_mask_decoding_func(x, q) # (B, M, N) + + best_point_mask_id = None + if self.num_mask_hypotheses == 1: + point_mask = point_masks[0] + best_point_mask_id = torch.zeros(batch_size, device=x.device, dtype=torch.long) + elif force_hyp_idx >= 0: + point_mask = point_masks[force_hyp_idx] + else: + all_point_masks = torch.cat(point_masks, dim=0) + all_gt_part_ids = torch.cat([gt_part_ids] * self.num_mask_hypotheses, dim=0) + all_num_valid_parts = torch.cat([num_valid_parts] * self.num_mask_hypotheses, dim=0) + + if run_matching: + matching = self.matcher(all_point_masks, all_gt_part_ids, all_num_valid_parts) + sorted_indices = self.get_sorted_indices(matching) + all_point_masks = self.sort_result(all_point_masks.permute(0, 2, 1), sorted_indices).permute(0, 2, 1) + + all_point_mask_losses = self.compute_point_mask_loss( + all_point_masks.detach(), all_gt_part_ids, all_num_valid_parts, + reduction='none' + ).reshape(batch_size * self.num_mask_hypotheses, num_points).mean(-1).split(batch_size) + best_point_mask_id = torch.argmin(torch.stack(all_point_mask_losses, dim=0), dim=0) + + stacked_point_masks = torch.stack(point_masks, dim=0) # (num_hypotheses, batch_size, M, N) + batch_indices = torch.arange(batch_size, device=best_point_mask_id.device) + point_mask = stacked_point_masks[best_point_mask_id, batch_indices] + point_mask = point_mask + stacked_point_masks.sum(dim=0) * 0.0 # ensure all hypotheses receive gradients + + sorted_indices = None + if run_matching: + assert num_valid_parts is not None and gt_part_ids is not None + matching = self.matcher(point_mask, gt_part_ids, num_valid_parts) + sorted_indices = self.get_sorted_indices(matching) + q = self.sort_result(q, sorted_indices) + point_mask = self.sort_result(point_mask.permute(0, 2, 1), sorted_indices).permute(0, 2, 1) + + if self.training: + assert gt_part_ids is not None + part_ids = gt_part_ids + elif overwrite_part_ids is not None: + part_ids = overwrite_part_ids + else: + part_ids = point_mask.argmax(dim=-1) + # Mask all columns where no point is affiliated to -torch.inf + num_parts = point_mask.shape[-1] + for part_id in range(num_parts): + if not (part_ids == part_id).any(): + point_mask[..., part_id] = -torch.inf + + done = False + while not done: + done = True + probs = point_mask.softmax(dim=-1) + idx = part_ids.long().unsqueeze(-1) + part_probs = probs.gather(dim=-1, index=idx).squeeze(-1) + for part_id in part_ids.unique(): + # Compute part confidence as the average probability + part_confidence = (part_probs[part_ids == part_id].log().sum() / (part_ids == part_id).sum()).exp() + if part_confidence < min_part_confidence: + done = False + point_mask[..., part_id] = -torch.inf + part_ids = torch.argmax(point_mask, dim=-1) + + if self.use_point_features_for_motion_decoding: + max_parts = point_mask.shape[-1] + part_id_mask = part_ids.unsqueeze(-1) == torch.arange(max_parts, device=x.device).unsqueeze(0).unsqueeze(0) # (B, N, M) + sample_probs = torch.where( + part_id_mask, + 1.0 - (self.point_feature_random_ratio if self.training else 0), + self.point_feature_random_ratio if self.training else 0 + ) # (B, N, M) + point_to_part_mask = torch.bernoulli(sample_probs) # (B, N, M) + + part_features = torch.einsum('bnd,bnm->bmd', x, point_to_part_mask) # (B, M, D) + counts = point_to_part_mask.sum(dim=1) # (B, M) + nonzero_mask = counts > 0 # (B, M) + part_features = torch.where( + nonzero_mask.unsqueeze(-1), + part_features / counts.unsqueeze(-1), + part_features + ) + + q = torch.cat([q, part_features], dim=-1) + + # Prepare input for part hierarchy decoder + part_adjacency_matrix = self.part_hierarchy_decoder( + torch.cat([ + q.unsqueeze(2).expand(-1, -1, q.size(1), -1), # (B, N, N, D) + q.unsqueeze(1).expand(-1, q.size(1), -1, -1) # (B, N, N, D) + ], dim=-1) # ( 1, 5, 5) + ).squeeze(-1) + + ( + part_motion_logits, + revolute_plucker, + revolute_range, + prismatic_axis, + prismatic_range + ) = (None, None, None, None, None) + + if forward_motion_class: + part_motion_logits = self.part_motion_classifier(q) + + closest_point_on_axis = None + if self.motion_representation == 'per_part_plucker': + if forward_motion_params and forward_motion_class: + revolute_motion_params = self.revolute_motion_decoder(q) + prismatic_motion_params = self.prismatic_motion_decoder(q) + revolute_plucker, revolute_range = ( + revolute_motion_params[..., :6], + revolute_motion_params[..., 6:] + ) + prismatic_axis, prismatic_range = ( + prismatic_motion_params[..., :3], + prismatic_motion_params[..., 3:] + ) + elif self.motion_representation == 'per_point_closest': + if forward_motion_params and forward_motion_class: + revolute_motion_params = self.revolute_motion_decoder(q) + prismatic_motion_params = self.prismatic_motion_decoder(q) + revolute_plucker, revolute_range = ( + revolute_motion_params[..., :3], + revolute_motion_params[..., 3:] + ) + prismatic_axis, prismatic_range = ( + prismatic_motion_params[..., :3], + prismatic_motion_params[..., 3:] + ) + + per_point_q = torch.gather(q, dim=1, index=part_ids.unsqueeze(-1).expand(-1, -1, q.size(-1))) + motion_decoder_input = torch.cat([x, per_point_q], dim=-1) + closest_point_on_axis = self.point_motion_decoder(motion_decoder_input) + return ( + point_mask.contiguous(), + part_adjacency_matrix, + part_motion_logits, + revolute_plucker, revolute_range, + prismatic_axis, prismatic_range, + closest_point_on_axis, + part_ids, + best_point_mask_id + ) + + def get_sorted_indices( + self, + matching: torch.LongTensor + ): + """ + Get sorted indices of matching values. + Columns with valid matches (matching != -1) are placed first, sorted by matching values. + Columns with invalid matches are placed last. + + Args: + matching: LongTensor of shape (batch_size, num_columns) with part assignments (-1 for invalid) + + Returns: + sorted_indices: LongTensor of shape (batch_size, num_columns) with permutation indices + """ + batch_size, num_columns = matching.shape + + # Create valid mask: (batch_size, num_columns) + valid_mask = matching > -1 + + # For sorting, replace -1 with large values so they sort last + matching_for_sort = matching.clone() + matching_for_sort[~valid_mask] = num_columns + + # Sort columns by matching values within each batch + sorted_indices = torch.argsort(matching_for_sort, dim=-1) # (batch_size, num_columns) + + return sorted_indices + + def sort_result( + self, + result: torch.Tensor, + sorted_indices: torch.LongTensor, + ): + """ + Reorder columns of a tensor using pre-computed sorted indices. + + This method applies a column permutation to the result tensor based on sorted indices + that were previously computed (e.g., from get_sorted_indices method). The permutation + reorders columns so that those corresponding to valid part matches appear first. + + Args: + result: Tensor of shape (batch_size, num_columns, ...) to be permuted along the column dimension + sorted_indices: LongTensor of shape (batch_size, num_columns) containing permutation indices + + Returns: + permuted_result: Tensor with columns reordered according to sorted_indices + """ + new_dims = [None] * (len(result.shape) - 2) + expanded_indices = sorted_indices[(..., *new_dims)] + expanded_indices = expanded_indices.expand(-1, -1, *result.shape[2:]) + permuted_result = torch.gather(result, dim=1, index=expanded_indices) + return permuted_result + + def compute_point_mask_loss( + self, + point_mask_logits: torch.FloatTensor, + part_ids: torch.LongTensor, + num_valid_parts: torch.LongTensor, + reduction: str = 'mean' + ) -> Tuple[dict, Optional[float]]: + device = point_mask_logits.device + num_parts = point_mask_logits.shape[-1] + + # invalid_parts_mask = ( + # torch.arange(num_parts).unsqueeze(0).unsqueeze(0).to(device) >= \ + # num_valid_parts.unsqueeze(-1).unsqueeze(-1) # (B, 1, M) + # ) + # point_mask_logits = point_mask_logits.masked_fill(invalid_parts_mask, float('-inf')) + + return torch.nn.functional.cross_entropy( + point_mask_logits.reshape(-1, num_parts), + part_ids.reshape(-1), + reduction=reduction + ) + + def compute_dice_loss( + self, + point_mask_logits: torch.FloatTensor, + part_ids: torch.LongTensor, + num_valid_parts: torch.LongTensor, + ) -> Tuple[dict, Optional[float]]: + """ + Compute soft dice loss for multi-class point segmentation. + + Args: + point_mask_logits: (B, N, M) - predicted point mask logits + part_ids: (B, N) - ground truth part assignments + num_valid_parts: (B,) - number of valid parts for each batch + sorted_indices: (B, M) - sorted indices of part assignments + smooth: smoothing constant to avoid division by zero + Returns: + dice_loss: scalar loss value + """ + device = point_mask_logits.device + num_parts = point_mask_logits.shape[-1] + + # Create mask for invalid parts + invalid_parts_mask = ( + torch.arange(num_parts).unsqueeze(0).unsqueeze(0).to(device) >= \ + num_valid_parts.unsqueeze(-1).unsqueeze(-1) # (B, 1, M) + ) + + # Apply softmax to get probabilities, masking invalid parts + # point_mask_logits_masked = point_mask_logits.masked_fill(invalid_parts_mask, float('-inf')) + # point_mask_probs = torch.softmax(point_mask_logits_masked, dim=-1) # (B, N, M) + point_mask_probs = torch.softmax(point_mask_logits, dim=-1) + + # Convert part_ids to one-hot encoding + part_ids_onehot = torch.zeros_like(point_mask_probs) # (B, N, M) + part_ids_onehot.scatter_(-1, part_ids.unsqueeze(-1), 1.0) # (B, N, M) + + # Create mask for valid parts only + valid_parts_mask = ~invalid_parts_mask # (B, 1, M) + + # Compute dice coefficient for each part separately + # For each part m: intersection = sum over points where both pred and target are high for part m + intersection = (point_mask_probs * part_ids_onehot).sum(dim=1) # (B, M) + + # Union = sum of predictions + sum of targets for each part + pred_sum = point_mask_probs.sum(dim=1) # (B, M) + target_sum = part_ids_onehot.sum(dim=1) # (B, M) + + # Dice coefficient: 2 * intersection / (pred_sum + target_sum) + dice_scores = (2.0 * intersection + 1e-6) / (pred_sum + target_sum + 1e-6) # (B, M) + + # Only compute loss for valid parts + valid_parts_per_batch = valid_parts_mask.squeeze(1) # (B, M) + dice_loss = 1.0 - dice_scores + + # Mask out invalid parts + dice_loss = dice_loss * valid_parts_per_batch + + # Average over valid parts and batches + total_valid_parts = valid_parts_per_batch.sum() + if total_valid_parts > 0: + return dice_loss.sum() / total_valid_parts + else: + return self.parameters().__next__().sum() * 0 + + def compute_motion_hierarchy_loss( + self, + logits_motion_structure: torch.FloatTensor, + gt_motion_structure: torch.BoolTensor, + num_valid_parts: torch.LongTensor + ) -> Tuple[dict, Optional[float]]: + """ + Compute binary cross-entropy loss for part hierarchy. + + Args: + logits_motion_structure: (B, M, M) - predicted adjacency matrix logits + gt_motion_structure: (B, M, M) - ground truth adjacency matrix + num_valid_parts: (B,) - number of valid parts for each batch + sorted_indices: (B, M) - sorted indices of part assignments + Returns: + motion_hierarchy_loss: scalar loss value + """ + + device = logits_motion_structure.device + max_parts = logits_motion_structure.shape[-1] + + # Create 2D mask for valid parts (top-left num_valid_parts × num_valid_parts submatrix) + row_indices = torch.arange(max_parts, device=device).unsqueeze(0).unsqueeze(-1) # (1, M, 1) + col_indices = torch.arange(max_parts, device=device).unsqueeze(0).unsqueeze(0) # (1, 1, M) + + valid_parts_mask = ( + (row_indices < num_valid_parts.unsqueeze(-1).unsqueeze(-1)) & # (B, M, 1) + (col_indices < num_valid_parts.unsqueeze(-1).unsqueeze(-1)) # (B, 1, M) + ) # (B, M, M) + + logits_valid = logits_motion_structure[valid_parts_mask] + gt_valid = gt_motion_structure[valid_parts_mask].float() + # Compute binary cross-entropy loss only for valid parts + return torch.nn.functional.binary_cross_entropy_with_logits( + logits_valid, # (N_valid,) + gt_valid, # (N_valid,) + pos_weight=(gt_valid < 0.5).sum() / (gt_valid > 0.5).sum() + ) + + def compute_part_motion_classification_loss( + self, + part_motion_logits: torch.FloatTensor, + gt_part_motion_class: torch.LongTensor, + num_valid_parts: torch.LongTensor + ) -> Tuple[dict, Optional[float]]: + """ + Compute part motion classification loss. + + Args: + part_motion_logits: (B, M, 4) - predicted motion class logits + gt_part_motion_class: (B, M) - ground truth motion classes + num_valid_parts: (B,) - number of valid parts for each batch + sorted_indices: (B, M) - sorted indices of part assignments + Returns: + part_motion_loss: scalar loss value + """ + + # Only compute loss for valid parts + valid_parts_mask = torch.arange(part_motion_logits.shape[1], device=part_motion_logits.device).unsqueeze(0) < num_valid_parts.unsqueeze(1) # (B, M) + + logits_valid = part_motion_logits[valid_parts_mask] + gt_valid = gt_part_motion_class[valid_parts_mask] + + # Compute cross-entropy loss only for valid parts + return torch.nn.functional.cross_entropy( + logits_valid.reshape(-1, 4), # (B*M, 4) + gt_valid.reshape(-1), # (B*M,) + reduction='mean' + ) + + def compute_motion_axis_losses( + self, + revolute_plucker: torch.FloatTensor, + prismatic_axis: torch.FloatTensor, + gt_revolute_plucker: torch.FloatTensor, + gt_prismatic_axis: torch.FloatTensor, + num_valid_parts: torch.LongTensor + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """ + Compute motion parameter losses for revolute and prismatic motion. + + Args: + revolute_plucker: (B, M, 6) - predicted revolute motion parameters + prismatic_axis: (B, M, 3) - predicted prismatic motion parameters + gt_revolute_plucker: (B, M, 6) - ground truth revolute motion parameters + gt_prismatic_axis: (B, M, 3) - ground truth prismatic motion parameters + num_valid_parts: (B,) - number of valid parts for each batch + sorted_indices: (B, M) - sorted indices of part assignments + Returns: + revolute_loss: scalar loss value + prismatic_loss: scalar loss value + """ + + valid_parts_mask = ( + torch.arange(revolute_plucker.shape[1], device=revolute_plucker.device).unsqueeze(0) + < num_valid_parts.unsqueeze(1) + ) + + revolute_loss = self.parameters().__next__().sum() * 0 + # 1. Revolute loss + valid_revolute_mask = valid_parts_mask & torch.any(gt_revolute_plucker[..., :3] != 0, dim=-1) + if valid_revolute_mask.any(): + revolute_plucker_valid = revolute_plucker[valid_revolute_mask] + gt_revolute_plucker_valid = gt_revolute_plucker[valid_revolute_mask] + revolute_loss = torch.nn.functional.l1_loss( + revolute_plucker_valid, + gt_revolute_plucker_valid[..., :revolute_plucker_valid.shape[-1]] + ) + + prismatic_loss = self.parameters().__next__().sum() * 0 + # 2. Prismatic loss + valid_prismatic_mask = valid_parts_mask & torch.any(gt_prismatic_axis[..., :3] != 0, dim=-1) + if valid_prismatic_mask.any(): + prismatic_axis_valid = prismatic_axis[valid_prismatic_mask] + gt_prismatic_axis_valid = gt_prismatic_axis[valid_prismatic_mask] + prismatic_loss = torch.nn.functional.l1_loss( + prismatic_axis_valid, + gt_prismatic_axis_valid + ) + + return revolute_loss, prismatic_loss + + def compute_motion_range_losses( + self, + revolute_range: torch.FloatTensor, + prismatic_range: torch.FloatTensor, + gt_revolute_range: torch.FloatTensor, + gt_prismatic_range: torch.FloatTensor, + num_valid_parts: torch.LongTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """ + Compute motion range losses for revolute and prismatic joints. + """ + valid_parts_mask = ( + torch.arange(revolute_range.shape[1], device=revolute_range.device).unsqueeze(0) + < num_valid_parts.unsqueeze(1) + ) + + valid_revolute_mask = valid_parts_mask & torch.any(gt_revolute_range != 0, dim=-1) + valid_prismatic_mask = valid_parts_mask & torch.any(gt_prismatic_range != 0, dim=-1) + + revolute_range_loss = self.parameters().__next__().sum() * 0 + if valid_revolute_mask.any(): + revolute_range_valid = revolute_range[valid_revolute_mask] + gt_revolute_range_valid = gt_revolute_range[valid_revolute_mask] + revolute_range_loss = torch.nn.functional.l1_loss( + revolute_range_valid, gt_revolute_range_valid + ) + + prismatic_range_loss = self.parameters().__next__().sum() * 0 + if valid_prismatic_mask.any(): + prismatic_range_valid = prismatic_range[valid_prismatic_mask] + gt_prismatic_range_valid = gt_prismatic_range[valid_prismatic_mask] + prismatic_range_loss = torch.nn.functional.l1_loss( + prismatic_range_valid, gt_prismatic_range_valid + ) + + return revolute_range_loss, prismatic_range_loss + + def forward( + self, + xyz: torch.FloatTensor, + feats: torch.FloatTensor, + part_ids: torch.LongTensor, + num_valid_parts: torch.LongTensor, + part_structure_matrix: torch.BoolTensor, + query_xyz: Optional[torch.FloatTensor] = None, + query_feats: Optional[torch.FloatTensor] = None, + gt_part_motion_class: Optional[torch.LongTensor] = None, + gt_revolute_plucker: Optional[torch.FloatTensor] = None, + gt_revolute_range: Optional[torch.FloatTensor] = None, + gt_prismatic_axis: Optional[torch.FloatTensor] = None, + gt_prismatic_range: Optional[torch.FloatTensor] = None, + gt_closest_point_on_axis: Optional[torch.FloatTensor] = None, + normals: Optional[torch.FloatTensor] = None, + text_prompts: Optional[List[str]] = None, + run_matching: bool = False, + ) -> Tuple[dict, torch.LongTensor]: + """ + Forward pass during training. + """ + forward_motion_class = gt_part_motion_class is not None + forward_motion_params = ( + ( + gt_revolute_plucker is not None or \ + gt_revolute_range is not None or \ + gt_prismatic_axis is not None or \ + gt_prismatic_range is not None + ) and self.motion_representation == 'per_part_plucker' + ) or ( + ( + gt_revolute_plucker is not None or \ + gt_revolute_range is not None or \ + gt_prismatic_axis is not None or \ + gt_prismatic_range is not None + ) and gt_closest_point_on_axis is not None and self.motion_representation == 'per_point_closest' + ) + + ( + point_mask, + part_adjacency_matrix, + part_motion_logits, + revolute_plucker, revolute_range, + prismatic_axis, prismatic_range, + closest_point_on_axis, + _, best_point_mask_id + ) = self.forward_results( + xyz, feats, query_xyz, query_feats, + normals, text_prompts, + forward_motion_class, forward_motion_params, + gt_part_ids=part_ids, + num_valid_parts=num_valid_parts, + run_matching=run_matching + ) + + # Compute losses + # 1. Point mask loss + point_mask_loss = self.compute_point_mask_loss( + point_mask, + part_ids, + num_valid_parts + ) + + # 2. Dice loss + dice_loss = self.compute_dice_loss( + point_mask, + part_ids, + num_valid_parts + ) + + # 2. Part hierarchy loss + motion_hierarchy_loss = self.compute_motion_hierarchy_loss( + part_adjacency_matrix, + part_structure_matrix, + num_valid_parts + ) + + # 3. Part motion classification loss + part_motion_classification_loss = self.parameters().__next__().sum() * 0 + if gt_part_motion_class is not None and part_motion_logits is not None and forward_motion_class: + part_motion_classification_loss = self.compute_part_motion_classification_loss( + part_motion_logits, + gt_part_motion_class, + num_valid_parts + ) + + # 4. Motion axis losses (if needed) + part_motion_axis_loss_revolute = self.parameters().__next__().sum() * 0 + part_motion_axis_loss_prismatic = self.parameters().__next__().sum() * 0 + if forward_motion_params and ( + revolute_plucker is not None and + prismatic_axis is not None and + gt_revolute_plucker is not None and + gt_prismatic_axis is not None + ): + part_motion_axis_loss_revolute, part_motion_axis_loss_prismatic = \ + self.compute_motion_axis_losses( + revolute_plucker, prismatic_axis, + gt_revolute_plucker, gt_prismatic_axis, num_valid_parts + ) + + # 5. Motion range losses (if needed) + part_motion_range_loss_revolute = self.parameters().__next__().sum() * 0 + part_motion_range_loss_prismatic = self.parameters().__next__().sum() * 0 + if forward_motion_params and ( + revolute_range is not None and + prismatic_range is not None and + gt_revolute_range is not None and + gt_prismatic_range is not None + ): + part_motion_range_loss_revolute, part_motion_range_loss_prismatic = \ + self.compute_motion_range_losses( + revolute_range, prismatic_range, + gt_revolute_range, gt_prismatic_range, num_valid_parts + ) + + # 6. Per point closest point on axis loss (if needed) + point_closest_point_on_axis_loss = self.parameters().__next__().sum() * 0 + if forward_motion_params and ( + closest_point_on_axis is not None and + gt_closest_point_on_axis is not None + ): + per_point_motion_type = torch.gather(gt_part_motion_class, dim=1, index=part_ids) + revolute_points_flag = per_point_motion_type.eq(1) | per_point_motion_type.eq(3) + point_closest_point_on_axis_loss = torch.nn.functional.l1_loss( + closest_point_on_axis[revolute_points_flag], + gt_closest_point_on_axis[revolute_points_flag] + ) + + # Combine all losses + losses = dict( + point_mask_loss=point_mask_loss, + dice_loss=dice_loss, + motion_hierarchy_loss=motion_hierarchy_loss, + part_motion_classification_loss=part_motion_classification_loss, + part_motion_axis_loss_revolute=part_motion_axis_loss_revolute, + part_motion_axis_loss_prismatic=part_motion_axis_loss_prismatic, + part_motion_range_loss_revolute=part_motion_range_loss_revolute, + part_motion_range_loss_prismatic=part_motion_range_loss_prismatic, + point_closest_point_on_axis_loss=point_closest_point_on_axis_loss + ) + return losses, best_point_mask_id + + def _postprocess_results( + self, + point_mask, + part_adjacency_matrix, + part_motion_logits, + revolute_plucker, revolute_range, + prismatic_axis, prismatic_range, + closest_point_on_axis, + part_ids, + ): + motion_hierarchy_batch = [] + for batch_idx, single_part_adjacency_matrix in enumerate(part_adjacency_matrix): # (B, N, N) + unique_part_ids = torch.unique(part_ids[batch_idx]) + + # Extract submatrix for only the unique part IDs that exist + submatrix = single_part_adjacency_matrix[unique_part_ids][:, unique_part_ids] + + # Extract motion hierarchy from the submatrix + hierarchy_compressed = extract_motion_hierarchy(submatrix) + + # Map back to original indices + hierarchy_original = [] + for parent_idx, child_idx in hierarchy_compressed: + original_parent = unique_part_ids[parent_idx].item() + original_child = unique_part_ids[child_idx].item() + hierarchy_original.append((original_parent, original_child)) + + motion_hierarchy_batch.append(hierarchy_original) + + part_motion_class = torch.argmax(part_motion_logits, dim=-1) + is_part_revolute = part_motion_class.eq(1) | part_motion_class.eq(3) + is_part_prismatic = part_motion_class.eq(2) | part_motion_class.eq(3) + + # Make sure the plucker and axis parameters are valid + if revolute_plucker is not None: + revolute_plucker[..., :3] = revolute_plucker[..., :3] / torch.norm(revolute_plucker[..., :3], dim=-1, keepdim=True).clamp_min(1e-8) + if prismatic_axis is not None: + prismatic_axis[..., :3] = prismatic_axis[..., :3] / torch.norm(prismatic_axis[..., :3], dim=-1, keepdim=True).clamp_min(1e-8) + + # Assert that all part IDs in motion hierarchy have at least one associated point + for batch_idx, hierarchy in enumerate(motion_hierarchy_batch): + if len(hierarchy) > 0: + # Get all part IDs mentioned in the hierarchy + hierarchy_part_ids = set() + for parent_id, child_id in hierarchy: + hierarchy_part_ids.add(parent_id) + hierarchy_part_ids.add(child_id) + + # Get unique part IDs that actually exist in the point cloud + existing_part_ids = set(part_ids[batch_idx].cpu().numpy().tolist()) + + # Assert that all hierarchy part IDs exist in the point cloud + missing_part_ids = hierarchy_part_ids - existing_part_ids + assert len(missing_part_ids) == 0, f"Batch {batch_idx}: Part IDs {missing_part_ids} in motion hierarchy have no associated points" + + return dict( + part_ids=part_ids.cpu().numpy().squeeze(0), + motion_hierarchy=motion_hierarchy_batch[0], + is_part_revolute=is_part_revolute.cpu().numpy().squeeze(0), + is_part_prismatic=is_part_prismatic.cpu().numpy().squeeze(0), + revolute_plucker=revolute_plucker.cpu().numpy().squeeze(0) if revolute_plucker is not None else None, + revolute_range=revolute_range.cpu().numpy().squeeze(0) if revolute_range is not None else None, + prismatic_axis=prismatic_axis.cpu().numpy().squeeze(0) if prismatic_axis is not None else None, + prismatic_range=prismatic_range.cpu().numpy().squeeze(0) if prismatic_range is not None else None, + closest_point_on_axis=closest_point_on_axis.cpu().numpy().squeeze(0) if closest_point_on_axis is not None else None, + ) + + @torch.no_grad() + def infer( + self, + xyz: torch.FloatTensor, + feats: torch.FloatTensor, + query_xyz: Optional[torch.FloatTensor] = None, + query_feats: Optional[torch.FloatTensor] = None, + normals: Optional[torch.FloatTensor] = None, + text_prompts: Optional[List[str]] = None, + forward_motion_class: bool = True, + forward_motion_params: bool = True, + run_matching: bool = False, + gt_part_ids: Optional[torch.LongTensor] = None, + overwrite_part_ids: Optional[torch.LongTensor] = None, + output_all_hyps: bool = False, + min_part_confidence: float = 0.0 + ): + assert xyz.shape[0] == 1, "Only batch size 1 is supported" + + num_valid_parts = None + if gt_part_ids is not None: + num_valid_parts = gt_part_ids.max(dim=-1).values + 1 + + results = [] + if output_all_hyps: + for hyp_idx in range(self.num_mask_hypotheses): + results.append(self.forward_results( + xyz, feats, query_xyz, query_feats, + normals, text_prompts, + forward_motion_class, forward_motion_params, + gt_part_ids=gt_part_ids, + overwrite_part_ids=overwrite_part_ids, + num_valid_parts=num_valid_parts, + run_matching=run_matching, + force_hyp_idx=hyp_idx, + min_part_confidence=min_part_confidence + )) + else: + results.append(self.forward_results( + xyz, feats, query_xyz, query_feats, + normals, text_prompts, + forward_motion_class, forward_motion_params, + gt_part_ids=gt_part_ids, + overwrite_part_ids=overwrite_part_ids, + num_valid_parts=num_valid_parts, + run_matching=run_matching, + force_hyp_idx=-1, + min_part_confidence=min_part_confidence + )) + + postprocessed_results = [self._postprocess_results(*result[:-1]) for result in results] # Ignore the last element (best_point_mask_id) + for postprocessed_result in postprocessed_results: + if self.motion_representation == 'per_point_closest': + postprocessed_result['revolute_plucker'] = closest_point_on_axis_to_revolute_plucker( + postprocessed_result['closest_point_on_axis'], + postprocessed_result['part_ids'], + postprocessed_result['is_part_revolute'], + postprocessed_result['is_part_prismatic'], + postprocessed_result['revolute_plucker'] + ) + + return postprocessed_results + + +def Articulate3D_S(**kwargs): + return Articulate3D(num_layers=6, hidden_size=384, n_heads=6, **kwargs) + +def Articulate3D_B(**kwargs): + return Articulate3D(num_layers=6, hidden_size=768, n_heads=12, **kwargs) + +def Articulate3D_L(**kwargs): + return Articulate3D(num_layers=12, hidden_size=1024, n_heads=16, **kwargs) + +def Articulate3D_XL(**kwargs): + return Articulate3D(num_layers=14, hidden_size=1152, n_heads=16, **kwargs) diff --git a/particulate/visualization_utils.py b/particulate/visualization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f60cf3d5c210ce2ca5e5ea7b9fab9fffff565b1e --- /dev/null +++ b/particulate/visualization_utils.py @@ -0,0 +1,300 @@ +from typing import Optional, Tuple + +import numpy as np +import trimesh +from PIL import Image +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d.art3d import Poly3DCollection + +COLORS = [ + (72, 36, 117), + (33, 145, 140), + (189, 223, 38), + (153, 80, 8), + (12, 12, 242), + (242, 12, 150), + (12, 242, 150), + (12, 150, 242) +] +ARROW_COLOR_REVOLUTE = (255, 0, 0) +ARROW_COLOR_PRISMATIC = (255, 255, 0) + + +def plot_mesh(mesh): + verts = mesh.vertices + faces = getattr(mesh, 'faces', []) + + fig = plt.figure(figsize=(8, 8)) + ax = fig.add_subplot(111, projection='3d') + triangles = verts[faces] + poly = Poly3DCollection(triangles, facecolors='lightblue', edgecolor='none', alpha=0.5, zorder=1) + ax.add_collection3d(poly) + + min_vals = verts.min(axis=0) + max_vals = verts.max(axis=0) + center = (min_vals + max_vals) / 2 + max_range = (max_vals - min_vals).max() / 2.0 + + # Set axis limits to ensure equal aspect ratio + ax.set_xlim(center[0] - max_range, center[0] + max_range) + ax.set_ylim(center[1] - max_range, center[1] + max_range) + ax.set_zlim(center[2] - max_range, center[2] + max_range) + + # Draw axes (with zorder to ensure they're visible above the mesh) + length = max_range * 1.2 + # X axis (Red) + ax.quiver(center[0], center[1], center[2], length, 0, 0, color='r', label='X', + linewidth=2, arrow_length_ratio=0.15, zorder=10) + # Y axis (Green) + ax.quiver(center[0], center[1], center[2], 0, length, 0, color='g', label='Y', + linewidth=2, arrow_length_ratio=0.15, zorder=10) + # Z axis (Blue) + ax.quiver(center[0], center[1], center[2], 0, 0, length, color='b', label='Z', + linewidth=2, arrow_length_ratio=0.15, zorder=10) + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.legend() + ax.set_title("Mesh with Axes (Select Up Direction)") + + ax.set_box_aspect([1,1,1]) + return fig + + +def create_textured_mesh_parts(mesh_parts, colors=COLORS, tex_res=256): + # Create a texture map with evenly distributed color blocks + # Use a horizontal strip layout: texture height = tex_res, width = num_parts * tex_res + texture_height = block_width = tex_res + texture_width = len(mesh_parts) * block_width + texture_array = np.zeros((texture_height, texture_width, 3), dtype=np.uint8) + + for i in range(len(mesh_parts)): + color_rgb = colors[i % len(colors)][:3] + x_start = i * block_width + x_end = (i + 1) * block_width + texture_array[:, x_start:x_end] = color_rgb + texture = Image.fromarray(texture_array) + + mesh_parts_colored = [] + for i, mesh_part in enumerate(mesh_parts): + # Create UV coordinates specifically for this part + # All faces in this part should point to the same color block + u_center = (i + 0.5) * block_width / texture_width + v_center = 0.5 + + # Create UV coordinates for all vertices in this submesh + num_part_vertices = len(mesh_part.vertices) + part_uv_coords = np.full((num_part_vertices, 2), [u_center, v_center], dtype=np.float32) + mesh_part.visual = trimesh.visual.TextureVisuals(uv=part_uv_coords, image=texture) + mesh_parts_colored.append(mesh_part) + + return mesh_parts_colored + + +def apply_color_with_texture(mesh: trimesh.Trimesh, color: Tuple, tex_res: int = 16) -> trimesh.Trimesh: + """ + Apply a solid color to a mesh using UV texture coordinates instead of face colors. + This ensures compatibility with Blender and other tools that don't support face colors. + + Args: + mesh: The mesh to apply color to + color: Color as tuple (R, G, B) with values 0-1 or (R, G, B, A) with values 0-255 + tex_res: Resolution of the texture (default: 16x16) + + Returns: + mesh: The mesh with texture applied + """ + # Normalize color to 0-255 range + if len(color) >= 3: + if all(c <= 1.0 for c in color[:3]): + # Color is in 0-1 range, convert to 0-255 + color_rgb = tuple(int(c * 255) for c in color[:3]) + else: + # Color is already in 0-255 range + color_rgb = tuple(int(c) for c in color[:3]) + else: + raise ValueError("Color must have at least 3 components (R, G, B)") + + # Create a solid color texture + texture_array = np.full((tex_res, tex_res, 3), color_rgb, dtype=np.uint8) + texture = Image.fromarray(texture_array) + + # Create UV coordinates (all pointing to center of texture) + num_vertices = len(mesh.vertices) + uv_coords = np.full((num_vertices, 2), 0.5, dtype=np.float32) + + # Apply texture to mesh + mesh.visual = trimesh.visual.TextureVisuals(uv=uv_coords, image=texture) + + return mesh + + +def create_sphere(center: np.ndarray, radius: float, color: Tuple[float, float, float]) -> trimesh.Trimesh: + """ + Create a sphere mesh. + """ + sphere = trimesh.creation.icosphere(radius=radius, subdivisions=0) + sphere.vertices += center + sphere = apply_color_with_texture(sphere, color) + return sphere + + +def create_ring(center, normal, major_radius=0.04, minor_radius=0.006, color=(255, 0, 0), segments=32, tube_segments=16): + """ + Create a 3D ring (torus) perpendicular to a given direction. + + Args: + center: The center position of the ring (3D point) + normal: The normal direction of the ring plane (will be normalized) + major_radius: The radius of the ring from center to tube center + minor_radius: The radius of the tube itself (ring width) + color: RGB color tuple (can be 0-1 or 0-255 range) + segments: Number of segments around the ring + tube_segments: Number of segments around the tube cross-section + + Returns: + trimesh.Trimesh: The ring mesh + """ + center = np.array(center) + normal = np.array(normal) + normal = normal / np.linalg.norm(normal) + + # Find two perpendicular vectors to the normal + if abs(normal[2]) < 0.9: + v1 = np.cross(normal, np.array([0, 0, 1])) + else: + v1 = np.cross(normal, np.array([1, 0, 0])) + v1 = v1 / np.linalg.norm(v1) + v2 = np.cross(normal, v1) + v2 = v2 / np.linalg.norm(v2) + + # Generate torus vertices + vertices = [] + for i in range(segments): + theta = 2 * np.pi * i / segments + # Point on the major circle + circle_point = center + major_radius * (np.cos(theta) * v1 + np.sin(theta) * v2) + # Direction from center to this point on the major circle + radial_dir = np.cos(theta) * v1 + np.sin(theta) * v2 + + for j in range(tube_segments): + phi = 2 * np.pi * j / tube_segments + # Point on the tube cross-section + tube_offset = minor_radius * (np.cos(phi) * radial_dir + np.sin(phi) * normal) + vertex = circle_point + tube_offset + vertices.append(vertex) + + vertices = np.array(vertices) + + # Generate faces + faces = [] + for i in range(segments): + for j in range(tube_segments): + # Current vertex indices + v0 = i * tube_segments + j + v1 = i * tube_segments + (j + 1) % tube_segments + v2 = ((i + 1) % segments) * tube_segments + (j + 1) % tube_segments + v3 = ((i + 1) % segments) * tube_segments + j + + # Create two triangles for this quad + faces.append([v0, v1, v2]) + faces.append([v0, v2, v3]) + + faces = np.array(faces) + + # Create mesh with color using UV texture (compatible with Blender) + ring_mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) + ring_mesh = apply_color_with_texture(ring_mesh, color) + + return ring_mesh + + +def create_arrow( + start_point: np.ndarray, + end_point: np.ndarray, + color=(1, 0, 0, 1), + radius: float = 0.03, + radius_tip: float = 0.05 +) -> trimesh.Trimesh: + """ + Build a 3-D arrow (cylinder + cone) going from `start_point` to `end_point`. + """ + direction = end_point - start_point + length = np.linalg.norm(direction) + if length == 0: + raise ValueError("start_point and end_point must be different.") + + # Unit vector in arrow direction + v_dir = direction / length + + # Heuristic: tip is 10 % of length but never longer than 0.07 m + tip_h = min(0.1 * length, 0.04) + body_h = length - tip_h + if body_h <= 0: # extremely short arrow fallback + tip_h = 0.5 * length + body_h = length - tip_h + + # Cylinder (body) -- origin on z, height along +z + cyl = trimesh.creation.cylinder(radius=radius, height=body_h, sections=32) + cyl.apply_translation([0, 0, body_h / 2]) # base sits at z = 0 + + # Cone (tip) -- base at z = 0, apex at z = +tip_h + cone = trimesh.creation.cone(radius=radius_tip, height=tip_h, sections=32) + cone.apply_translation([0, 0, body_h]) # base starts where cylinder ends + + # Rotate both meshes from +Z to desired direction + R = trimesh.geometry.align_vectors([0, 0, 1], v_dir) + cyl.apply_transform(R) + cone.apply_transform(R) + + # Translate so tail is at start_point + cyl.apply_translation(start_point) + cone.apply_translation(start_point) + + cyl = apply_color_with_texture(cyl, color) + cone = apply_color_with_texture(cone, color) + + return trimesh.util.concatenate([cyl, cone]) + + +def get_3D_arrow_on_points( + direction: np.ndarray, + points: np.ndarray, + fixed_point: Optional[np.ndarray] = None, + extension: float = 0.05, +) -> Tuple[float, float]: + """ + Build a 3-D arrow (cylinder + cone) that encloses `points` along `direction`. + """ + # ── normalise direction ──────────────────────────────────────────────── + direction = np.asarray(direction, dtype=float) + if np.linalg.norm(direction) == 0: + raise ValueError("`direction` must be a non-zero vector.") + d_hat = direction / np.linalg.norm(direction) + + # ── validate points ─────────────────────────────────────────────────── + points = np.asarray(points, dtype=float) + if points.ndim != 2 or points.shape[1] != 3: + raise ValueError("`points` must be of shape (N, 3).") + + # ── choose reference point on axis ──────────────────────────────────── + P0 = ( + np.asarray(fixed_point, dtype=float) + if fixed_point is not None + else points.mean(axis=0) + ) + + # ── project points onto axis to find extents ────────────────────────── + scalars = np.dot(points - P0, d_hat) + if scalars.shape[0] > 0: + s_min = scalars.min() - max(extension * (scalars.max() - scalars.min()), 0.1) + s_max = scalars.max() + max(extension * (scalars.max() - scalars.min()), 0.1) + else: + s_min = -0.1 + s_max = 0.1 + + start_pt = P0 + s_min * d_hat + end_pt = P0 + s_max * d_hat + + return start_pt, end_pt