diff --git a/src/pytomlpp/_io.py b/src/pytomlpp/_io.py index 14ba1f8..817ae25 100644 --- a/src/pytomlpp/_io.py +++ b/src/pytomlpp/_io.py @@ -3,25 +3,47 @@ from . import _impl -def dump(data, fl): +def dump(data, fl, mode="w"): """Serialise data to TOML. Args: data: data to serialise - fl (io.TextIOBase): file-object to write to (supports ``write``) + fl (io.TextIOBase, str or os.Pathlike): + file-object to write to (supports ``write``) + or a file name (str, os.Pathlike) that supports ``open`` + mode (str): opening mode. Either "w" or "wt" (text), or "wb" (binary). Defaults to "w". + Ignored if fl supports ``write``. """ + data = _impl.dumps(data) + if mode == "wb": + data = data.encode("utf-8") + if hasattr(fl, "write"): + fl.write(data) + return + with open(fl, mode=mode) as fh: + fh.write(data) - fl.write(_impl.dumps(data)) -def load(fl): +def load(fl, mode="r"): """Deserialise from TOML. Args: - fl (io.TextIOBase): file-object to read from (supports ``read``) + fl (io.TextIOBase, str or os.Pathlike): + file-object to read from (supports ``read``) + or a file name (str, os.Pathlike) that supports ``open`` + mode (str): opening mode. Either "r" or "rt" (text) or "rb" (binary). Defaults to "r". + Ignored if fl supports ``read``. Returns: deserialised data """ - return _impl.loads(fl.read()) + if hasattr(fl, "read"): + data = fl.read() + else: + with open(fl, mode=mode) as fh: + data = fh.read() + if isinstance(data, bytes): + return _impl.loads(data.decode("utf-8")) + return _impl.loads(data) diff --git a/tests/python-tests/test_api.py b/tests/python-tests/test_api.py index f156030..8a65c1b 100644 --- a/tests/python-tests/test_api.py +++ b/tests/python-tests/test_api.py @@ -87,14 +87,16 @@ def test_loads_valid_toml_files(toml_file): table_json = json.loads(toml_file.with_suffix(".json").read_text()) table_expected = value_from_json(table_json) assert table == table_expected - + table = pytomlpp.load(toml_file) + assert table == table_expected @pytest.mark.parametrize("toml_file", invalid_toml_files) def test_loads_invalid_toml_files(toml_file): with pytest.raises(pytomlpp.DecodeError): with open(str(toml_file), "r") as f: pytomlpp.load(f) - + with pytest.raises(pytomlpp.DecodeError): + pytomlpp.load(str(toml_file)) @pytest.mark.parametrize("toml_file", valid_toml_files) def test_round_trip_for_valid_toml_files(toml_file): @@ -115,3 +117,9 @@ class A: pass with pytest.raises(TypeError): pytomlpp.dumps({'a': A()}) + +@pytest.mark.parametrize("toml_file", valid_toml_files) +def test_decode_encode_binary(toml_file, tmp_path): + data = pytomlpp.load(toml_file) + pytomlpp.dump(data, str(tmp_path / "tmp.toml"), mode="wb") + assert pytomlpp.load(str(tmp_path / "tmp.toml"), mode="rb") == data