diff --git a/unit_test/problems/test_brax.py b/unit_test/problems/test_brax.py index 19e5363b..8aac1ba5 100644 --- a/unit_test/problems/test_brax.py +++ b/unit_test/problems/test_brax.py @@ -35,11 +35,14 @@ def neuroevolution_process( print(f"In generation {index}:") t = time.time() workflow.step() - print(f"\tTime elapsed: {time.time() - t: .4f}(s).") - monitor: EvalMonitor = workflow.get_submodule("monitor") - print(f"\tTop fitness: {monitor.topk_fitness}") - best_params = adapter.to_params(monitor.topk_solutions[0]) - print(f"\tBest params: {best_params}") + + print(f"\tTime elapsed: {time.time() - t: .4f}(s).") + monitor: EvalMonitor = workflow.get_submodule("monitor") + print(f"\tTop fitness: {monitor.topk_fitness}") + best_params = adapter.to_params(monitor.topk_solutions[0]) + print(f"\tBest params: {best_params}") + + return best_params class TestBraxProblem(unittest.TestCase): @@ -96,8 +99,10 @@ def test_brax_problem(self): monitor=pop_monitor, device=self.device, ) - neuroevolution_process( + best_params = neuroevolution_process( workflow=workflow, adapter=adapter, max_generation=3, ) + + problem.visualize(best_params)