Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stack height classifier #1

Open
wants to merge 15 commits into
base: grasp_pytorch0.4+
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import cv2
import torch
import csv
# import h5py

class Logger():
Expand Down Expand Up @@ -73,6 +74,11 @@ def save_heightmaps(self, iteration, color_heightmap, depth_heightmap, mode):
def write_to_log(self, log_name, log):
np.savetxt(os.path.join(self.transitions_directory, '%s.log.txt' % log_name), log, delimiter=' ')

# For stack classifier data (might not need)
def save_label(self, log_name, log):
with open(os.path.join(self.color_images_directory, '%s.txt' % log_name), 'w') as f:
csv.writer(f, delimiter=' ').writerows(log)

def save_model(self, model, name):
torch.save(model.cpu().state_dict(), os.path.join(self.models_directory, 'snapshot.%s.pth' % (name)))

Expand Down
72 changes: 66 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
from trainer import Trainer
from logger import Logger
import utils
try:
import efficientnet_pytorch
from efficientnet_pytorch import EfficientNet
except ImportError:
print('efficientnet_pytorch is not available, using densenet. '
'Try installing https://github.com/ahundt/EfficientNet-PyTorch for all features.'
'A version of EfficientNets without dilation can be installed with the command:'
' pip3 install efficientnet-pytorch --user --upgrade'
'See https://github.com/lukemelas/EfficientNet-PyTorch for details')
efficientnet_pytorch = None

# to convert action names to the corresponding ID number and vice-versa
ACTION_TO_ID = {'push': 0, 'grasp': 1, 'place': 2}
Expand Down Expand Up @@ -169,6 +179,23 @@ def main(args):
else:
goal_condition_len = 0

#TODO(hkwon214) temporary
# ------ Image Classifier options -----
use_classifier = args.use_classifier
checkpoint_path = args.checkpoint_path
#TODO(hkwon214) hard coded to use efficientnet for now. modify for future?
if use_classifier:
if checkpoint_path is None:
raise NotImplementedError('No checkpoints')
model = EfficientNet.from_name('efficientnet-b0')
#model = nn.DataParallel(model)
model_stack = model_stack.cuda()
checkpoint = torch.load(checkpoint_path)
model_stack.load_state_dict(checkpoint['state_dict'])
model_stack.eval()



# Set random seed
np.random.seed(random_seed)

Expand Down Expand Up @@ -223,7 +250,7 @@ def set_nonlocal_success_variables_false():
nonlocal_variables['grasp_color_success'] = False
nonlocal_variables['place_color_success'] = False

def check_stack_update_goal(place_check=False, top_idx=-1):
def check_stack_update_goal(place_check=False, top_idx=-1, use_classifier = False, input_img = None):
""" Check nonlocal_variables for a good stack and reset if it does not match the current goal.

# Params
Expand All @@ -246,8 +273,13 @@ def check_stack_update_goal(place_check=False, top_idx=-1):
# only the place check expects the current goal to be met
current_stack_goal = current_stack_goal[:-1]
stack_shift = 0
# TODO(ahundt) BUG Figure out why a real stack of size 2 or 3 and a push which touches no blocks does not pass the stack_check and ends up a MISMATCH in need of reset. (update: may now be fixed, double check then delete when confirmed)
stack_matches_goal, nonlocal_variables['stack_height'] = robot.check_stack(current_stack_goal, top_idx=top_idx, stack_axis=stack_axis)
if use_classifier:
# TODO(hkwon214) Add image classifier
stack_matches_goal, nonlocal_variables['stack_height'] = robot.stack_reward(model_stack, input_img, current_stack_goal)
else:
# TODO(ahundt) BUG Figure out why a real stack of size 2 or 3 and a push which touches no blocks does not pass the stack_check and ends up a MISMATCH in need of reset. (update: may now be fixed, double check then delete when confirmed)
stack_matches_goal, nonlocal_variables['stack_height'] = robot.check_stack(current_stack_goal, top_idx=top_idx, stack_axis=stack_axis)

nonlocal_variables['partial_stack_success'] = stack_matches_goal
if nonlocal_variables['stack_height'] == 1:
# A stack of size 1 does not meet the criteria for a partial stack success
Expand Down Expand Up @@ -422,7 +454,14 @@ def process_actions():
# Check if the push caused a topple, size shift zero because
# place operations expect increased height,
# while push expects constant height.
needed_to_reset = check_stack_update_goal()
#TODO(hkwon214) temp
#needed_to_reset = check_stack_update_goal()
color_img, depth_img = robot.get_camera_data()
depth_img = depth_img * robot.cam_depth_scale # Apply depth scale from calibration
# Get heightmap from RGB-D image (by re-projecting 3D point cloud)
color_heightmap_after_action, depth_heightmap_after_action = utils.get_heightmap(color_img, depth_img, robot.cam_intrinsics, robot.cam_pose, workspace_limits, heightmap_resolution)

needed_to_reset = check_stack_update_goal(use_classifier = use_classifier, input_img = depth_heightmap_after_action)
if not place or not needed_to_reset:
print('Push motion successful (no crash, need not move blocks): %r' % (nonlocal_variables['push_success']))
elif nonlocal_variables['primitive_action'] == 'grasp':
Expand All @@ -441,7 +480,15 @@ def process_actions():
top_idx = -2
# check if a failed grasp led to a topple, or if the top block was grasped
# TODO(ahundt) in check_stack() support the check after a specific grasp in case of successful grasp topple. Perhaps allow the top block to be specified?
needed_to_reset = check_stack_update_goal(top_idx=top_idx)
#needed_to_reset = check_stack_update_goal(top_idx=top_idx)
#TODO(hkwon214) temp

color_img, depth_img = robot.get_camera_data()
depth_img = depth_img * robot.cam_depth_scale # Apply depth scale from calibration
# Get heightmap from RGB-D image (by re-projecting 3D point cloud)
color_heightmap_after_action, depth_heightmap_after_action = utils.get_heightmap(color_img, depth_img, robot.cam_intrinsics, robot.cam_pose, workspace_limits, heightmap_resolution)

needed_to_reset = check_stack_update_goal(top_idx=top_idx, use_classifier = use_classifier, input_img = depth_heightmap_after_action)
if nonlocal_variables['grasp_success']:
# robot.restart_sim()
successful_grasp_count += 1
Expand All @@ -464,7 +511,15 @@ def process_actions():
elif nonlocal_variables['primitive_action'] == 'place':
place_count += 1
nonlocal_variables['place_success'] = robot.place(primitive_position, best_rotation_angle)
needed_to_reset = check_stack_update_goal(place_check=True)

#TODO(hkwon214) robot executed task -> capture image for image classifier
color_img, depth_img = robot.get_camera_data()
depth_img = depth_img * robot.cam_depth_scale # Apply depth scale from calibration
# Get heightmap from RGB-D image (by re-projecting 3D point cloud)
color_heightmap_after_action, depth_heightmap_after_action = utils.get_heightmap(color_img, depth_img, robot.cam_intrinsics, robot.cam_pose, workspace_limits, heightmap_resolution)

# needed_to_reset = check_stack_update_goal(place_check=True)
needed_to_reset = check_stack_update_goal(place_check=True, use_classifier = use_classifier, input_img = depth_heightmap_after_action)
if not needed_to_reset and nonlocal_variables['place_success'] and nonlocal_variables['partial_stack_success']:
partial_stack_count += 1
nonlocal_variables['stack'].next()
Expand Down Expand Up @@ -971,6 +1026,11 @@ def experience_replay(method, prev_primitive_action, prev_reward_value, trainer,
parser.add_argument('--grasp_color_task', dest='grasp_color_task', action='store_true', default=False, help='enable grasping specific colored objects')
parser.add_argument('--grasp_count', dest='grasp_cout', type=int, action='store', default=0, help='number of successful task based grasps')

# TODO(hkwon214)
# ------ Image Classifier Options (Temporary) ------
parser.add_argument('--use_classifier', dest='use_classifier', action='store_true', default=False, help='use image classifier weights')
parser.add_argument('--checkpoint_path', dest='checkpoint_path', action='store', default='objects/blocks', help='directory of image classifier weights')

# Run main program with specified arguments
args = parser.parse_args()
main(args)
39 changes: 38 additions & 1 deletion robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import numpy as np
import utils
from simulation import vrep

import torch
import cv2
from scipy import ndimage, misc
class Robot(object):
"""
Key member variables:
Expand Down Expand Up @@ -1243,6 +1245,41 @@ def check_stack(self, object_color_sequence, distance_threshold=0.06, top_idx=-1
# TODO(ahundt) add check_stack for real robot
return goal_success, detected_height

# TODO(hkwon214): From image classifier
def stack_reward(self, model, input_img, current_stack_goal):
#input_img = torch.from_numpy(input_img)
print('IMAGE SHAPE: ' + str(input_img.shape))
goal_success = False
stack_class = model(input_img)
stack_class = stack_class.item()
detected_height = stack_class + 1
if current_stack_goal == detected_height:
goal_success = True
return goal_success, detected_height

def check_incremental_height(self,input_img, current_stack_goal):
goal_success = False
img_median = ndimage.median_filter(input_img, size=5)
max_z = np.max(img_median)
print('MAXZ ' + str(max_z))
if (max_z > 0.051) and (max_z < 0.052):
detected_height = 1
elif (max_z > 0.10) and (max_z < 0.11):
detected_height = 2
elif (max_z > 0.155) and (max_z < 0.156):
detected_height = 3
elif (max_z > 0.20) and (max_z < 0.21):
detected_height = 4
if current_stack_goal == detected_height:
goal_success = True
return goal_success, detected_height

def check_z_height(self,input_img):
# CV or
img_median = ndimage.median_filter(input_img, size=5)
max_z = np.max(img_median)
return max_z


def restart_real(self):

Expand Down
Loading