Skip to content

Commit

Permalink
Merge pull request #5 from xsuite/release/v0.1.1
Browse files Browse the repository at this point in the history
Release 0.1.1
  • Loading branch information
freddieknets authored Feb 26, 2024
2 parents 0c7f9b5 + 0226f28 commit 5eafc64
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 33 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xaux"
version = "0.1.0"
version = "0.1.1"
description = "Support tools for Xsuite packages"
authors = ["Frederik Van der Veken <[email protected]>"]
license = "Apache 2.0"
Expand Down
69 changes: 69 additions & 0 deletions tests/test_deliberate_failure_and_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from multiprocessing import Pool
import pytest
from xaux import ProtectFile
import json
from pathlib import Path
import time
import shutil


def rewrite(pf, with_copy=False):
data = json.load(pf)
time.sleep(0.2)
data["myint"] += 1
if not with_copy:
pf.seek(0) # revert point to beginning of file
json.dump(data, pf, indent=4, sort_keys=True)
pf.truncate()
else: # write to another file and copy back
cfname = "_copy_" + pf.name
with open(cfname, "w") as cf:
json.dump(data, cf, indent=4, sort_keys=True)
shutil.copyfile(cfname, pf.name)
Path.unlink(Path(cfname))


def change_file_protected(fname, with_copy=False):
with ProtectFile(fname, "r+", backup=False, wait=0.06) as pf:
rewrite(pf, with_copy=with_copy)
return


def change_file_standard(fname, with_copy=False):
with open(fname, "r+") as pf: # fails with this context
rewrite(pf)
return


def init_file(fname):
with ProtectFile(fname, "w", backup=False, wait=1) as pf:
json.dump({"myint": 0}, pf, indent=4)


def test_deliberate_failure():
fname = "test_standard.json"
assert not Path(fname).exists()
init_file(fname)
workers = 4
with Pool(processes=workers) as pool:
pool.map(change_file_standard, [fname] * 4)

with open(fname, "r+") as pf: # fails with this context
data = json.load(pf)
assert data["myint"] != workers # assert that result is wrong
Path.unlink(Path(fname))


@pytest.mark.parametrize("with_copy", [False, True])
def test_protection(with_copy):
fname = "test_protection.json"
assert not Path(fname).exists()
init_file(fname)
workers = 4
with Pool(processes=workers) as pool:
pool.map(change_file_protected, [(fname)] * 4)

with open(fname, "r+") as pf: # fails with this context
data = json.load(pf)
assert data["myint"] == workers
Path.unlink(Path(fname))
2 changes: 1 addition & 1 deletion tests/test_version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from xaux import __version__

def test_version():
assert __version__ == '0.1.0'
assert __version__ == '0.1.1'

2 changes: 1 addition & 1 deletion xaux/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
# ===================
# Do not change
# ===================
__version__ = '0.1.0'
__version__ = '0.1.1'
# ===================
124 changes: 94 additions & 30 deletions xaux/protectfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import tempfile
import time
import json
import subprocess

tempdir = tempfile.TemporaryDirectory()
protected_open = {}
Expand Down Expand Up @@ -46,18 +47,27 @@ def get_hash(filename, size=128):
def get_fstat(filename):
stats = Path(filename).stat()
return {
'n_sequence_fields': stats.n_sequence_fields,
'n_unnamed_fields': stats.n_unnamed_fields,
'st_mode': stats.st_mode,
'st_ino': stats.st_ino,
'st_dev': stats.st_dev,
'st_uid': stats.st_uid,
'st_gid': stats.st_gid,
'st_size': stats.st_size,
'st_mtime_ns': stats.st_mtime_ns,
'st_ctime_ns': stats.st_ctime_ns,
'n_sequence_fields': int(stats.n_sequence_fields),
'n_unnamed_fields': int(stats.n_unnamed_fields),
'st_mode': int(stats.st_mode),
'st_ino': int(stats.st_ino),
'st_dev': int(stats.st_dev),
'st_uid': int(stats.st_uid),
'st_gid': int(stats.st_gid),
'st_size': int(stats.st_size),
'st_mtime_ns': int(stats.st_mtime_ns),
'st_ctime_ns': int(stats.st_ctime_ns),
}

def xrdcp_installed():
try:
cmd = subprocess.run(["xrdcp", "--version"], stdout=subprocess.PIPE,
stderr=subprocess.PIPE, check=True)
return cmd.returncode == 0
except (subprocess.CalledProcessError, FileNotFoundError):
return False


class ProtectFile:
"""A wrapper around a file pointer, protecting it with a lockfile and backups.
Expand Down Expand Up @@ -106,23 +116,23 @@ class ProtectFile:
--------
Reading in a file (while making sure it is not written to by another process):
>>> from protectfile import ProtectedFile
>>> with ProtectedFile('thebook.txt', 'r', backup=False, wait=1) as pf:
>>> from protectfile import ProtectFile
>>> with ProtectFile('thebook.txt', 'r', backup=False, wait=1) as pf:
>>> text = pf.read()
Reading and appending to a file:
>>> from protectfile import ProtectedFile
>>> with ProtectedFile('thebook.txt', 'r+', backup=False, wait=1) as pf:
>>> from protectfile import ProtectFile
>>> with ProtectFile('thebook.txt', 'r+', backup=False, wait=1) as pf:
>>> text = pf.read()
>>> pf.write("This string will be added at the end of the file, \
... however, it won't be added to the 'text' variable")
Reading and updating a JSON file:
>>> import json
>>> from protectfile import ProtectedFile
>>> with ProtectedFile(info.json, 'r+', backup=False, wait=1) as pf:
>>> from protectfile import ProtectFile
>>> with ProtectFile(info.json, 'r+', backup=False, wait=1) as pf:
>>> meta = json.load(pf)
>>> meta.update({'author': 'Emperor Claudius'})
>>> pf.truncate(0) # Delete file contents (to avoid appending)
Expand All @@ -132,13 +142,21 @@ class ProtectFile:
Reading and updating a Parquet file:
>>> import pandas as pd
>>> from protectfile import ProtectedFile
>>> with ProtectedFile(mydata.parquet, 'r+b', backup=False, wait=1) as pf:
>>> from protectfile import ProtectFile
>>> with ProtectFile(mydata.parquet, 'r+b', backup=False, wait=1) as pf:
>>> data = pd.read_parquet(pf)
>>> data['x'] += 5
>>> pf.truncate(0) # Delete file contents (to avoid appending)
>>> pf.seek(0) # Move file pointer to start of file
>>> data.to_parquet(pf, index=True)
Reading and updating a json file in EOS with xrdcp:
>>> from protectfile import ProtectFile
>>> eos_url = 'root://eosuser.cern.ch/'
>>> fname = '/eos/user/k/kparasch/test.json'
>>> with ProtectFile(fname, 'r+', eos_url=eos_url) as pf:
>>> pass
"""

def __init__(self, *args, **kwargs):
Expand All @@ -164,6 +182,8 @@ def __init__(self, *args, **kwargs):
max_lock_time : float, default None
If provided, it will write the maximum runtime in seconds inside the
lockfile. This is to avoided crashed accesses locking the file forever.
eos_url : string, default None
If provided, it will use xrdcp to copy the temporary file to eos and back.
Additionally, the following parameters are inherited from open():
'file', 'mode', 'buffering', 'encoding', 'errors', 'newline', 'closefd', 'opener'
Expand All @@ -186,6 +206,20 @@ def __init__(self, *args, **kwargs):
self._backup_if_readonly = arg.pop('backup_if_readonly', False)
self._check_hash = arg.pop('check_hash', True)

# Make sure conditions are satisfied when using EOS-XRDCP
self._eos_url = arg.pop('eos_url', None)
if self._eos_url is not None:
self.original_eos_path = arg['file']
if self._do_backup or self._backup_if_readonly:
raise NotImplementedError("Backup not supported with eos_url.")
if not self._eos_url.startswith("root://eos") or not self._eos_url.endswith('.cern.ch/'):
raise NotImplementedError(f'Invalid EOS url provided: {self._eos_url=}')
if not str(self.original_eos_path).startswith("/eos"):
raise NotImplementedError(f'Only /eos paths are supported with eos_url.')
if not xrdcp_installed():
raise RuntimeError("xrdcp is not installed.")
self.original_eos_path = self._eos_url + self.original_eos_path

# Initialise paths
arg['file'] = Path(arg['file']).resolve()
file = arg['file']
Expand All @@ -207,7 +241,6 @@ def __init__(self, *args, **kwargs):
if self._exists:
raise FileExistsError


# Provide an expected running time (to free a file in case of crash)
max_lock_time = arg.pop('max_lock_time', None)
if max_lock_time is not None and self._readonly == False \
Expand Down Expand Up @@ -251,7 +284,7 @@ def __init__(self, *args, **kwargs):
self._flock = io.open(self.lockfile, 'r+')
break
else:
raise RunTimeError("Too many lockfiles!")
raise RuntimeError("Too many lockfiles!")

# Store lock information
if max_lock_time is not None:
Expand Down Expand Up @@ -280,8 +313,12 @@ def __init__(self, *args, **kwargs):
# slow if many processes write to it concurrently
if not self._readonly:
if self._exists:
_print_debug("Init", f"cp {self.file=} to {self.tempfile=}")
shutil.copy2(self.file, self.tempfile)
if self._eos_url is not None:
_print_debug("Init", f"xrdcp {self.original_eos_path=} to {self.tempfile=}")
self.xrdcp(self.original_eos_path, self.tempfile)
else:
_print_debug("Init", f"cp {self.file=} to {self.tempfile=}")
shutil.copy2(self.file, self.tempfile)
arg['file'] = self.tempfile
self._fd = io.open(**arg)

Expand All @@ -302,14 +339,20 @@ def __exit__(self, *args, **kwargs):
# Check that original file was not modified in between (i.e. corrupted)
# TODO: verify that checking file stats is 1) enough, and 2) not
# potentially problematic on certain file systems
if self._exists and get_fstat(self.file) != self._fstat:
file_changed = False
if self._exists:
new_stats = get_fstat(self.file)
for key, val in self._fstat.items():
if key not in new_stats or val != new_stats[key]:
file_changed = True
if file_changed:
print(f"Error: File {self.file} changed during lock!")
# If corrupted, restore from backup
# and move result of calculation (i.e. tempfile) to the parent folder
print("Old stats:")
print(self._fstat)
print("New stats:")
print(get_fstat(self.file))
print(new_stats)
self.restore()
else:
# All is fine: move result from temporary file to original
Expand Down Expand Up @@ -338,16 +381,24 @@ def mv_temp(self, destination=None):
if not self._readonly:
if destination is None:
# Move temporary file to original file
_print_debug("Mv_temp", f"cp {self.tempfile=} to {self.file=}")
shutil.copy2(self.tempfile, self.file)
if self._eos_url is not None:
_print_debug("Mv_temp", f"xrdcp {self.tempfile=} to {self.original_eos_path=}")
self.xrdcp(self.tempfile, self.original_eos_path)
else:
_print_debug("Mv_temp", f"cp {self.tempfile=} to {self.file=}")
shutil.copy2(self.tempfile, self.file)
# Check if copy succeeded
if self._check_hash and get_hash(self.tempfile) != get_hash(self.file):
print(f"Warning: tried to copy temporary file {self.tempfile} into {self.file}, "
+ "but hashes do not match!")
self.restore()
else:
_print_debug("Mv_temp", f"cp {self.tempfile=} to {destination=}")
shutil.copy2(self.tempfile, destination)
if self._eos_url is not None:
_print_debug("Mv_temp", f"xrdcp {self.tempfile=} to {destination=}")
self.xrdcp(self.tempfile, destination)
else:
_print_debug("Mv_temp", f"cp {self.tempfile=} to {destination=}")
shutil.copy2(self.tempfile, destination)
_print_debug("Mv_temp", f"unlink {self.tempfile=}")
self.tempfile.unlink()

Expand All @@ -359,8 +410,11 @@ def restore(self):
self.backupfile.rename(self.file)
print('Restored file to previous state.')
if not self._readonly:
alt_file = Path(self.file.parent, self.file.name + '__' \
+ datetime.datetime.now().isoformat() + '.result').resolve()
extension = f"__{datetime.datetime.now().isoformat()}.result"
if self._eos_url is not None:
alt_file = self.original_eos_path + extension
else:
alt_file = Path(self.file.parent, self.file.name + extension).resolve()
self.mv_temp(alt_file)
print(f"Saved calculation results in {alt_file.name}.")

Expand All @@ -386,3 +440,13 @@ def release(self, pop=True):
if pop:
protected_open.pop(self._file, 0)

def xrdcp(self, source=None, destination=None):
if source is None or destination is None:
raise RuntimeError("Source or destination not specified in xrdcp command.")
if self._eos_url is None:
raise RuntimeError("self._eos_url is None, while it shouldn't have been.")

subprocess.run(["xrdcp", "-f", f"{str(source)}", f"{str(destination)}"],
check=True)


0 comments on commit 5eafc64

Please sign in to comment.