Skip to content

Commit

Permalink
some modification
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhaozhpe committed Apr 1, 2024
1 parent 06ea757 commit 1795b3f
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ This repository provides the official implementation of our paper, "PhysORD: A N
## Dataset
- This project utilizes the [TartanDrive](https://github.com/castacks/tartan_drive) dataset. Follow the instructions in its [repository](https://github.com/castacks/tartan_drive?tab=readme-ov-file#create-traintest-split) to create the `train`, `test-easy` and `test-hard` sets.
- `test-easy` is used for validation during training, and `test-hard` for model evaluation.
- We also provide pre-processed data for quick reproduction. Download the [train_val_easy_507_step20.pt / ...step5.pt](https://drive.google.com/drive/folders/16PX9j6SUU8_LB0vq5wj31WWVUY49cR8l?usp=sharing) file.
- We also provide pre-processed data for quick reproduction. Download the [train_val_easy_507_step20.pt / ...step5.pt](https://drive.google.com/drive/folders/16PX9j6SUU8_LB0vq5wj31WWVUY49cR8l?usp=sharing) file into the [data](data) folder.

## Reproduce Guide
To reproduce our result in the paper, you can follow the the steps below.
Expand Down
6 changes: 3 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', default='physord_s5', type=str, help='experiment name')
parser.add_argument('--exp_name', default='physord', type=str, help='experiment name')
parser.add_argument('--eval_data_fp', type=str, required=False, default='/data/data0/datasets/tartandrive/data/test-hard/', help='Path to test data')
parser.add_argument('--timesteps', type=int, required=False, default=20, help='Number of timesteps to predict')
parser.add_argument('--test_sample_interval', type=int, default=1, help='test data sample interval')
Expand All @@ -16,8 +16,8 @@
# Load the model
print("Loading the model ...")
model = PhysORD(device=device, use_dVNet = True, time_step = 0.1).to(device)
model_dir = f'./result2/{args.exp_name}'
model_fp = f'{model_dir}/best/best-data507-timestep5.tar'
model_dir = f'./pretrained/{args.exp_name}'
model_fp = f'{model_dir}/best/best-data507-timestep20.tar'
model.load_state_dict(torch.load(model_fp, map_location=device))

# Load the data
Expand Down
47 changes: 6 additions & 41 deletions physord/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,66 +132,31 @@ def step_forward(self, x):
vk_next = vk_next[:,:,0]

return torch.cat((qk_next, vk_next, omegak_next, sk, rpmk, uk), dim=1)

def predict_traininga(self, step_num, x, action):
# initial_x, initial_y = x[:, 0].clone(), x[:, 1].clone()
# x[:, 0] = x[:, 0] - initial_x
# x[:, 1] = x[:, 1] - initial_y

x = torch.cat((x, action[0:1,:]), dim = 1)
# print("x shape: ", x.shape)
xseq = x[None,:,:]
curx = x
for i in range(step_num):
nextx = self.step_forward(curx)
# curx = nextx
curx = torch.cat((nextx[:, :26], action[i+1:i+2,:]), dim = 1)
# print("curx shape: ", curx.shape)
xseq = torch.cat((xseq, curx[None,:,:]), dim = 0)
# for i in range(step_num + 1):
# xseq[i, :, 0] = xseq[i, :, 0] + initial_x
# xseq[i, :, 1] = xseq[i, :, 1] + initial_y

return xseq

def efficient_evaluation(self, step_num, x, action):
# initial_x, initial_y = x[:, 0].clone(), x[:, 1].clone()
# x[:, 0] = x[:, 0] - initial_x
# x[:, 1] = x[:, 1] - initial_y

# x = torch.cat((x, action[0:1,:]), dim = 1)
# print("x shape: ", x.shape)
initial_x, initial_y = x[:, 0].clone(), x[:, 1].clone()
x[:, 0] = x[:, 0] - initial_x
x[:, 1] = x[:, 1] - initial_y
xseq = x[None,:,:]
curx = x
for i in range(step_num):
nextx = self.step_forward(curx)
# curx = nextx
curx = torch.cat((nextx[:, :26], action[i+1, :, :]), dim = 1)
# print("curx shape: ", curx.shape)
xseq = torch.cat((xseq, curx[None,:,:]), dim = 0)
# for i in range(step_num + 1):
# xseq[i, :, 0] = xseq[i, :, 0] + initial_x
# xseq[i, :, 1] = xseq[i, :, 1] + initial_y
for i in range(step_num + 1):
xseq[i, :, 0] = xseq[i, :, 0] + initial_x
xseq[i, :, 1] = xseq[i, :, 1] + initial_y

return xseq

def evaluation(self, step_num, traj):
# print("x shape: ", x.shape)
# print("traj: ", traj[:, 10])
xseq = traj[0,:,:]
xseq = xseq[None,:,:]
curx = traj[0,:,:]
# print("curx shape: ", curx.shape)
for i in range(step_num):
# print("curx", curx[10])
nextx = self.step_forward(curx)
# curx = nextx
curx = torch.cat((nextx[:, :26], traj[i+1, :,26:29]), dim = 1)
# print("curx shape: ", curx.shape)
xseq = torch.cat((xseq, curx[None,:,:]), dim = 0)
# for i in range(step_num + 1):
# xseq[i, :, 0] = xseq[i, :, 0] + initial_x
# xseq[i, :, 1] = xseq[i, :, 1] + initial_y

return xseq

Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def train(args, train_data, val_data):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description=None)
parser.add_argument('--exp_name', default='physord_s5', type=str, help='experiment name')
parser.add_argument('--exp_name', default='physord', type=str, help='experiment name')
parser.add_argument('--train_data_size', type=int, default=507, help='number of training data: 100% = 507, 80% = 406, 50% = 254, 10% = 51, 1%=5')
parser.add_argument('--timesteps', type=int, default=5, help='number of prediction steps')
parser.add_argument('--preprocessed_data_dir', default=None, type=str, help='directory of the preprocessed data.')
parser.add_argument('--save_dir', default="./result2/", type=str, help='where to save the trained model')
parser.add_argument('--timesteps', type=int, default=20, help='number of prediction steps')
parser.add_argument('--preprocessed_data_dir', default='./data/', type=str, help='directory of the preprocessed data.')
parser.add_argument('--save_dir', default="./result/", type=str, help='where to save the trained model')
parser.add_argument('--val_sample_interval', type=int, default=1, help='validation_data')
parser.add_argument('--early_stop', dest='early_stopping', action='store_true', help='early stopping?')
parser.add_argument('--pretrained', default=None, type=str, help='Path to the pretrained model. If not provided, no pretrained model will be loaded.')
Expand Down
7 changes: 0 additions & 7 deletions util/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,6 @@ def get_test_data(eval_data_fp, norm_params, T, sample_intervals):
state = get_state_seq_from_traj(traj, min_val_st, max_val_st, min_val_brake, max_val_brake)
state = arrange_data_sample(state, T + 1, sample_intervals)
state = torch.from_numpy(state)

initial_x = state[0, :, 0].unsqueeze(0)
initial_y = state[0, :, 1].unsqueeze(0)
initial_x = initial_x.expand(state.size(0), -1)
initial_y = initial_y.expand(state.size(0), -1)
state[:, :, 0] = state[:, :, 0] - initial_x
state[:, :, 1] = state[:, :, 1] - initial_y
states_test.append(state)
test_x_cat = torch.cat(states_test, dim=1)
return test_x_cat

0 comments on commit 1795b3f

Please sign in to comment.