Skip to content

Commit

Permalink
Python EPI: Only print stdout on failures
Browse files Browse the repository at this point in the history
  • Loading branch information
LeafyLi committed Nov 25, 2022
1 parent b736406 commit da8aaae
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 28 deletions.
5 changes: 4 additions & 1 deletion epi_judge_python/test_framework/generic_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import json
import os
import sys
Expand Down Expand Up @@ -82,8 +83,9 @@ def run_tests(handler, config, res_printer):
test_failure = TestFailure()

try:
test_stdout = io.StringIO()
test_output = handler.run_test(config.timeout_seconds,
config.metrics_override, test_case)
config.metrics_override, test_case, test_stdout)
result = TestResult.PASSED
tests_passed += 1
metrics.append(test_output.metrics)
Expand All @@ -107,6 +109,7 @@ def run_tests(handler, config, res_printer):
test_failure.get_description(), test_output.timer)

if result != TestResult.PASSED:
print(test_stdout.getvalue())
if not handler.expected_is_void():
test_case.pop()
if test_explanation not in ('', 'TODO'):
Expand Down
57 changes: 30 additions & 27 deletions epi_judge_python/test_framework/generic_test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import math

from contextlib import redirect_stdout
from test_framework import test_utils
from test_framework.binary_tree_utils import (assert_equal_binary_trees,
is_object_tree_type)
Expand Down Expand Up @@ -63,43 +64,45 @@ def parse_signature(self, signature):

self._ret_value_trait = get_trait(signature[-1])

def run_test(self, timeout_seconds, metrics_override, test_args):
def run_test(self, timeout_seconds, metrics_override, test_args, stdout_capturer):
"""
This method is invoked for each row in a test data file (except the header).
It deserializes the list of arguments and calls the user function with them.
:param timeout_seconds: number of seconds to timeout
:param metrics_override: metrics transformation for customizing metrics calculation
:param test_args: serialized arguments
:param stdout_capturer: Captures stdout such that a test case's stdout is only printed if it fails.
:return: list, that contains [result of comparison of expected and result, expected, result].
Two last entries are omitted in case of the void return type
"""
expected_param_count = len(
self._param_traits) + (0 if self.expected_is_void() else 1)
if len(test_args) != expected_param_count:
raise RuntimeError(
'Invalid argument count: expected {}, actual {}'.format(
expected_param_count, len(test_args)))

parsed = [
param_trait.parse(json.loads(test_arg))
for param_trait, test_arg in zip(self._param_traits, test_args)
]

metrics = self.calculate_metrics(parsed)
metrics = metrics_override(metrics, *parsed)

executor = TimedExecutor(timeout_seconds)
if self._has_executor_hook:
result = self._func(executor, *parsed)
else:
result = executor.run(lambda: self._func(*parsed))

if not self.expected_is_void():
expected = self._ret_value_trait.parse(json.loads(test_args[-1]))
self._assert_results_equal(expected, result)

return TestOutput(executor.get_timer(), metrics)
with redirect_stdout(stdout_capturer):
expected_param_count = len(
self._param_traits) + (0 if self.expected_is_void() else 1)
if len(test_args) != expected_param_count:
raise RuntimeError(
'Invalid argument count: expected {}, actual {}'.format(
expected_param_count, len(test_args)))

parsed = [
param_trait.parse(json.loads(test_arg))
for param_trait, test_arg in zip(self._param_traits, test_args)
]

metrics = self.calculate_metrics(parsed)
metrics = metrics_override(metrics, *parsed)

executor = TimedExecutor(timeout_seconds)
if self._has_executor_hook:
result = self._func(executor, *parsed)
else:
result = executor.run(lambda: self._func(*parsed))

if not self.expected_is_void():
expected = self._ret_value_trait.parse(json.loads(test_args[-1]))
self._assert_results_equal(expected, result)

return TestOutput(executor.get_timer(), metrics)

def _assert_results_equal(self, expected, result):
if self._comp is not None:
Expand Down

0 comments on commit da8aaae

Please sign in to comment.