Skip to content

Commit

Permalink
Merge pull request #32 from fal-ai/handle-invalid-parameters
Browse files Browse the repository at this point in the history
fix: validate invalid parameters when initializing environments
  • Loading branch information
isidentical authored Nov 1, 2022
2 parents 9570c59 + da10552 commit 50c3cc6
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 22 deletions.
7 changes: 3 additions & 4 deletions src/isolate/backends/conda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import shutil
import subprocess
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, ClassVar, Dict, List

Expand All @@ -23,16 +23,15 @@
class CondaEnvironment(BaseEnvironment[Path]):
BACKEND_NAME: ClassVar[str] = "conda"

packages: List[str]
packages: List[str] = field(default_factory=list)

@classmethod
def from_config(
cls,
config: Dict[str, Any],
settings: IsolateSettings = DEFAULT_SETTINGS,
) -> BaseEnvironment:
user_provided_packages = config.get("packages", [])
environment = cls(user_provided_packages)
environment = cls(**config)
environment.apply_settings(settings)
return environment

Expand Down
2 changes: 1 addition & 1 deletion src/isolate/backends/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def from_config(
config: Dict[str, Any],
settings: IsolateSettings = DEFAULT_SETTINGS,
) -> BaseEnvironment:
environment = cls()
environment = cls(**config)
environment.apply_settings(settings)
return environment

Expand Down
6 changes: 1 addition & 5 deletions src/isolate/backends/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ def from_config(
config: Dict[str, Any],
settings: IsolateSettings = DEFAULT_SETTINGS,
) -> BaseEnvironment:
environment = cls(
host=config["host"],
target_environment_kind=config["target_environment_kind"],
target_environment_config=config["target_environment_config"],
)
environment = cls(**config)
environment.apply_settings(settings)

return environment
Expand Down
12 changes: 3 additions & 9 deletions src/isolate/backends/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
import subprocess
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Union

Expand All @@ -21,7 +21,7 @@
class VirtualPythonEnvironment(BaseEnvironment[Path]):
BACKEND_NAME: ClassVar[str] = "virtualenv"

requirements: List[str]
requirements: List[str] = field(default_factory=list)
constraints_file: Optional[os.PathLike] = None

@classmethod
Expand All @@ -30,13 +30,7 @@ def from_config(
config: Dict[str, Any],
settings: IsolateSettings = DEFAULT_SETTINGS,
) -> BaseEnvironment:
requirements = config.get("requirements", [])
# TODO: we probably should validate that this file actually exists
constraints_file = config.get("constraints_file", None)
environment = cls(
requirements=requirements,
constraints_file=constraints_file,
)
environment = cls(**config)
environment.apply_settings(settings)
return environment

Expand Down
7 changes: 6 additions & 1 deletion src/isolate/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,16 @@ def Run(
messages: Queue[definitions.PartialRunResult] = Queue()
try:
environment = from_grpc(request.environment)
except ValueError as exc:
except ValueError:
return self.abort_with_msg(
f"Unknown environment kind: {request.environment.kind}.",
context,
)
except TypeError as exc:
return self.abort_with_msg(
f"Invalid environment parameter: {str(exc)}.",
context,
)

run_settings = replace(
self.default_settings,
Expand Down
39 changes: 39 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,42 @@ def test_isolate_server_demo(isolate_server):
assert local_connection.run(target_func) == remote_connection.run(
target_func
)


@pytest.mark.parametrize(
"kind, config",
[
(
"virtualenv",
{
"packages": [
"pyjokes==1.0.0",
]
},
),
(
"conda",
{
"requirements": [
"pyjokes=1.0.0",
]
},
),
(
"isolate-server",
{
"host": "localhost",
"port": 1234,
"target_environment_kind": "virtualenv",
"target_environment_config": {
"requirements": [
"pyjokes==1.0.0",
]
},
},
),
],
)
def test_wrong_options(kind, config):
with pytest.raises(TypeError):
isolate.prepare_environment(kind, **config)
46 changes: 44 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from concurrent import futures
from functools import partial
from pathlib import Path
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, cast

import grpc
import pytest
Expand Down Expand Up @@ -80,7 +80,7 @@ def run_request(
if return_value is _NOT_SET:
raise ValueError("Never sent the result")
else:
return return_value
return cast(definitions.SerializedObject, return_value)


@pytest.mark.parametrize("inherit_local", [True, False])
Expand Down Expand Up @@ -178,3 +178,45 @@ def test_user_logs_immediate(stub: definitions.IsolateStub, monkeypatch: Any) ->
by_stream = {log.level: log.message for log in user_logs}
assert by_stream[LogLevel.STDOUT] == "0.6.0"
assert by_stream[LogLevel.STDERR] == "error error!"


def test_unknown_environment(stub: definitions.IsolateStub, monkeypatch: Any) -> None:
inherit_from_local(monkeypatch)

env_definition = define_environment("unknown")
request = definitions.BoundFunction(
function=to_serialized_object(
partial(
eval,
"__import__('pyjokes').__version__",
),
method="dill",
),
environment=env_definition,
)

with pytest.raises(grpc.RpcError) as exc:
run_request(stub, request)

assert exc.match("Unknown environment kind")


def test_invalid_param(stub: definitions.IsolateStub, monkeypatch: Any) -> None:
inherit_from_local(monkeypatch)

env_definition = define_environment("virtualenv", packages=["pyjokes==1.0"])
request = definitions.BoundFunction(
function=to_serialized_object(
partial(
eval,
"__import__('pyjokes').__version__",
),
method="dill",
),
environment=env_definition,
)

with pytest.raises(grpc.RpcError) as exc:
run_request(stub, request)

assert exc.match("unexpected keyword argument 'packages'")

0 comments on commit 50c3cc6

Please sign in to comment.