diff --git a/epi_judge_python/test_framework/generic_test.py b/epi_judge_python/test_framework/generic_test.py index b7b9ffb4b..36358a16e 100644 --- a/epi_judge_python/test_framework/generic_test.py +++ b/epi_judge_python/test_framework/generic_test.py @@ -1,3 +1,4 @@ +import io import json import os import sys @@ -82,8 +83,9 @@ def run_tests(handler, config, res_printer): test_failure = TestFailure() try: + test_stdout = io.StringIO() test_output = handler.run_test(config.timeout_seconds, - config.metrics_override, test_case) + config.metrics_override, test_case, test_stdout) result = TestResult.PASSED tests_passed += 1 metrics.append(test_output.metrics) @@ -107,6 +109,7 @@ def run_tests(handler, config, res_printer): test_failure.get_description(), test_output.timer) if result != TestResult.PASSED: + print(test_stdout.getvalue()) if not handler.expected_is_void(): test_case.pop() if test_explanation not in ('', 'TODO'): diff --git a/epi_judge_python/test_framework/generic_test_handler.py b/epi_judge_python/test_framework/generic_test_handler.py index c3c0ff61f..33788022a 100644 --- a/epi_judge_python/test_framework/generic_test_handler.py +++ b/epi_judge_python/test_framework/generic_test_handler.py @@ -2,6 +2,7 @@ import json import math +from contextlib import redirect_stdout from test_framework import test_utils from test_framework.binary_tree_utils import (assert_equal_binary_trees, is_object_tree_type) @@ -63,7 +64,7 @@ def parse_signature(self, signature): self._ret_value_trait = get_trait(signature[-1]) - def run_test(self, timeout_seconds, metrics_override, test_args): + def run_test(self, timeout_seconds, metrics_override, test_args, stdout_capturer): """ This method is invoked for each row in a test data file (except the header). It deserializes the list of arguments and calls the user function with them. @@ -71,35 +72,37 @@ def run_test(self, timeout_seconds, metrics_override, test_args): :param timeout_seconds: number of seconds to timeout :param metrics_override: metrics transformation for customizing metrics calculation :param test_args: serialized arguments + :param stdout_capturer: Captures stdout such that a test case's stdout is only printed if it fails. :return: list, that contains [result of comparison of expected and result, expected, result]. Two last entries are omitted in case of the void return type """ - expected_param_count = len( - self._param_traits) + (0 if self.expected_is_void() else 1) - if len(test_args) != expected_param_count: - raise RuntimeError( - 'Invalid argument count: expected {}, actual {}'.format( - expected_param_count, len(test_args))) - - parsed = [ - param_trait.parse(json.loads(test_arg)) - for param_trait, test_arg in zip(self._param_traits, test_args) - ] - - metrics = self.calculate_metrics(parsed) - metrics = metrics_override(metrics, *parsed) - - executor = TimedExecutor(timeout_seconds) - if self._has_executor_hook: - result = self._func(executor, *parsed) - else: - result = executor.run(lambda: self._func(*parsed)) - - if not self.expected_is_void(): - expected = self._ret_value_trait.parse(json.loads(test_args[-1])) - self._assert_results_equal(expected, result) - - return TestOutput(executor.get_timer(), metrics) + with redirect_stdout(stdout_capturer): + expected_param_count = len( + self._param_traits) + (0 if self.expected_is_void() else 1) + if len(test_args) != expected_param_count: + raise RuntimeError( + 'Invalid argument count: expected {}, actual {}'.format( + expected_param_count, len(test_args))) + + parsed = [ + param_trait.parse(json.loads(test_arg)) + for param_trait, test_arg in zip(self._param_traits, test_args) + ] + + metrics = self.calculate_metrics(parsed) + metrics = metrics_override(metrics, *parsed) + + executor = TimedExecutor(timeout_seconds) + if self._has_executor_hook: + result = self._func(executor, *parsed) + else: + result = executor.run(lambda: self._func(*parsed)) + + if not self.expected_is_void(): + expected = self._ret_value_trait.parse(json.loads(test_args[-1])) + self._assert_results_equal(expected, result) + + return TestOutput(executor.get_timer(), metrics) def _assert_results_equal(self, expected, result): if self._comp is not None: