Skip to content

Commit

Permalink
add test for create_env
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 2, 2024
1 parent 04761dc commit c87c09b
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 27 deletions.
54 changes: 27 additions & 27 deletions src/aiida_pythonjob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,34 +39,31 @@ def add_imports(type_hint):

def inspect_function(func: Callable) -> Dict[str, Any]:
"""Serialize a function for storage or transmission."""
# we need save the source code explicitly, because in the case of jupyter notebook,
# the source code is not saved in the pickle file
try:
# we need save the source code explicitly, because in the case of jupyter notebook,
# the source code is not saved in the pickle file
source_code = inspect.getsource(func)
# Split the source into lines for processing
source_code_lines = source_code.split("\n")
function_source_code = "\n".join(source_code_lines)
# Find the first line of the actual function definition
for i, line in enumerate(source_code_lines):
if line.strip().startswith("def "):
break
function_source_code_without_decorator = "\n".join(source_code_lines[i:])
function_source_code_without_decorator = textwrap.dedent(function_source_code_without_decorator)
# we also need to include the necessary imports for the types used in the type hints.
try:
required_imports = get_required_imports(func)
except Exception as e:
required_imports = {}
print(f"Failed to get required imports for function {func.__name__}: {e}")
# Generate import statements
import_statements = "\n".join(
f"from {module} import {', '.join(types)}" for module, types in required_imports.items()
)
except Exception as e:
print(f"Failed to inspect function {func.__name__}: {e}")
function_source_code = ""
function_source_code_without_decorator = ""
import_statements = ""
except OSError:
raise ValueError("Failed to get the source code of the function.")

# Split the source into lines for processing
source_code_lines = source_code.split("\n")
function_source_code = "\n".join(source_code_lines)
# Find the first line of the actual function definition
for i, line in enumerate(source_code_lines):
if line.strip().startswith("def "):
break
function_source_code_without_decorator = "\n".join(source_code_lines[i:])
function_source_code_without_decorator = textwrap.dedent(function_source_code_without_decorator)
# we also need to include the necessary imports for the types used in the type hints.
try:
required_imports = get_required_imports(func)
except Exception as exception:
raise ValueError(f"Failed to get the required imports for the function: {exception}")
# Generate import statements
import_statements = "\n".join(
f"from {module} import {', '.join(types)}" for module, types in required_imports.items()
)
return {
"name": func.__name__,
"source_code": function_source_code,
Expand Down Expand Up @@ -225,6 +222,9 @@ def create_conda_env(
scheduler.set_transport(transport)
try:
retval, stdout, stderr = transport.exec_command_wait(script)
print("retval", retval)
print("stdout", stdout)
print("stderr", stderr)
except NotImplementedError:
return (
True,
Expand All @@ -235,7 +235,7 @@ def create_conda_env(
if retval != 0:
return (
False,
f"The command `echo -n` returned a non-zero return code ({retval})",
f"The command returned a non-zero return code ({retval})",
)

template = """
Expand Down
60 changes: 60 additions & 0 deletions tests/test_create_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from unittest.mock import MagicMock, patch


def test_create_conda_env():
# Define test parameters
computer_name = "localhost"
env_name = "test_env"
pip_packages = ["numpy", "pandas"]
conda_deps = ["scipy"]
python_version = "3.8"
shell = "posix"

# Mock the computer and related objects
mock_computer = MagicMock()
mock_computer.label = computer_name
mock_user = MagicMock()
mock_user.email = "[email protected]"
mock_authinfo = MagicMock()
mock_transport = MagicMock()
mock_scheduler = MagicMock()

mock_authinfo.get_transport.return_value = mock_transport
mock_computer.get_authinfo.return_value = mock_authinfo
mock_authinfo.computer.get_scheduler.return_value = mock_scheduler

# Mock successful transport behavior
mock_transport.exec_command_wait.return_value = (
0, # retval
"Environment setup is complete.\n", # stdout
"", # stderr
)

# Patch `load_computer` and `User.collection.get_default` to return mocked objects
with (
patch("aiida.orm.utils.loaders.load_computer", return_value=mock_computer),
patch("aiida.orm.User.collection.get_default", return_value=mock_user),
):
from aiida_pythonjob.utils import create_conda_env

# Call the function
success, message = create_conda_env(
computer=computer_name,
name=env_name,
pip=pip_packages,
conda={"dependencies": conda_deps},
python_version=python_version,
shell=shell,
)

# Assertions for successful case
assert success is True
assert message == "Environment setup is complete."

# Validate that exec_command_wait was called with the generated script
mock_transport.exec_command_wait.assert_called_once()
called_script = mock_transport.exec_command_wait.call_args[0][0]
assert f"conda create -y -n {env_name} python={python_version}" in called_script
assert "pip install numpy pandas" in called_script

mock_transport.close()

0 comments on commit c87c09b

Please sign in to comment.