Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distillation script for faster inference #54

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f58f5a8
First version of small model
Flova Jan 28, 2025
aeb6cd5
Add patching
Flova Jan 28, 2025
e8db54f
Add basic destillation training script
Flova Jan 28, 2025
cb936c2
Fix warning
Flova Jan 28, 2025
a116338
Fix destillation
Flova Jan 28, 2025
2226934
Add destilled model to plot
Flova Jan 30, 2025
b6955f5
Merge branch 'main' into feature/destillation
Flova Jan 30, 2025
d0b7b63
Rename file
Flova Jan 30, 2025
9fc7f78
Fix joint names
Flova Jan 30, 2025
191cf78
use destilled hyperparam
Flova Jan 30, 2025
34bcf83
Migrate distillation to new hyperparameter standard
Flova Jan 30, 2025
f4e74b0
Current WIP
Flova Jan 30, 2025
8dd6023
Make it possible to only load some of the features from the dataset
Flova Feb 4, 2025
5b0b7f4
Add decoder only training
Flova Feb 4, 2025
577a105
Format
Flova Feb 4, 2025
56b3b58
Shorter pretraining
Flova Feb 5, 2025
b6d626d
Merge branch 'main' into feature/small_model
Flova Feb 5, 2025
19dd73b
Reverse to transformer decoder
Flova Feb 5, 2025
8f53490
Merge branch 'feature/small_model' into feature/destillation
Flova Feb 5, 2025
f5b51b5
Fix other scripts
Flova Feb 5, 2025
420158a
Remove profiling
Flova Feb 5, 2025
12f95c0
Be able to load pretrained decoders
Flova Feb 5, 2025
d29db0e
Apply formatting
Flova Feb 5, 2025
213d285
Optmize data transfer
Flova Feb 5, 2025
c0e7755
Remove pinned memory
Flova Feb 5, 2025
73c9b6c
Add wandb
Flova Feb 6, 2025
92f0693
Fix wand and add larger model
Flova Feb 9, 2025
ca4312b
Add wandb logs to gitignore
Flova Feb 9, 2025
048cdca
Fix model loading in distillation
Flova Feb 9, 2025
af70a87
Avoid printing out all the params
Flova Feb 9, 2025
e2b2299
Add wandb to distill
Flova Feb 9, 2025
8c7965d
Sort keys
Flova Feb 11, 2025
9fb40c6
Change ros runtime
Flova Feb 11, 2025
ada1b33
Fix five dim input
Flova Feb 20, 2025
3454ee1
Add different resolutions, fix five dim imu, add current training con…
Flova Feb 20, 2025
9b0f14e
Fix image padding
Flova Feb 20, 2025
5d27aa9
Add support for other imu input
Flova Feb 27, 2025
69f21fa
Cleanup
Flova Feb 27, 2025
da0c1ac
Normalize image during data loading
Flova Feb 27, 2025
114f1eb
Sample image at correct rate
Flova Feb 27, 2025
f114ea7
Fix preprocessing pipeline
Flova Feb 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add destilled model to plot
Flova committed Jan 30, 2025
commit 22269341e597a235931f5bdebe700fae9a76879a
24 changes: 15 additions & 9 deletions ddlitlab2024/ml/inference/plot.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,8 @@
joint_state_context_length = 100
num_normalization_samples = 50
num_joints = 20
checkpoint = "/homes/17vahl/ddlitlab2024/ddlitlab2024/ml/training/trajectory_transformer_model.pth"
checkpoint = "/homes/17vahl/ddlitlab2024/ddlitlab2024/ml/training/destilled_trajectory_transformer_model.pth"
distilled = True

logger.info("Load model")
model = End2EndDiffusionTransformer(
@@ -110,15 +111,20 @@
noisy_trajectory = torch.randn_like(joint_targets).to(device)
trajectory = noisy_trajectory

# Perform the denoising process
scheduler.set_timesteps(inference_denosing_timesteps)
for t in scheduler.timesteps:
if distilled:
# Directly predict the trajectory based on the noise
with torch.no_grad():
# Predict the noise residual
noise_pred = model(batch, trajectory, torch.tensor([t], device=device))

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample
trajectory = model(batch, noisy_trajectory, torch.tensor([0], device=device))
else:
# Perform the denoising process
scheduler.set_timesteps(inference_denosing_timesteps)
for t in scheduler.timesteps:
with torch.no_grad():
# Predict the noise residual
noise_pred = model(batch, trajectory, torch.tensor([t], device=device))

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample

# Undo the normalization
print(normalizer.mean)