Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions colab_start.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SAM 3D Objects - Colab Setup\n",
"\n",
"This notebook sets up the environment and runs a demo inference on Google Colab (T4 GPU recommended).\n",
"\n",
"**Note:** If you encounter build errors with PyTorch3D or other libraries, make sure you are using a GPU runtime."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 1. Clone the repository\n",
"import os\n",
"if not os.path.exists(\"sam-3d-objects\"):\n",
" !git clone https://github.com/facebookresearch/sam-3d-objects.git\n",
" %cd sam-3d-objects\n",
"else:\n",
" %cd sam-3d-objects\n",
" !git pull\n",
"\n",
"# 2. Install Dependencies\n",
"print(\"Installing dependencies... This may take a few minutes.\")\n",
"\n",
"# Install general requirements first (excluding pytorch3d)\n",
"!pip install -r requirements_colab.txt\n",
"\n",
"# 3. Install PyTorch3D (Binary Wheel)\n",
"# Building PyTorch3D from source on Colab is slow and error-prone. We use pre-built wheels.\n",
"import torch\n",
"import sys\n",
"\n",
"try:\n",
" import pytorch3d\n",
" print(\"PyTorch3D is already installed.\")\n",
"except ImportError:\n",
" print(\"Installing PyTorch3D...\")\n",
" pyt_version_str = torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" cuda_version_str = torch.version.cuda.replace(\".\", \"\")\n",
" \n",
" if \"2.4\" in torch.__version__ or \"2.5\" in torch.__version__:\n",
" try:\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt240/download.html\n",
" except:\n",
" print(\"Binary wheel not found, falling back to source install (slow)...\")\n",
" !pip install \"git+https://github.com/facebookresearch/pytorch3d.git@stable\"\n",
" else:\n",
" version_str = \"\".join([\n",
" f\"py3{sys.version_info.minor}_cu\",\n",
" cuda_version_str,\n",
" f\"_pyt{pyt_version_str}\"\n",
" ])\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
"\n",
"# 4. Install the package in editable mode and patch Hydra\n",
"!pip install -e .\n",
"!python patching/hydra\n",
"\n",
"print(\"Installation complete. PLEASE RESTART THE RUNTIME (Runtime > Restart session) if you see import errors!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download Checkpoints\n",
"\n",
"Attempt to download weights. \n",
"Option 1: Hugging Face (Requires Token, Full Model)\n",
"Option 2: Direct Download (Public Link - User Requested, may be partial)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import shutil\n",
"from huggingface_hub import snapshot_download\n",
"\n",
"def download_weights_hf(token):\n",
" tag = \"hf\"\n",
" download_dir = f\"checkpoints/{tag}-download\"\n",
" target_dir = f\"checkpoints/{tag}\"\n",
" \n",
" if os.path.exists(target_dir) and os.path.exists(os.path.join(target_dir, \"pipeline.yaml\")):\n",
" print(f\"Checkpoints already exist at {target_dir}\")\n",
" return\n",
" \n",
" print(\"Downloading model weights from Hugging Face...\")\n",
" try:\n",
" snapshot_download(\n",
" repo_id=\"facebook/sam-3d-objects\",\n",
" repo_type=\"model\",\n",
" local_dir=download_dir,\n",
" max_workers=1,\n",
" token=token\n",
" )\n",
" \n",
" source = os.path.join(download_dir, \"checkpoints\")\n",
" if os.path.exists(source):\n",
" if os.path.exists(target_dir):\n",
" shutil.rmtree(target_dir)\n",
" shutil.move(source, target_dir)\n",
" shutil.rmtree(download_dir)\n",
" print(\"Download complete and files moved.\")\n",
" else:\n",
" if os.path.exists(target_dir):\n",
" shutil.rmtree(target_dir)\n",
" shutil.move(download_dir, target_dir)\n",
" print(\"Download complete (fallback structure).\")\n",
" \n",
" except Exception as e:\n",
" print(f\"Error downloading weights from HF: {e}\")\n",
"\n",
"def download_weights_public():\n",
" # User requested link for public SAM weights\n",
" url = \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth\"\n",
" target_dir = \"checkpoints/hf\"\n",
" target_file = os.path.join(target_dir, \"sam_vit_h_4b8939.pth\")\n",
" \n",
" if not os.path.exists(target_dir):\n",
" os.makedirs(target_dir)\n",
" \n",
" print(f\"Downloading public SAM weights from {url}...\")\n",
" !wget -O {target_file} {url}\n",
" print(\"Download complete.\")\n",
" print(\"WARNING: This is only the SAM checkpoint. The full SAM-3D-Objects pipeline likely requires 'pipeline.yaml' and other weights from the gated HF repo.\")\n",
"\n",
"# Try to get token\n",
"token = None\n",
"try:\n",
" from google.colab import userdata\n",
" token = userdata.get('HF_TOKEN')\n",
"except:\n",
" pass\n",
"\n",
"# If no token, ask or fallback\n",
"if not token:\n",
" print(\"No HF_TOKEN found in secrets.\")\n",
" choice = input(\"Enter '1' to provide HF Token (Recommended), '2' to use Public SAM Link (May be incomplete): \")\n",
" if choice == '1':\n",
" from getpass import getpass\n",
" token = getpass(\"Enter Hugging Face Token: \")\n",
" download_weights_hf(token)\n",
" else:\n",
" download_weights_public()\n",
"else:\n",
" print(\"Using HF_TOKEN from secrets.\")\n",
" download_weights_hf(token)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Minimal Inference Code\n",
"import sys\n",
"import os\n",
"sys.path.append(\"notebook\")\n",
"from inference import Inference, load_image, load_single_mask\n",
"import torch\n",
"\n",
"if not torch.cuda.is_available():\n",
" print(\"Warning: CUDA is not available. Inference will be slow or fail.\")\n",
"\n",
"tag = \"hf\"\n",
"config_path = f\"checkpoints/{tag}/pipeline.yaml\"\n",
"\n",
"if not os.path.exists(config_path):\n",
" print(f\"Error: Config not found at {config_path}. Did you download the full model from Hugging Face?\")\n",
"else:\n",
" # Load model\n",
" print(\"Loading model...\")\n",
" inference = Inference(config_path, compile=False)\n",
"\n",
" # Load dummy image/mask (using one from repo)\n",
" image_path = \"notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png\"\n",
" mask_folder = \"notebook/images/shutterstock_stylish_kidsroom_1640806567\"\n",
" \n",
" print(f\"Processing image: {image_path}\")\n",
" if os.path.exists(image_path):\n",
" image = load_image(image_path)\n",
" mask = load_single_mask(mask_folder, index=14)\n",
"\n",
" # Run model\n",
" output = inference(image, mask, seed=42)\n",
"\n",
" # Save output\n",
" output[\"gs\"].save_ply(\"splat.ply\")\n",
" print(\"Success! Output saved to splat.ply\")\n",
" else:\n",
" print(\"Test image not found.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 1 addition & 1 deletion notebook/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

# not ideal to put that here
os.environ["CUDA_HOME"] = os.environ["CONDA_PREFIX"]
os.environ["CUDA_HOME"] = os.environ.get("CONDA_PREFIX", "/usr/local/cuda")
os.environ["LIDRA_SKIP_INIT"] = "true"

import sys
Expand Down
26 changes: 26 additions & 0 deletions requirements_colab.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
numpy
Pillow
opencv-python
matplotlib
seaborn
gradio
omegaconf
hydra-core
timm
h5py
scikit-image
einops-exts
transformers
accelerate
bitsandbytes
gdown
ninja
kaolin==0.17.0
gsplat @ git+https://github.com/nerfstudio-project/gsplat.git@2323de5905d5e90e035f792fe65bad0fedd413e7
xformers
spconv-cu121
open3d
pandas
scipy
MoGe @ git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b
cuda-python