Skip to content

Commit

Permalink
FEATURE: Allows examine_env to load hard coded policies (#87)
Browse files Browse the repository at this point in the history
* Allow examine_envs to load hard coded policies, E.g.
python robohive/utils/examine_env.py -e FrankaReachRandom-v0 -p robohive.utils.examine_env.rand_policy

* Remove unnecessary import
* examine_env loading scripted polcies: update DESC in examine_env and add unit test
Co-authored-by: Patrick Lancaster <[email protected]>
  • Loading branch information
palanc authored May 1, 2023
1 parent 00533d7 commit 1a8bc83
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
11 changes: 10 additions & 1 deletion robohive/tests/test_examine_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,19 @@ def test_offscreen_rendering(self):
result = runner.invoke(examine_env, ["--env_name", "door-v1", \
"--num_episodes", 1, \
"--render", "offscreen",\
"--camera_name", "top_acam"])
"--camera_name", "top_cam"])
print(result.output.strip())
self.assertEqual(result.exception, None)

def test_scripted_policy_loading(self):
# Call your function and test its output/assertions
print("Testing scripted policy loading")
runner = click.testing.CliRunner()
result = runner.invoke(examine_env, ["--env_name", "door-v1", \
"--num_episodes", 1, \
"--policy_path", "robohive.utils.examine_env.rand_policy"])
print(result.output.strip())
self.assertEqual(result.exception, None)

if __name__ == '__main__':
unittest.main()
37 changes: 26 additions & 11 deletions robohive/utils/examine_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
DESC = '''
Helper script to examine an environment and associated policy for behaviors; \n
- either onscreen, or offscreen, or just rollout without rendering.\n
- save resulting paths as pickle or as 2D plots
- save resulting paths as pickle or as 2D plots \n
- rollout either learned policies or scripted policies (e.g. see rand_policy class below) \n
USAGE:\n
$ python examine_env.py --env_name door-v0 \n
$ python examine_env.py --env_name door-v0 --policy my_policy.pickle --mode evaluation --episodes 10 \n
$ python examine_env.py --env_name door-v1 \n
$ python examine_env.py --env_name door-v1 --policy_path robohive.utils.examine_env.rand_policy \n
$ python examine_env.py --env_name door-v1 --policy_path my_policy.pickle --mode evaluation --episodes 10 \n
'''

# Random policy
Expand All @@ -30,7 +32,14 @@ def __init__(self, env, seed):

def get_action(self, obs):
# return self.env.np_random.uniform(high=self.env.action_space.high, low=self.env.action_space.low)
return self.env.action_space.sample(), {'mode': 'random samples'}
return self.env.action_space.sample(), {'mode': 'random samples', 'evaluation':self.env.action_space.sample()}

def load_class_from_str(module_name, class_name):
try:
m = __import__(module_name, globals(), locals(), class_name)
return getattr(m, class_name)
except (ImportError, AttributeError):
return None

# MAIN =========================================================
@click.command(help=DESC)
Expand All @@ -57,13 +66,19 @@ def main(env_name, policy_path, mode, seed, num_episodes, render, camera_name, o

# resolve policy and outputs
if policy_path is not None:
pi = pickle.load(open(policy_path, 'rb'))
if output_dir == './': # overide the default
output_dir, pol_name = os.path.split(policy_path)
output_name = os.path.splitext(pol_name)[0]
if output_name is None:
pol_name = os.path.split(policy_path)[1]
output_name = os.path.splitext(pol_name)[0]
policy_tokens = policy_path.split('.')
pi = load_class_from_str('.'.join(policy_tokens[:-1]), policy_tokens[-1])

if pi is not None:
pi = pi(env, seed)
else:
pi = pickle.load(open(policy_path, 'rb'))
if output_dir == './': # overide the default
output_dir, pol_name = os.path.split(policy_path)
output_name = os.path.splitext(pol_name)[0]
if output_name is None:
pol_name = os.path.split(policy_path)[1]
output_name = os.path.splitext(pol_name)[0]
else:
pi = rand_policy(env, seed)
mode = 'exploration'
Expand Down

0 comments on commit 1a8bc83

Please sign in to comment.