diff --git a/README.md b/README.md index d5fe68f..9b87a29 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ This repository provides the official implementation of our paper, "PhysORD: A N ## Dataset - This project utilizes the [TartanDrive](https://github.com/castacks/tartan_drive) dataset. Follow the instructions in its [repository](https://github.com/castacks/tartan_drive?tab=readme-ov-file#create-traintest-split) to create the `train`, `test-easy` and `test-hard` sets. - `test-easy` is used for validation during training, and `test-hard` for model evaluation. -- We also provide pre-processed data for quick reproduction. Download the [train_val_easy_507_step20.pt / ...step5.pt](https://drive.google.com/drive/folders/16PX9j6SUU8_LB0vq5wj31WWVUY49cR8l?usp=sharing) file. +- We also provide pre-processed data for quick reproduction. Download the [train_val_easy_507_step20.pt / ...step5.pt](https://drive.google.com/drive/folders/16PX9j6SUU8_LB0vq5wj31WWVUY49cR8l?usp=sharing) file into the [data](data) folder. ## Reproduce Guide To reproduce our result in the paper, you can follow the the steps below. diff --git a/eval.py b/eval.py index e625fe2..d9ac45a 100644 --- a/eval.py +++ b/eval.py @@ -5,7 +5,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--exp_name', default='physord_s5', type=str, help='experiment name') + parser.add_argument('--exp_name', default='physord', type=str, help='experiment name') parser.add_argument('--eval_data_fp', type=str, required=False, default='/data/data0/datasets/tartandrive/data/test-hard/', help='Path to test data') parser.add_argument('--timesteps', type=int, required=False, default=20, help='Number of timesteps to predict') parser.add_argument('--test_sample_interval', type=int, default=1, help='test data sample interval') @@ -16,8 +16,8 @@ # Load the model print("Loading the model ...") model = PhysORD(device=device, use_dVNet = True, time_step = 0.1).to(device) - model_dir = f'./result2/{args.exp_name}' - model_fp = f'{model_dir}/best/best-data507-timestep5.tar' + model_dir = f'./pretrained/{args.exp_name}' + model_fp = f'{model_dir}/best/best-data507-timestep20.tar' model.load_state_dict(torch.load(model_fp, map_location=device)) # Load the data diff --git a/physord/model.py b/physord/model.py index 26684d3..7a55428 100644 --- a/physord/model.py +++ b/physord/model.py @@ -132,66 +132,31 @@ def step_forward(self, x): vk_next = vk_next[:,:,0] return torch.cat((qk_next, vk_next, omegak_next, sk, rpmk, uk), dim=1) - - def predict_traininga(self, step_num, x, action): - # initial_x, initial_y = x[:, 0].clone(), x[:, 1].clone() - # x[:, 0] = x[:, 0] - initial_x - # x[:, 1] = x[:, 1] - initial_y - - x = torch.cat((x, action[0:1,:]), dim = 1) - # print("x shape: ", x.shape) - xseq = x[None,:,:] - curx = x - for i in range(step_num): - nextx = self.step_forward(curx) - # curx = nextx - curx = torch.cat((nextx[:, :26], action[i+1:i+2,:]), dim = 1) - # print("curx shape: ", curx.shape) - xseq = torch.cat((xseq, curx[None,:,:]), dim = 0) - # for i in range(step_num + 1): - # xseq[i, :, 0] = xseq[i, :, 0] + initial_x - # xseq[i, :, 1] = xseq[i, :, 1] + initial_y - - return xseq def efficient_evaluation(self, step_num, x, action): - # initial_x, initial_y = x[:, 0].clone(), x[:, 1].clone() - # x[:, 0] = x[:, 0] - initial_x - # x[:, 1] = x[:, 1] - initial_y - - # x = torch.cat((x, action[0:1,:]), dim = 1) - # print("x shape: ", x.shape) + initial_x, initial_y = x[:, 0].clone(), x[:, 1].clone() + x[:, 0] = x[:, 0] - initial_x + x[:, 1] = x[:, 1] - initial_y xseq = x[None,:,:] curx = x for i in range(step_num): nextx = self.step_forward(curx) - # curx = nextx curx = torch.cat((nextx[:, :26], action[i+1, :, :]), dim = 1) - # print("curx shape: ", curx.shape) xseq = torch.cat((xseq, curx[None,:,:]), dim = 0) - # for i in range(step_num + 1): - # xseq[i, :, 0] = xseq[i, :, 0] + initial_x - # xseq[i, :, 1] = xseq[i, :, 1] + initial_y + for i in range(step_num + 1): + xseq[i, :, 0] = xseq[i, :, 0] + initial_x + xseq[i, :, 1] = xseq[i, :, 1] + initial_y return xseq def evaluation(self, step_num, traj): - # print("x shape: ", x.shape) - # print("traj: ", traj[:, 10]) xseq = traj[0,:,:] xseq = xseq[None,:,:] curx = traj[0,:,:] - # print("curx shape: ", curx.shape) for i in range(step_num): - # print("curx", curx[10]) nextx = self.step_forward(curx) - # curx = nextx curx = torch.cat((nextx[:, :26], traj[i+1, :,26:29]), dim = 1) - # print("curx shape: ", curx.shape) xseq = torch.cat((xseq, curx[None,:,:]), dim = 0) - # for i in range(step_num + 1): - # xseq[i, :, 0] = xseq[i, :, 0] + initial_x - # xseq[i, :, 1] = xseq[i, :, 1] + initial_y return xseq diff --git a/train.py b/train.py index 2954e67..5210c15 100644 --- a/train.py +++ b/train.py @@ -137,11 +137,11 @@ def train(args, train_data, val_data): if __name__ == "__main__": parser = argparse.ArgumentParser(description=None) - parser.add_argument('--exp_name', default='physord_s5', type=str, help='experiment name') + parser.add_argument('--exp_name', default='physord', type=str, help='experiment name') parser.add_argument('--train_data_size', type=int, default=507, help='number of training data: 100% = 507, 80% = 406, 50% = 254, 10% = 51, 1%=5') - parser.add_argument('--timesteps', type=int, default=5, help='number of prediction steps') - parser.add_argument('--preprocessed_data_dir', default=None, type=str, help='directory of the preprocessed data.') - parser.add_argument('--save_dir', default="./result2/", type=str, help='where to save the trained model') + parser.add_argument('--timesteps', type=int, default=20, help='number of prediction steps') + parser.add_argument('--preprocessed_data_dir', default='./data/', type=str, help='directory of the preprocessed data.') + parser.add_argument('--save_dir', default="./result/", type=str, help='where to save the trained model') parser.add_argument('--val_sample_interval', type=int, default=1, help='validation_data') parser.add_argument('--early_stop', dest='early_stopping', action='store_true', help='early stopping?') parser.add_argument('--pretrained', default=None, type=str, help='Path to the pretrained model. If not provided, no pretrained model will be loaded.') diff --git a/util/data_process.py b/util/data_process.py index e055263..a193c4b 100644 --- a/util/data_process.py +++ b/util/data_process.py @@ -185,13 +185,6 @@ def get_test_data(eval_data_fp, norm_params, T, sample_intervals): state = get_state_seq_from_traj(traj, min_val_st, max_val_st, min_val_brake, max_val_brake) state = arrange_data_sample(state, T + 1, sample_intervals) state = torch.from_numpy(state) - - initial_x = state[0, :, 0].unsqueeze(0) - initial_y = state[0, :, 1].unsqueeze(0) - initial_x = initial_x.expand(state.size(0), -1) - initial_y = initial_y.expand(state.size(0), -1) - state[:, :, 0] = state[:, :, 0] - initial_x - state[:, :, 1] = state[:, :, 1] - initial_y states_test.append(state) test_x_cat = torch.cat(states_test, dim=1) return test_x_cat \ No newline at end of file