diff --git a/robohive/tests/test_examine_env.py b/robohive/tests/test_examine_env.py index 5819ff63..a58667d6 100644 --- a/robohive/tests/test_examine_env.py +++ b/robohive/tests/test_examine_env.py @@ -22,10 +22,19 @@ def test_offscreen_rendering(self): result = runner.invoke(examine_env, ["--env_name", "door-v1", \ "--num_episodes", 1, \ "--render", "offscreen",\ - "--camera_name", "top_acam"]) + "--camera_name", "top_cam"]) print(result.output.strip()) self.assertEqual(result.exception, None) + def test_scripted_policy_loading(self): + # Call your function and test its output/assertions + print("Testing scripted policy loading") + runner = click.testing.CliRunner() + result = runner.invoke(examine_env, ["--env_name", "door-v1", \ + "--num_episodes", 1, \ + "--policy_path", "robohive.utils.examine_env.rand_policy"]) + print(result.output.strip()) + self.assertEqual(result.exception, None) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/robohive/utils/examine_env.py b/robohive/utils/examine_env.py index ea0dc03a..f186ad5b 100644 --- a/robohive/utils/examine_env.py +++ b/robohive/utils/examine_env.py @@ -16,10 +16,12 @@ DESC = ''' Helper script to examine an environment and associated policy for behaviors; \n - either onscreen, or offscreen, or just rollout without rendering.\n -- save resulting paths as pickle or as 2D plots +- save resulting paths as pickle or as 2D plots \n +- rollout either learned policies or scripted policies (e.g. see rand_policy class below) \n USAGE:\n - $ python examine_env.py --env_name door-v0 \n - $ python examine_env.py --env_name door-v0 --policy my_policy.pickle --mode evaluation --episodes 10 \n + $ python examine_env.py --env_name door-v1 \n + $ python examine_env.py --env_name door-v1 --policy_path robohive.utils.examine_env.rand_policy \n + $ python examine_env.py --env_name door-v1 --policy_path my_policy.pickle --mode evaluation --episodes 10 \n ''' # Random policy @@ -30,7 +32,14 @@ def __init__(self, env, seed): def get_action(self, obs): # return self.env.np_random.uniform(high=self.env.action_space.high, low=self.env.action_space.low) - return self.env.action_space.sample(), {'mode': 'random samples'} + return self.env.action_space.sample(), {'mode': 'random samples', 'evaluation':self.env.action_space.sample()} + +def load_class_from_str(module_name, class_name): + try: + m = __import__(module_name, globals(), locals(), class_name) + return getattr(m, class_name) + except (ImportError, AttributeError): + return None # MAIN ========================================================= @click.command(help=DESC) @@ -57,13 +66,19 @@ def main(env_name, policy_path, mode, seed, num_episodes, render, camera_name, o # resolve policy and outputs if policy_path is not None: - pi = pickle.load(open(policy_path, 'rb')) - if output_dir == './': # overide the default - output_dir, pol_name = os.path.split(policy_path) - output_name = os.path.splitext(pol_name)[0] - if output_name is None: - pol_name = os.path.split(policy_path)[1] - output_name = os.path.splitext(pol_name)[0] + policy_tokens = policy_path.split('.') + pi = load_class_from_str('.'.join(policy_tokens[:-1]), policy_tokens[-1]) + + if pi is not None: + pi = pi(env, seed) + else: + pi = pickle.load(open(policy_path, 'rb')) + if output_dir == './': # overide the default + output_dir, pol_name = os.path.split(policy_path) + output_name = os.path.splitext(pol_name)[0] + if output_name is None: + pol_name = os.path.split(policy_path)[1] + output_name = os.path.splitext(pol_name)[0] else: pi = rand_policy(env, seed) mode = 'exploration'