diff --git a/src/aiida_pythonjob/utils.py b/src/aiida_pythonjob/utils.py index 167fc23..ee0c34a 100644 --- a/src/aiida_pythonjob/utils.py +++ b/src/aiida_pythonjob/utils.py @@ -39,34 +39,31 @@ def add_imports(type_hint): def inspect_function(func: Callable) -> Dict[str, Any]: """Serialize a function for storage or transmission.""" + # we need save the source code explicitly, because in the case of jupyter notebook, + # the source code is not saved in the pickle file try: - # we need save the source code explicitly, because in the case of jupyter notebook, - # the source code is not saved in the pickle file source_code = inspect.getsource(func) - # Split the source into lines for processing - source_code_lines = source_code.split("\n") - function_source_code = "\n".join(source_code_lines) - # Find the first line of the actual function definition - for i, line in enumerate(source_code_lines): - if line.strip().startswith("def "): - break - function_source_code_without_decorator = "\n".join(source_code_lines[i:]) - function_source_code_without_decorator = textwrap.dedent(function_source_code_without_decorator) - # we also need to include the necessary imports for the types used in the type hints. - try: - required_imports = get_required_imports(func) - except Exception as e: - required_imports = {} - print(f"Failed to get required imports for function {func.__name__}: {e}") - # Generate import statements - import_statements = "\n".join( - f"from {module} import {', '.join(types)}" for module, types in required_imports.items() - ) - except Exception as e: - print(f"Failed to inspect function {func.__name__}: {e}") - function_source_code = "" - function_source_code_without_decorator = "" - import_statements = "" + except OSError: + raise ValueError("Failed to get the source code of the function.") + + # Split the source into lines for processing + source_code_lines = source_code.split("\n") + function_source_code = "\n".join(source_code_lines) + # Find the first line of the actual function definition + for i, line in enumerate(source_code_lines): + if line.strip().startswith("def "): + break + function_source_code_without_decorator = "\n".join(source_code_lines[i:]) + function_source_code_without_decorator = textwrap.dedent(function_source_code_without_decorator) + # we also need to include the necessary imports for the types used in the type hints. + try: + required_imports = get_required_imports(func) + except Exception as exception: + raise ValueError(f"Failed to get the required imports for the function: {exception}") + # Generate import statements + import_statements = "\n".join( + f"from {module} import {', '.join(types)}" for module, types in required_imports.items() + ) return { "name": func.__name__, "source_code": function_source_code, @@ -225,6 +222,9 @@ def create_conda_env( scheduler.set_transport(transport) try: retval, stdout, stderr = transport.exec_command_wait(script) + print("retval", retval) + print("stdout", stdout) + print("stderr", stderr) except NotImplementedError: return ( True, @@ -235,7 +235,7 @@ def create_conda_env( if retval != 0: return ( False, - f"The command `echo -n` returned a non-zero return code ({retval})", + f"The command returned a non-zero return code ({retval})", ) template = """ diff --git a/tests/test_create_env.py b/tests/test_create_env.py new file mode 100644 index 0000000..6eda4ea --- /dev/null +++ b/tests/test_create_env.py @@ -0,0 +1,60 @@ +from unittest.mock import MagicMock, patch + + +def test_create_conda_env(): + # Define test parameters + computer_name = "localhost" + env_name = "test_env" + pip_packages = ["numpy", "pandas"] + conda_deps = ["scipy"] + python_version = "3.8" + shell = "posix" + + # Mock the computer and related objects + mock_computer = MagicMock() + mock_computer.label = computer_name + mock_user = MagicMock() + mock_user.email = "test_user@test.com" + mock_authinfo = MagicMock() + mock_transport = MagicMock() + mock_scheduler = MagicMock() + + mock_authinfo.get_transport.return_value = mock_transport + mock_computer.get_authinfo.return_value = mock_authinfo + mock_authinfo.computer.get_scheduler.return_value = mock_scheduler + + # Mock successful transport behavior + mock_transport.exec_command_wait.return_value = ( + 0, # retval + "Environment setup is complete.\n", # stdout + "", # stderr + ) + + # Patch `load_computer` and `User.collection.get_default` to return mocked objects + with ( + patch("aiida.orm.utils.loaders.load_computer", return_value=mock_computer), + patch("aiida.orm.User.collection.get_default", return_value=mock_user), + ): + from aiida_pythonjob.utils import create_conda_env + + # Call the function + success, message = create_conda_env( + computer=computer_name, + name=env_name, + pip=pip_packages, + conda={"dependencies": conda_deps}, + python_version=python_version, + shell=shell, + ) + + # Assertions for successful case + assert success is True + assert message == "Environment setup is complete." + + # Validate that exec_command_wait was called with the generated script + mock_transport.exec_command_wait.assert_called_once() + called_script = mock_transport.exec_command_wait.call_args[0][0] + assert f"conda create -y -n {env_name} python={python_version}" in called_script + assert "pip install numpy pandas" in called_script + + mock_transport.close()