Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distillation script for faster inference #54

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ddlitlab2024/dataset/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ class JointStates(Base):
Index(None, "recording_id", asc("stamp")),
)

def get_ordered_joint_names(self) -> list[str]:
@staticmethod
def get_ordered_joint_names() -> list[str]:
return [
JointStates.head_pan.name,
JointStates.head_tilt.name,
Expand Down
3 changes: 2 additions & 1 deletion ddlitlab2024/dataset/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
self.num_frames_video = num_frames_video
self.trajectory_stride = trajectory_stride
self.num_joints = num_joints
self.joint_names = JointStates.get_ordered_joint_names()

# Print out metadata
cursor = self.db_connection.cursor()
Expand Down Expand Up @@ -119,7 +120,7 @@ def query_joint_data(
)

# Convert to numpy array, keep only the joint angle columns in alphabetical order
raw_joint_data = raw_joint_data[JointStates.get_ordered_joint_names()].to_numpy(dtype=np.float32)
raw_joint_data = raw_joint_data[self.joint_names].to_numpy(dtype=np.float32)

assert raw_joint_data.shape[1] == self.num_joints, "The number of joints is not correct"

Expand Down
23 changes: 14 additions & 9 deletions ddlitlab2024/ml/inference/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# Parse the command line arguments
parser = argparse.ArgumentParser(description="Inference Plot")
parser.add_argument("checkpoint", type=str, help="Path to the checkpoint to load")
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps")
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps (not used for distilled)")
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate")
args = parser.parse_args()

Expand Down Expand Up @@ -104,15 +104,20 @@
noisy_trajectory = torch.randn_like(joint_targets).to(device)
trajectory = noisy_trajectory

# Perform the denoising process
scheduler.set_timesteps(args.steps)
for t in scheduler.timesteps:
if params.get("distilled_decoder", False):
# Directly predict the trajectory based on the noise
with torch.no_grad():
# Predict the noise residual
noise_pred = model(batch, trajectory, torch.tensor([t], device=device))

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample
trajectory = model(batch, noisy_trajectory, torch.tensor([0], device=device))
else:
# Perform the denoising process
scheduler.set_timesteps(args.steps)
for t in scheduler.timesteps:
with torch.no_grad():
# Predict the noise residual
noise_pred = model(batch, trajectory, torch.tensor([t], device=device))

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample

# Undo the normalization
trajectory = normalizer.denormalize(trajectory)
Expand Down
131 changes: 69 additions & 62 deletions ddlitlab2024/ml/inference/ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from bitbots_tf_buffer import Buffer
from cv_bridge import CvBridge
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from ema_pytorch import EMA
from game_controller_hl_interfaces.msg import GameState
from profilehooks import profile
from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
Expand Down Expand Up @@ -42,21 +41,18 @@ def __init__(self, node_name, context):
[rclpy.parameter.Parameter("use_sim_time", rclpy.Parameter.Type.BOOL, True)],
)

checkpoint_path = (
"../training/destilled_trajectory_transformer_model_first_train_20_epoch_hyp.pth"
#"../training/trajectory_transformer_model_500_epoch_xmas_hyp.pth"
)
self.inference_denosing_timesteps = 30

# Params
self.sample_rate = DEFAULT_RESAMPLE_RATE_HZ
hidden_dim = 256
self.action_context_length = 100
self.trajectory_prediction_length = 10
train_denoising_timesteps = 1000
self.inference_denosing_timesteps = 10
self.image_context_length = 10
self.imu_context_length = 100
self.joint_state_context_length = 100
self.num_joints = 20
checkpoint = (
"/home/florian/ddlitlab/ddlitlab_repo/ddlitlab2024/ml/training/"
"trajectory_transformer_model_500_epoch_xmas.pth"
)
# Load the hyperparameters from the checkpoint
self.get_logger().info(f"Loading checkpoint '{checkpoint_path}'")
checkpoint = torch.load(checkpoint_path, weights_only=True)
self.hyper_params = checkpoint["hyperparams"]

# Subscribe to all the input topics
self.joint_state_sub = self.create_subscription(JointState, "/joint_states", self.joint_state_callback, 10)
Expand Down Expand Up @@ -89,12 +85,14 @@ def __init__(self, node_name, context):
self.latest_game_state = None

# Add default values to the buffers
self.image_embeddings = [torch.randn(3, 480, 480)] * self.image_context_length
self.imu_data = [torch.randn(4)] * self.imu_context_length
self.joint_state_data = [
torch.randn(len(JointStates.get_ordered_joint_names()))
] * self.joint_state_context_length
self.joint_command_data = [torch.randn(self.num_joints)] * self.action_context_length
self.image_embeddings = [torch.randn(3, 480, 480)] * self.hyper_params["image_context_length"]
self.imu_data = [torch.randn(4)] * self.hyper_params["imu_context_length"]
self.joint_state_data = [torch.randn(len(JointStates.get_ordered_joint_names()))] * self.hyper_params[
"joint_state_context_length"
]
self.joint_command_data = [torch.randn(self.hyper_params["num_joints"])] * self.hyper_params[
"action_context_length"
]

self.data_lock = Lock()

Expand All @@ -106,43 +104,41 @@ def __init__(self, node_name, context):
# Load model
self.get_logger().info("Load model")
self.model = End2EndDiffusionTransformer(
num_joints=self.num_joints,
hidden_dim=hidden_dim,
use_action_history=True,
num_action_history_encoder_layers=2,
max_action_context_length=self.action_context_length,
use_imu=True,
imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod.QUATERNION,
num_imu_encoder_layers=2,
max_imu_context_length=self.imu_context_length,
use_joint_states=True,
joint_state_encoder_layers=2,
max_joint_state_context_length=self.joint_state_context_length,
use_images=True,
image_sequence_encoder_type=SequenceEncoderType.TRANSFORMER,
image_encoder_type=ImageEncoderType.RESNET18,
num_image_sequence_encoder_layers=1,
max_image_context_length=self.image_context_length,
num_decoder_layers=4,
trajectory_prediction_length=self.trajectory_prediction_length,
num_joints=self.hyper_params["num_joints"],
hidden_dim=self.hyper_params["hidden_dim"],
use_action_history=self.hyper_params["use_action_history"],
num_action_history_encoder_layers=self.hyper_params["num_action_history_encoder_layers"],
max_action_context_length=self.hyper_params["action_context_length"],
use_imu=self.hyper_params["use_imu"],
imu_orientation_embedding_method=IMUEncoder.OrientationEmbeddingMethod(
self.hyper_params["imu_orientation_embedding_method"]
),
num_imu_encoder_layers=self.hyper_params["num_imu_encoder_layers"],
imu_context_length=self.hyper_params["imu_context_length"],
use_joint_states=self.hyper_params["use_joint_states"],
joint_state_encoder_layers=self.hyper_params["joint_state_encoder_layers"],
joint_state_context_length=self.hyper_params["joint_state_context_length"],
use_images=self.hyper_params["use_images"],
image_sequence_encoder_type=SequenceEncoderType(self.hyper_params["image_sequence_encoder_type"]),
image_encoder_type=ImageEncoderType(self.hyper_params["image_encoder_type"]),
num_image_sequence_encoder_layers=self.hyper_params["num_image_sequence_encoder_layers"],
image_context_length=self.hyper_params["image_context_length"],
num_decoder_layers=self.hyper_params["num_decoder_layers"],
trajectory_prediction_length=self.hyper_params["trajectory_prediction_length"],
).to(device)

self.og_model = self.model

self.normalizer = Normalizer(self.model.mean, self.model.std)
self.model = EMA(self.model)
self.model.load_state_dict(torch.load(checkpoint, weights_only=True))
self.model.load_state_dict(checkpoint["model_state_dict"])
self.model.eval()
print(self.normalizer.mean)

# Create diffusion noise scheduler
self.get_logger().info("Create diffusion noise scheduler")
self.scheduler = DDIMScheduler(beta_schedule="squaredcos_cap_v2", clip_sample=False)
self.scheduler.config["num_train_timesteps"] = train_denoising_timesteps
self.scheduler.config["num_train_timesteps"] = self.hyper_params["train_denoising_timesteps"]
self.scheduler.set_timesteps(self.inference_denosing_timesteps)

# Create control timer to run inference at a fixed rate
interval = 1 / self.sample_rate * self.trajectory_prediction_length
interval = 1 / self.sample_rate * (self.hyper_params["trajectory_prediction_length"])
# We want to run the inference in a separate thread to not block the callbacks, but we also want to make sure
# that the inference is not running multiple times in parallel
self.create_timer(interval, self.step, callback_group=MutuallyExclusiveCallbackGroup())
Expand Down Expand Up @@ -217,10 +213,10 @@ def update_buffers(self):
)

# Remove the oldest data from the buffers
self.joint_state_data = self.joint_state_data[-self.joint_state_context_length :]
self.image_embeddings = self.image_embeddings[-self.image_context_length :]
self.imu_data = self.imu_data[-self.imu_context_length :]
self.joint_command_data = self.joint_command_data[-self.action_context_length :]
self.joint_state_data = self.joint_state_data[-self.hyper_params["joint_state_context_length"] :]
self.image_embeddings = self.image_embeddings[-self.hyper_params["image_context_length"] :]
self.imu_data = self.imu_data[-self.hyper_params["imu_context_length"] :]
self.joint_command_data = self.joint_command_data[-self.hyper_params["action_context_length"] :]

@profile
def step(self):
Expand All @@ -239,28 +235,39 @@ def step(self):
% (2 * np.pi), # torch.stack(list(self.joint_command_data), dim=0).unsqueeze(0).to(device),
}

print("Batch: ", batch["image_data"].shape)

# Perform the denoising process
trajectory = torch.randn(1, self.trajectory_prediction_length, self.num_joints).to(device)
trajectory = torch.randn(
1, self.hyper_params["trajectory_prediction_length"], self.hyper_params["num_joints"]
).to(device)

start_ros_time = self.get_clock().now()

## Perform the embedding of the conditioning
start = time.time()
embedded_input = self.og_model.encode_input_data(batch)
embedded_input = self.model.encode_input_data(batch)
print("Time for embedding: ", time.time() - start)

# Denoise the trajectory
start = time.time()
self.scheduler.set_timesteps(self.inference_denosing_timesteps)
for t in self.scheduler.timesteps:
with torch.no_grad():
# Predict the noise residual
noise_pred = self.og_model.forward_with_context(
embedded_input, trajectory, torch.tensor([t], device=device)
)

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample
if self.hyper_params.get("distilled_decoder", False):
# Directly predict the trajectory based on the noise
with torch.no_grad():
trajectory = self.model.forward_with_context(embedded_input, trajectory, torch.tensor([0], device=device))
else:
# Perform the denoising process
self.scheduler.set_timesteps(self.inference_denosing_timesteps)
for t in self.scheduler.timesteps:
with torch.no_grad():
# Predict the noise residual
noise_pred = self.model.forward_with_context(
embedded_input, trajectory, torch.tensor([t], device=device)
)

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = self.scheduler.step(noise_pred, t, trajectory).prev_sample

print("Time for forward: ", time.time() - start)

Expand All @@ -272,7 +279,7 @@ def step(self):
trajectory_msg.header.stamp = Time.to_msg(start_ros_time)
trajectory_msg.joint_names = JointStates.get_ordered_joint_names()
trajectory_msg.points = []
for i in range(self.trajectory_prediction_length):
for i in range(self.hyper_params["trajectory_prediction_length"]):
point = JointTrajectoryPoint()
point.positions = trajectory[0, i].cpu().numpy() - np.pi
point.time_from_start = Duration(nanoseconds=int(1e9 / self.sample_rate * i)).to_msg()
Expand Down
6 changes: 3 additions & 3 deletions ddlitlab2024/ml/model/encoder/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import nn
from torchvision.models import resnet18, resnet50, swin_s, swin_t
from torchvision.models.resnet import ResNet18_Weights, ResNet50_Weights

from ddlitlab2024.ml.model.encoder.base import BaseEncoder

Expand Down Expand Up @@ -60,12 +61,11 @@ def __init__(self, resnet_type: ImageEncoderType, hidden_dim: int):
super().__init__()
match resnet_type:
case ImageEncoderType.RESNET18:
self.encoder = resnet18(pretrained=True)
self.encoder = resnet18(weights=ResNet18_Weights.DEFAULT)
case ImageEncoderType.RESNET50:
self.encoder = resnet50(pretrained=True)
self.encoder = resnet50(weights=ResNet50_Weights.DEFAULT)
case _:
raise ValueError(f"Invalid ResNet type: {resnet_type}")
# TODO check for softmax layer etc.
self.encoder.fc = nn.Linear(self.encoder.fc.in_features, hidden_dim)


Expand Down
3 changes: 2 additions & 1 deletion ddlitlab2024/ml/training/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ use_images: True
image_sequence_encoder_type: "transformer"
image_encoder_type: "resnet18"
num_image_sequence_encoder_layers: 1
num_decoder_layers: 4
num_decoder_layers: 4
distill_teacher_inference_steps: 30
Loading
Loading