Skip to content

Commit

Permalink
context is a class field
Browse files Browse the repository at this point in the history
  • Loading branch information
luisa-mao committed May 28, 2024
1 parent 64c5217 commit dfd0ee5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 19 deletions.
10 changes: 5 additions & 5 deletions src/navigation/terrain_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ std::shared_ptr<PathRolloutBase> TerrainEvaluator::FindBest(
DrawPathCosts(paths, best_path);
// save the image to a file
// latest vis from rgb to bgr
cv::cvtColor(latest_vis_image_, latest_vis_image_, cv::COLOR_RGB2BGR);
cv::imwrite("latest_vis.png", latest_vis_image_);
cv::cvtColor(latest_vis_image_, latest_vis_image_, cv::COLOR_BGR2RGB);
cv::imwrite("latest_cost.png", latest_cost_image_);
exit(0);
// cv::cvtColor(latest_vis_image_, latest_vis_image_, cv::COLOR_RGB2BGR);
// cv::imwrite("latest_vis.png", latest_vis_image_);
// cv::cvtColor(latest_vis_image_, latest_vis_image_, cv::COLOR_BGR2RGB);
// cv::imwrite("latest_cost.png", latest_cost_image_);
// exit(0);


return best_path;
Expand Down
25 changes: 11 additions & 14 deletions src/navigation/terrain_evaluator2.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ namespace motion_primitives {

class CustomTerrainEvaluator : public TerrainEvaluator {
public:
CustomTerrainEvaluator() : TerrainEvaluator() {}
torch::Tensor context_tensor_;
CustomTerrainEvaluator() : TerrainEvaluator() {
// todo: store the context as private variable
torch::jit::script::Module tensors = torch::jit::load("terrain_models/context.pt");
context_tensor_ = tensors.run_method("return_tensor").toTensor();
}


// std::shared_ptr<PathRolloutBase> FindBest(
// const std::vector<std::shared_ptr<PathRolloutBase>>& paths) override {
// // Custom implementation of FindBest
// }

// latest bev image 749 1476
// latest bev image channels 3
Expand All @@ -36,14 +38,9 @@ class CustomTerrainEvaluator : public TerrainEvaluator {
img_tensor = img_tensor.to(torch::kFloat32);
img_tensor = img_tensor / 255.0;

auto example_context = torch::randn({1, 9, 128, 64});
// todo: store the context as private variable
torch::jit::script::Module tensors = torch::jit::load("terrain_models/context.pt");
// torch::Tensor prior = tensors.get_attribute("context").toTensor();
auto context_tensor = tensors.run_method("return_tensor").toTensor();
// print the context shape
std::cout << "Context shape: ";
for (auto& size : context_tensor.sizes()) {
for (auto& size : context_tensor_.sizes()) {
std::cout << size << " ";
}
std::cout << std::endl;
Expand All @@ -62,7 +59,7 @@ class CustomTerrainEvaluator : public TerrainEvaluator {


std::vector<torch::jit::IValue> inputs;
inputs.push_back(context_tensor);
inputs.push_back(context_tensor_);
inputs.push_back(img_tensor);

torch::NoGradGuard no_grad;
Expand Down Expand Up @@ -94,9 +91,9 @@ class CustomTerrainEvaluator : public TerrainEvaluator {


// write the latest_bev_image to a file
cv::imwrite("latest_bev_image.png", latest_bev_image);
// cv::imwrite("latest_bev_image.png", latest_bev_image);
// write the scalar_cost_map to a file
cv::imwrite("scalar_cost_map.png", scalar_cost_map*255);
// cv::imwrite("scalar_cost_map.png", scalar_cost_map*255);

// return the output
return scalar_cost_map;
Expand Down

0 comments on commit dfd0ee5

Please sign in to comment.