Skip to content

Commit f1f8933

Browse files
authored
fix: cleanup session dir on windows (#241)
Signed-off-by: Cody Edwards <edwards@amazon.com>
1 parent bd39896 commit f1f8933

File tree

4 files changed

+41
-14
lines changed

4 files changed

+41
-14
lines changed

src/openjd/sessions/_session.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
)
4949
from ._version import version
5050

51+
if is_windows(): # pragma: nocover
52+
from subprocess import HIGH_PRIORITY_CLASS # type: ignore
53+
5154
if TYPE_CHECKING:
5255
from openjd.model.v2023_09._model import EnvironmentVariableObject
5356

@@ -463,30 +466,30 @@ def cleanup(self) -> None:
463466
if self._user is not None:
464467
files = [str(f) for f in self.working_directory.glob("*")]
465468

469+
creation_flags = None
466470
if is_posix():
467471
recursive_delete_cmd = ["rm", "-rf"]
468472
else:
469473
recursive_delete_cmd = [
470-
"start",
471-
'"Powershell"',
472-
"/high",
473-
"/wait",
474-
"/b",
475474
"powershell",
476475
"-Command",
477476
"Remove-Item",
478477
"-Recurse",
479478
"-Force",
480479
]
481480
files = [", ".join(files)]
481+
# The cleanup needs to run as a high priority
482+
# https://learn.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-getpriorityclass#return-value
483+
creation_flags = HIGH_PRIORITY_CLASS
482484

483-
subprocess = LoggingSubprocess(
485+
_subprocess = LoggingSubprocess(
484486
logger=self._logger,
485487
args=recursive_delete_cmd + files,
486488
user=self._user,
489+
creation_flags=creation_flags,
487490
)
488491
# Note: Blocking call until the process has exited
489-
subprocess.run()
492+
_subprocess.run()
490493

491494
self._working_dir.cleanup()
492495
except RuntimeError as exc:

src/openjd/sessions/_subprocess.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class LoggingSubprocess(object):
6464
_has_started: Event
6565
_os_env_vars: Optional[dict[str, Optional[str]]]
6666
_working_dir: Optional[str]
67+
_creation_flags: Optional[int]
6768

6869
_pid: Optional[int]
6970
_sudo_child_process_group_id: Optional[int]
@@ -79,13 +80,16 @@ def __init__(
7980
callback: Optional[Callable[[], None]] = None,
8081
os_env_vars: Optional[dict[str, Optional[str]]] = None,
8182
working_dir: Optional[str] = None,
83+
creation_flags: Optional[int] = None,
8284
):
8385
if len(args) < 1:
8486
raise ValueError("'args' kwarg must be a sequence of at least one element")
8587
if user is not None and os.name == "posix" and not isinstance(user, PosixSessionUser):
8688
raise ValueError("Argument 'user' must be a PosixSessionUser on posix systems.")
8789
if user is not None and is_windows() and not isinstance(user, WindowsSessionUser):
8890
raise ValueError("Argument 'user' must be a WindowsSessionUser on Windows systems.")
91+
if not is_windows() and creation_flags is not None:
92+
raise ValueError("Argument 'creation_flags' is only supported on Windows")
8993

9094
self._logger = logger
9195
self._args = args[:] # Make a copy
@@ -100,6 +104,7 @@ def __init__(
100104
self._pid = None
101105
self._returncode = None
102106
self._sudo_child_process_group_id = None
107+
self._creation_flags = creation_flags
103108

104109
@property
105110
def pid(self) -> Optional[int]:
@@ -275,6 +280,9 @@ def _start_subprocess(self) -> Optional[Popen]:
275280
# https://docs.python.org/2/library/subprocess.html#subprocess.CREATE_NEW_PROCESS_GROUP
276281
popen_args["creationflags"] = CREATE_NEW_PROCESS_GROUP
277282

283+
if self._creation_flags:
284+
popen_args["creationflags"] |= self._creation_flags
285+
278286
# Get the command string for logging
279287
cmd_line_for_logger: str
280288
if is_posix():

test/openjd/sessions/test_runner_step_script.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,9 @@ def test_run_file_in_session_dir_as_windows_user(
309309
# GIVEN
310310
tmpdir = TempDir(user=windows_user)
311311
script = StepScript_2023_09(
312-
actions=StepActions_2023_09(onRun=Action_2023_09(command=r"test.bat")),
312+
actions=StepActions_2023_09(
313+
onRun=Action_2023_09(command=CommandString_2023_09(r"test.bat"))
314+
),
313315
embeddedFiles=[
314316
EmbeddedFileText_2023_09(
315317
name="Foo",

test/openjd/sessions/test_subprocess.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,14 @@ def test_terminate_ends_process_tree(
322322
all_messages = []
323323
# Note: This is the number of *CHILD* processes of the main process that we start.
324324
# The total number of processes in flight will be this plus one.
325-
expected_num_child_procs: int
326-
if is_posix():
327-
# Process tree: python -> python
328-
# Children: python
329-
expected_num_child_procs = 1
330-
else:
325+
326+
# On Posix and on Windows not in a virutal environment:
327+
# Process tree: python -> python
328+
# Children: python
329+
expected_num_child_procs = 1
330+
331+
# Check if we're in a virtual environment on Windows, see https://docs.python.org/3/library/venv.html#how-venvs-work
332+
if is_windows() and sys.prefix != sys.base_prefix:
331333
# Windows starts an extra python process due to running in a virtual environment
332334
# Process tree: conhost -> python -> python -> python
333335
# Children: python, python, python
@@ -489,6 +491,18 @@ def test_run_gracetime_when_process_ends_but_grandchild_uses_stdout(
489491
m not in messages for m in not_expected_messages
490492
), f"Unexpected messages: {', '.join(repr(m) for m in not_expected_messages if m in messages)}"
491493

494+
@pytest.mark.skipif(is_windows(), reason="Posix-specific tests")
495+
def test_creation_flags_posix(self, queue_handler: QueueHandler) -> None:
496+
497+
with pytest.raises(ValueError):
498+
logger = build_logger(queue_handler)
499+
LoggingSubprocess(
500+
logger=logger,
501+
args=[sys.executable, "-c", 'print("this should not run")'],
502+
# Creation flags aren't supported on Posix systems.
503+
creation_flags=1337,
504+
)
505+
492506

493507
def list_has_items_in_order(expected: list, actual: list) -> bool:
494508
"""

0 commit comments

Comments
 (0)