Skip to content

Commit

Permalink
Merge pull request #155 from yilei/push_up_to_339298457
Browse files Browse the repository at this point in the history
Push up to 339298457
  • Loading branch information
yilei authored Oct 27, 2020
2 parents 86d81c3 + 9a090c9 commit ddbd7d4
Show file tree
Hide file tree
Showing 17 changed files with 321 additions and 85 deletions.
15 changes: 15 additions & 0 deletions absl/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).

Nothing notable unreleased.

## 0.11.0 (2020-10-27)

### Changed

* (testing) Surplus entries in AssertionError stack traces from absltest are
now suppressed and no longer reported in the xml_reporter.
* (logging) An exception is now raised instead of `logging.fatal` when logging
directories cannot be found.
* (testing) Multiple flags are now set together before their validators run.
This resolves an issue where multi-flag validators rely on specific flag
combinations.
* (flags) As a deterrent for misuse, FlagHolder objects will now raise a
TypeError exception when used in a conditional statement or equality
expression.

## 0.10.0 (2020-08-19)

### Added
Expand Down
1 change: 1 addition & 0 deletions absl/_build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def py2py3_test_binary(name, **kwargs):
if len(kwargs.get("srcs", [])) != 1:
fail("py2py3_test_binary requires main or len(srcs)==1")
kwargs["main"] = kwargs["srcs"][0]
kwargs.setdefault("tags", []).append("py3-compatible")

native.alias(name = name, actual = select({
"//absl:py3_mode": name + "_py3",
Expand Down
5 changes: 4 additions & 1 deletion absl/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,10 @@ def run(
Args:
main: The main function to execute. It takes an single argument "argv",
which is a list of command line arguments with parsed flags removed.
If it returns an integer, it is used as the process's exit code.
The return value is passed to `sys.exit`, and so for example
a return value of 0 or None results in a successful termination, whereas
a return value of 1 results in abnormal termination.
For more details, see https://docs.python.org/3/library/sys#sys.exit
argv: A non-empty list of the command line arguments including program name,
sys.argv is used if None.
flags_parser: Callable[[List[Text]], Any], the function used to parse flags.
Expand Down
39 changes: 31 additions & 8 deletions absl/flags/_flagvalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,16 +499,25 @@ def __getattr__(self, name):

def __setattr__(self, name, value):
"""Sets the 'value' attribute of the flag --name."""
fl = self._flags()
if name in self.__dict__['__hiddenflags']:
raise AttributeError(name)
if name not in fl:
return self._set_unknown_flag(name, value)
fl[name].value = value
self._assert_validators(fl[name].validators)
fl[name].using_default_value = False
self._set_attributes(**{name: value})
return value

def _set_attributes(self, **attributes):
"""Sets multiple flag values together, triggers validators afterwards."""
fl = self._flags()
known_flags = set()
for name, value in six.iteritems(attributes):
if name in self.__dict__['__hiddenflags']:
raise AttributeError(name)
if name in fl:
fl[name].value = value
known_flags.add(name)
else:
self._set_unknown_flag(name, value)
for name in known_flags:
self._assert_validators(fl[name].validators)
fl[name].using_default_value = False

def validate_all_flags(self):
"""Verifies whether all flags pass validation.
Expand Down Expand Up @@ -1348,6 +1357,20 @@ def __init__(self, flag_values, flag, ensure_non_none_value=False):
# This allows future use of this for "required flags with None default"
self._ensure_non_none_value = ensure_non_none_value

def __eq__(self, other):
raise TypeError(
"unsupported operand type(s) for ==: '{0}' and '{1}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__, type(other).__name__))

def __bool__(self):
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))

__nonzero__ = __bool__

@property
def name(self):
return self._name
Expand Down
10 changes: 4 additions & 6 deletions absl/flags/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,16 +334,13 @@ def mark_flag_as_required(flag_name, flag_values=_flagvalues.FLAGS):
Important note: validator will pass for any non-None value, such as False,
0 (zero), '' (empty string) and so on.
It is recommended to call this method like this:
If your module might be imported by others, and you only wish to make the flag
required when the module is directly executed, call this method like this:
if __name__ == '__main__':
flags.mark_flag_as_required('your_flag_name')
app.run()
Because validation happens at app.run() we want to ensure required-ness
is enforced at that time. You generally do not want to force users who import
your code to have additional required flags for their own binaries or tests.
Args:
flag_name: str, name of the flag
flag_values: flags.FlagValues, optional FlagValues instance where the flag
Expand All @@ -367,7 +364,8 @@ def mark_flag_as_required(flag_name, flag_values=_flagvalues.FLAGS):
def mark_flags_as_required(flag_names, flag_values=_flagvalues.FLAGS):
"""Ensures that flags are not None during program execution.
Recommended usage:
If your module might be imported by others, and you only wish to make the flag
required when the module is directly executed, call this method like this:
if __name__ == '__main__':
flags.mark_flags_as_required(['flag1', 'flag2', 'flag3'])
Expand Down
67 changes: 55 additions & 12 deletions absl/flags/tests/_flagvalues_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _test_find_module_or_id_defining_flag(self, test_id):
# Delete the changelist flag, its short name should still be registered.
del fv.changelist
module_or_id_changelist = testing_fn('changelist')
self.assertEqual(module_or_id_changelist, None)
self.assertIsNone(module_or_id_changelist)
module_or_id_c = testing_fn('c')
self.assertEqual(module_or_id_c, current_module_or_id)
module_or_id_l = testing_fn('l')
Expand Down Expand Up @@ -333,30 +333,30 @@ def test_invalid_flag_name(self):

def test_len(self):
fv = _flagvalues.FlagValues()
self.assertEqual(0, len(fv))
self.assertEmpty(fv)
self.assertFalse(fv)

_defines.DEFINE_boolean('boolean', False, 'help', flag_values=fv)
self.assertEqual(1, len(fv))
self.assertLen(fv, 1)
self.assertTrue(fv)

_defines.DEFINE_boolean(
'bool', False, 'help', short_name='b', flag_values=fv)
self.assertEqual(3, len(fv))
self.assertLen(fv, 3)
self.assertTrue(fv)

def test_pickle(self):
fv = _flagvalues.FlagValues()
with self.assertRaisesRegexp(TypeError, "can't pickle FlagValues"):
with self.assertRaisesRegex(TypeError, "can't pickle FlagValues"):
pickle.dumps(fv)

def test_copy(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_integer('answer', 0, 'help', flag_values=fv)
fv(['', '--answer=1'])

with self.assertRaisesRegexp(
TypeError, 'FlagValues does not support shallow copies'):
with self.assertRaisesRegex(TypeError,
'FlagValues does not support shallow copies'):
copy.copy(fv)

fv2 = copy.deepcopy(fv)
Expand Down Expand Up @@ -640,6 +640,7 @@ def test_gnu_getopt_raise(self, *argv):
class SettingUnknownFlagTest(absltest.TestCase):

def setUp(self):
super(SettingUnknownFlagTest, self).setUp()
self.setter_called = 0

def set_undef(self, unused_name, unused_val):
Expand Down Expand Up @@ -679,9 +680,39 @@ def setter(unused_name, unused_val):
new_flags.undefined_flag = 0


class SetAttributesTest(absltest.TestCase):

def setUp(self):
super(SetAttributesTest, self).setUp()
self.new_flags = _flagvalues.FlagValues()
_defines.DEFINE_boolean(
'defined_flag', None, '', flag_values=self.new_flags)
_defines.DEFINE_boolean(
'another_defined_flag', None, '', flag_values=self.new_flags)
self.setter_called = 0

def set_undef(self, unused_name, unused_val):
self.setter_called += 1

def test_two_defined_flags(self):
self.new_flags._set_attributes(
defined_flag=False, another_defined_flag=False)
self.assertEqual(self.setter_called, 0)

def test_one_defined_one_undefined_flag(self):
with self.assertRaises(_exceptions.UnrecognizedFlagError):
self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)

def test_register_unknown_flag_setter(self):
self.new_flags._register_unknown_flag_setter(self.set_undef)
self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
self.assertEqual(self.setter_called, 1)


class FlagsDashSyntaxTest(absltest.TestCase):

def setUp(self):
super(FlagsDashSyntaxTest, self).setUp()
self.fv = _flagvalues.FlagValues()
_defines.DEFINE_string(
'long_name', 'default', 'help', flag_values=self.fv, short_name='s')
Expand Down Expand Up @@ -754,15 +785,15 @@ def test_allow_overwrite_false(self):

fv.mark_as_parsed()
self.assertEqual('foo', fv.default_foo)
self.assertEqual(None, fv.default_none)
self.assertIsNone(fv.default_none)

fv(['', '--default_foo=notFoo', '--default_none=notNone'])
self.assertEqual('notFoo', fv.default_foo)
self.assertEqual('notNone', fv.default_none)

fv.unparse_flags()
self.assertEqual('foo', fv['default_foo'].value)
self.assertEqual(None, fv['default_none'].value)
self.assertIsNone(fv['default_none'].value)

fv(['', '--default_foo=alsoNotFoo', '--default_none=alsoNotNone'])
self.assertEqual('alsoNotFoo', fv.default_foo)
Expand All @@ -772,15 +803,15 @@ def test_multi_string_default_none(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_multi_string('foo', None, 'help', flag_values=fv)
fv.mark_as_parsed()
self.assertEqual(None, fv.foo)
self.assertIsNone(fv.foo)
fv(['', '--foo=aa'])
self.assertEqual(['aa'], fv.foo)
fv.unparse_flags()
self.assertEqual(None, fv['foo'].value)
self.assertIsNone(fv['foo'].value)
fv(['', '--foo=bb', '--foo=cc'])
self.assertEqual(['bb', 'cc'], fv.foo)
fv.unparse_flags()
self.assertEqual(None, fv['foo'].value)
self.assertIsNone(fv['foo'].value)

def test_multi_string_default_string(self):
fv = _flagvalues.FlagValues()
Expand Down Expand Up @@ -883,6 +914,18 @@ def test_allow_override(self):
self.assertEqual(3, first.value)
self.assertEqual(3, second.value)

def test_eq(self):
with self.assertRaises(TypeError):
self.name_flag == 'value' # pylint: disable=pointless-statement

def test_eq_reflection(self):
with self.assertRaises(TypeError):
'value' == self.name_flag # pylint: disable=pointless-statement

def test_bool(self):
with self.assertRaises(TypeError):
bool(self.name_flag)


if __name__ == '__main__':
absltest.main()
12 changes: 11 additions & 1 deletion absl/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,10 @@ def find_log_dir_and_names(program_name=None, log_dir=None):
Returns:
(log_dir, file_prefix, symlink_prefix)
Raises:
FileNotFoundError: raised in Python 3 when it cannot find a log directory.
OSError: raised in Python 2 when it cannot find a log directory.
"""
if not program_name:
# Strip the extension (foobar.par becomes foobar, and
Expand Down Expand Up @@ -699,6 +703,10 @@ def find_log_dir(log_dir=None):
directory. Otherwise if the --log_dir command-line flag is provided,
the logfile will be created in that directory. Otherwise the logfile
will be created in a standard location.
Raises:
FileNotFoundError: raised in Python 3 when it cannot find a log directory.
OSError: raised in Python 2 when it cannot find a log directory.
"""
# Get a list of possible log dirs (will try to use them in order).
if log_dir:
Expand All @@ -715,7 +723,9 @@ def find_log_dir(log_dir=None):
for d in dirs:
if os.path.isdir(d) and os.access(d, os.W_OK):
return d
_absl_logger.fatal("Can't find a writable directory for logs, tried %s", dirs)
exception_class = OSError if six.PY2 else FileNotFoundError
raise exception_class(
"Can't find a writable directory for logs, tried %s" % dirs)


def get_absl_log_prefix(record):
Expand Down
9 changes: 4 additions & 5 deletions absl/logging/tests/logging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,13 +753,12 @@ def test_find_log_dir_with_tmp(self):

def test_find_log_dir_with_nothing(self):
with mock.patch.object(os.path, 'exists'), \
mock.patch.object(os.path, 'isdir'), \
mock.patch.object(logging.get_absl_logger(), 'fatal') as mock_fatal:
mock.patch.object(os.path, 'isdir'):
os.path.exists.return_value = False
os.path.isdir.return_value = False
log_dir = logging.find_log_dir()
mock_fatal.assert_called()
self.assertEqual(None, log_dir)
exception_class = OSError if six.PY2 else FileNotFoundError
with self.assertRaises(exception_class):
logging.find_log_dir()

def test_find_log_dir_and_names_with_args(self):
user = 'test_user'
Expand Down
1 change: 0 additions & 1 deletion absl/testing/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//absl/flags",
"@six_archive//:six",
],
)

Expand Down
8 changes: 7 additions & 1 deletion absl/testing/absltest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@

_TEXT_OR_BINARY_TYPES = (six.text_type, six.binary_type)

# Suppress surplus entries in AssertionError stack traces.
__unittest = True # pylint: disable=invalid-name


def expectedFailureIf(condition, reason): # pylint: disable=invalid-name
"""Expects the test to fail if the run condition is True.
Expand Down Expand Up @@ -2444,7 +2447,10 @@ def _get_qualname(cls):
def _rmtree_ignore_errors(path):
# type: (Text) -> None
if os.path.isfile(path):
os.unlink(path)
try:
os.unlink(path)
except OSError:
pass
else:
shutil.rmtree(path, ignore_errors=True)

Expand Down
4 changes: 1 addition & 3 deletions absl/testing/flagsaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def some_func():
import inspect

from absl import flags
import six

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -156,8 +155,7 @@ def __call__(self, func):
def __enter__(self):
self._saved_flag_values = save_flag_values(FLAGS)
try:
for name, value in six.iteritems(self._overrides):
setattr(FLAGS, name, value)
FLAGS._set_attributes(**self._overrides)
except:
# It may fail because of flag validators.
restore_flag_values(self._saved_flag_values, FLAGS)
Expand Down
Loading

0 comments on commit ddbd7d4

Please sign in to comment.