diff --git a/examples/example_fastapi/main.py b/examples/example_fastapi/main.py index b676e661b..ee2790d78 100644 --- a/examples/example_fastapi/main.py +++ b/examples/example_fastapi/main.py @@ -216,9 +216,9 @@ async def check_url(internet_host: HttpUrl, timeout_seconds: int = 5, socket_fam async def read_internet(): """Check Internet connectivity of the system, requiring IP connectivity, domain resolution and HTTPS/TLS.""" internet_hosts: list[HttpUrl] = [ - HttpUrl(url="https://aleph.im/", scheme="https"), - HttpUrl(url="https://ethereum.org", scheme="https"), - HttpUrl(url="https://ipfs.io/", scheme="https"), + HttpUrl(url="https://aleph.im/"), + HttpUrl(url="https://ethereum.org/"), + HttpUrl(url="https://ipfs.io/"), ] timeout_seconds = 5 diff --git a/packaging/Makefile b/packaging/Makefile index c797cf267..673633a51 100644 --- a/packaging/Makefile +++ b/packaging/Makefile @@ -19,7 +19,7 @@ debian-package-code: python3 -m venv build_venv build_venv/bin/pip install --progress-bar off --upgrade pip setuptools wheel # Fixing this protobuf dependency version to avoid getting CI errors as version 5.29.0 have this compilation issue - build_venv/bin/pip install --no-cache-dir --progress-bar off --target ./aleph-vm/opt/aleph-vm/ 'aleph-message==0.6.1' 'eth-account==0.10' 'sentry-sdk==1.31.0' 'qmp==1.1.0' 'aleph-superfluid~=0.2.1' 'sqlalchemy[asyncio]>=2.0' 'aiosqlite==0.19.0' 'alembic==1.13.1' 'aiohttp_cors==0.7.0' 'pyroute2==0.7.12' 'python-cpuid==0.1.0' 'solathon==1.0.2' 'protobuf==5.28.3' + build_venv/bin/pip install --no-cache-dir --progress-bar off --target ./aleph-vm/opt/aleph-vm/ 'git+https://github.com/aleph-im/aleph-message@108-upgrade-pydantic-version#egg=aleph-message' 'eth-account==0.10' 'sentry-sdk==1.31.0' 'qmp==1.1.0' 'aleph-superfluid~=0.2.1' 'sqlalchemy[asyncio]>=2.0' 'aiosqlite==0.19.0' 'alembic==1.13.1' 'aiohttp_cors==0.7.0' 'pydantic-settings==2.6.1' 'pyroute2==0.7.12' 'python-cpuid==0.1.0' 'solathon==1.0.2' 'protobuf==5.28.3' build_venv/bin/python3 -m compileall ./aleph-vm/opt/aleph-vm/ debian-package-resources: firecracker-bins vmlinux download-ipfs-kubo target/bin/sevctl diff --git a/pyproject.toml b/pyproject.toml index 3a12314d9..89da2a6bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "aioredis==1.3.1", "aiosqlite==0.19", "alembic==1.13.1", - "aleph-message==0.6.1", + "aleph-message @ git+https://github.com/aleph-im/aleph-message@108-upgrade-pydantic-version#egg=main", "aleph-superfluid~=0.2.1", "dbus-python==1.3.2", "eth-account~=0.10", @@ -49,9 +49,11 @@ dependencies = [ "protobuf==5.28.3", "psutil==5.9.5", "py-cpuinfo==9", - "pydantic[dotenv]~=1.10.13", + "pydantic>=2", + "pydantic-settings==2.6.1", "pyroute2==0.7.12", "python-cpuid==0.1.1", + "python-dotenv", "pyyaml==6.0.1", "qmp==1.1", "schedule==1.2.1", @@ -120,8 +122,9 @@ dependencies = [ "mypy==1.8.0", "ruff==0.4.6", "isort==5.13.2", - "yamlfix==1.16.1", + "yamlfix==1.17.0", "pyproject-fmt==2.2.1", + "pydantic>=2", ] [tool.hatch.envs.linting.scripts] typing = "mypy {args:src/aleph/vm/ tests/ examples/example_fastapi runtimes/aleph-debian-12-python}" diff --git a/src/aleph/vm/conf.py b/src/aleph/vm/conf.py index f1f8d75f6..b16115fa5 100644 --- a/src/aleph/vm/conf.py +++ b/src/aleph/vm/conf.py @@ -13,8 +13,9 @@ from aleph_message.models import Chain from aleph_message.models.execution.environment import HypervisorType -from pydantic import BaseSettings, Field, HttpUrl -from pydantic.env_settings import DotenvType, env_file_sentinel +from dotenv import load_dotenv +from pydantic import Field, HttpUrl +from pydantic_settings import BaseSettings, SettingsConfigDict from aleph.vm.orchestrator.chain import STREAM_CHAINS from aleph.vm.utils import ( @@ -24,6 +25,8 @@ is_command_available, ) +load_dotenv() + logger = logging.getLogger(__name__) Url = NewType("Url", str) @@ -165,7 +168,7 @@ class Settings(BaseSettings): default=True, description="Enable IPv6 forwarding on the host. Required for IPv6 connectivity in VMs.", ) - NFTABLES_CHAIN_PREFIX = "aleph" + NFTABLES_CHAIN_PREFIX: str = "aleph" USE_NDP_PROXY: bool = Field( default=True, description="Use the Neighbor Discovery Protocol Proxy to respond to Router Solicitation for instances on IPv6", @@ -176,8 +179,8 @@ class Settings(BaseSettings): description="Method used to resolve the dns server if DNS_NAMESERVERS is not present.", ) DNS_NAMESERVERS: list[str] | None = None - DNS_NAMESERVERS_IPV4: list[str] | None - DNS_NAMESERVERS_IPV6: list[str] | None + DNS_NAMESERVERS_IPV4: list[str] | None = None + DNS_NAMESERVERS_IPV6: list[str] | None = None FIRECRACKER_PATH: Path = Path("/opt/firecracker/firecracker") JAILER_PATH: Path = Path("/opt/firecracker/jailer") @@ -185,7 +188,7 @@ class Settings(BaseSettings): LINUX_PATH: Path = Path("/opt/firecracker/vmlinux.bin") INIT_TIMEOUT: float = 20.0 - CONNECTOR_URL = Url("http://localhost:4021") + CONNECTOR_URL: HttpUrl = HttpUrl("http://localhost:4021") CACHE_ROOT: Path = Path("/var/cache/aleph/vm") MESSAGE_CACHE: Path | None = Field( @@ -206,10 +209,10 @@ class Settings(BaseSettings): None, description="Location of executions log. Default to EXECUTION_ROOT/executions/" ) - PERSISTENT_VOLUMES_DIR: Path = Field( + PERSISTENT_VOLUMES_DIR: Path | None = Field( None, description="Persistent volumes location. Default to EXECUTION_ROOT/volumes/persistent/" ) - JAILER_BASE_DIR: Path = Field(None) + JAILER_BASE_DIR: Path | None = Field(None) MAX_PROGRAM_ARCHIVE_SIZE: int = 10_000_000 # 10 MB MAX_DATA_ARCHIVE_SIZE: int = 10_000_000 # 10 MB @@ -308,8 +311,10 @@ class Settings(BaseSettings): description="Identifier used for the 'fake instance' message defined in " "examples/instance_message_from_aleph.json", ) - FAKE_INSTANCE_MESSAGE = Path(abspath(join(__file__, "../../../../examples/instance_message_from_aleph.json"))) - FAKE_INSTANCE_QEMU_MESSAGE = Path(abspath(join(__file__, "../../../../examples/qemu_message_from_aleph.json"))) + FAKE_INSTANCE_MESSAGE: Path = Path(abspath(join(__file__, "../../../../examples/instance_message_from_aleph.json"))) + FAKE_INSTANCE_QEMU_MESSAGE: Path = Path( + abspath(join(__file__, "../../../../examples/qemu_message_from_aleph.json")) + ) CHECK_FASTAPI_VM_ID: str = "63faf8b5db1cf8d965e6a464a0cb8062af8e7df131729e48738342d956f29ace" LEGACY_CHECK_FASTAPI_VM_ID: str = "67705389842a0a1b95eaa408b009741027964edc805997475e95c505d642edd8" @@ -345,7 +350,7 @@ def check(self): assert isfile(self.JAILER_PATH), f"File not found {self.JAILER_PATH}" assert isfile(self.LINUX_PATH), f"File not found {self.LINUX_PATH}" assert self.NETWORK_INTERFACE, "Network interface is not specified" - assert self.CONNECTOR_URL.startswith("http://") or self.CONNECTOR_URL.startswith("https://") + assert str(self.CONNECTOR_URL).startswith("http://") or str(self.CONNECTOR_URL).startswith("https://") if self.ALLOW_VM_NETWORKING: assert exists( f"/sys/class/net/{self.NETWORK_INTERFACE}" @@ -480,18 +485,25 @@ def display(self) -> str: attributes[attr] = "" else: attributes[attr] = getattr(self, attr) - - return "\n".join(f"{self.Config.env_prefix}{attribute} = {value}" for attribute, value in attributes.items()) + return "\n".join( + f"{self.model_config.get('env_prefix', '')}{attribute} = {value}" for attribute, value in attributes.items() + ) def __init__( self, - _env_file: DotenvType | None = env_file_sentinel, + _env_file: str | Path | None = None, _env_file_encoding: str | None = None, _env_nested_delimiter: str | None = None, _secrets_dir: Path | None = None, **values: Any, ) -> None: - super().__init__(_env_file, _env_file_encoding, _env_nested_delimiter, _secrets_dir, **values) + super().__init__( + _env_file, + _env_file_encoding, + _env_nested_delimiter, + _secrets_dir, + **values, + ) if not self.MESSAGE_CACHE: self.MESSAGE_CACHE = self.CACHE_ROOT / "message" if not self.CODE_CACHE: @@ -515,10 +527,9 @@ def __init__( if not self.CONFIDENTIAL_SESSION_DIRECTORY: self.CONFIDENTIAL_SESSION_DIRECTORY = self.EXECUTION_ROOT / "sessions" - class Config: - env_prefix = "ALEPH_VM_" - case_sensitive = False - env_file = ".env" + model_config = SettingsConfigDict( + env_prefix="ALEPH_VM_", case_sensitive=False, env_file=".env", validate_default=False + ) def make_db_url(): diff --git a/src/aleph/vm/controllers/__main__.py b/src/aleph/vm/controllers/__main__.py index f3cef3171..d5e2f21ad 100644 --- a/src/aleph/vm/controllers/__main__.py +++ b/src/aleph/vm/controllers/__main__.py @@ -26,7 +26,7 @@ def configuration_from_file(path: Path): with open(path) as f: data = json.load(f) - return Configuration.parse_obj(data) + return Configuration.model_validate(data) def parse_args(args): diff --git a/src/aleph/vm/controllers/configuration.py b/src/aleph/vm/controllers/configuration.py index 34922e5b8..d5a5e6a99 100644 --- a/src/aleph/vm/controllers/configuration.py +++ b/src/aleph/vm/controllers/configuration.py @@ -30,26 +30,26 @@ class QemuGPU(BaseModel): class QemuVMConfiguration(BaseModel): qemu_bin_path: str - cloud_init_drive_path: str | None + cloud_init_drive_path: str | None = None image_path: str monitor_socket_path: Path qmp_socket_path: Path vcpu_count: int mem_size_mb: int - interface_name: str | None + interface_name: str | None = None host_volumes: list[QemuVMHostVolume] gpus: list[QemuGPU] class QemuConfidentialVMConfiguration(BaseModel): qemu_bin_path: str - cloud_init_drive_path: str | None + cloud_init_drive_path: str | None = None image_path: str monitor_socket_path: Path qmp_socket_path: Path vcpu_count: int mem_size_mb: int - interface_name: str | None + interface_name: str | None = None host_volumes: list[QemuVMHostVolume] gpus: list[QemuGPU] ovmf_path: Path @@ -76,7 +76,7 @@ def save_controller_configuration(vm_hash: str, configuration: Configuration) -> config_file_path = Path(f"{settings.EXECUTION_ROOT}/{vm_hash}-controller.json") with config_file_path.open("w") as controller_config_file: controller_config_file.write( - configuration.json( + configuration.model_dump_json( by_alias=True, exclude_none=True, indent=4, exclude={"settings": {"USE_DEVELOPER_SSH_KEYS"}} ) ) diff --git a/src/aleph/vm/controllers/qemu/instance.py b/src/aleph/vm/controllers/qemu/instance.py index 1a1dc9b19..d1596044b 100644 --- a/src/aleph/vm/controllers/qemu/instance.py +++ b/src/aleph/vm/controllers/qemu/instance.py @@ -220,7 +220,7 @@ async def configure(self): def save_controller_configuration(self): """Save VM configuration to be used by the controller service""" path = Path(f"{settings.EXECUTION_ROOT}/{self.vm_hash}-controller.json") - path.open("w").write(self.controller_configuration.json(by_alias=True, exclude_none=True, indent=4)) + path.open("w").write(self.controller_configuration.model_dump_json(by_alias=True, exclude_none=True, indent=4)) path.chmod(0o644) return path diff --git a/src/aleph/vm/hypervisors/firecracker/config.py b/src/aleph/vm/hypervisors/firecracker/config.py index b7e4fc77a..59560ce1c 100644 --- a/src/aleph/vm/hypervisors/firecracker/config.py +++ b/src/aleph/vm/hypervisors/firecracker/config.py @@ -1,6 +1,6 @@ from pathlib import Path -from pydantic import BaseModel, PositiveInt +from pydantic import BaseModel, ConfigDict, PositiveInt VSOCK_PATH = "/tmp/v.sock" @@ -51,12 +51,7 @@ class FirecrackerConfig(BaseModel): boot_source: BootSource drives: list[Drive] machine_config: MachineConfig - vsock: Vsock | None - network_interfaces: list[NetworkInterface] | None + vsock: Vsock | None = None + network_interfaces: list[NetworkInterface] | None = None - class Config: - allow_population_by_field_name = True - - @staticmethod - def alias_generator(x: str): - return x.replace("_", "-") + model_config = ConfigDict(populate_by_name=True, alias_generator=lambda x: x.replace("_", "-")) diff --git a/src/aleph/vm/hypervisors/firecracker/microvm.py b/src/aleph/vm/hypervisors/firecracker/microvm.py index d357fb6e0..1d34e2752 100644 --- a/src/aleph/vm/hypervisors/firecracker/microvm.py +++ b/src/aleph/vm/hypervisors/firecracker/microvm.py @@ -195,7 +195,7 @@ async def save_configuration_file(self, config: FirecrackerConfig) -> Path: if not self.use_jailer else open(f"{self.jailer_path}/tmp/config.json", "wb") ) as config_file: - config_file.write(config.json(by_alias=True, exclude_none=True, indent=4).encode()) + config_file.write(config.model_dump_json(by_alias=True, exclude_none=True, indent=4).encode()) config_file.flush() config_file_path = Path(config_file.name) config_file_path.chmod(0o644) diff --git a/src/aleph/vm/models.py b/src/aleph/vm/models.py index 4edaf2d47..19de70fe0 100644 --- a/src/aleph/vm/models.py +++ b/src/aleph/vm/models.py @@ -439,8 +439,8 @@ async def save(self): vcpus=self.vm.hardware_resources.vcpus, memory=self.vm.hardware_resources.memory, network_tap=self.vm.tap_interface.device_name if self.vm.tap_interface else "", - message=self.message.json(), - original_message=self.original.json(), + message=self.message.model_dump_json(), + original_message=self.original.model_dump_json(), persistent=self.persistent, ) ) @@ -463,8 +463,8 @@ async def save(self): io_write_bytes=None, vcpus=self.vm.hardware_resources.vcpus, memory=self.vm.hardware_resources.memory, - message=self.message.json(), - original_message=self.original.json(), + message=self.message.model_dump_json(), + original_message=self.original.model_dump_json(), persistent=self.persistent, gpus=json.dumps(self.gpus, default=pydantic_encoder), ) diff --git a/src/aleph/vm/orchestrator/README.md b/src/aleph/vm/orchestrator/README.md index c1d22ea0f..10a0569f5 100644 --- a/src/aleph/vm/orchestrator/README.md +++ b/src/aleph/vm/orchestrator/README.md @@ -80,12 +80,12 @@ cd aleph-vm/ ### 2.e. Install Pydantic -[PyDantic](https://pydantic-docs.helpmanual.io/) +[PyDantic](https://pydantic-docs.helpmanual.io/) is used to parse and validate Aleph messages. ```shell apt install -y --no-install-recommends --no-install-suggests python3-pip -pip3 install pydantic[dotenv] +pip3 install pydantic-dotenv pip3 install 'aleph-message==0.4.9' ``` diff --git a/src/aleph/vm/orchestrator/chain.py b/src/aleph/vm/orchestrator/chain.py index 0b4174397..717fefa38 100644 --- a/src/aleph/vm/orchestrator/chain.py +++ b/src/aleph/vm/orchestrator/chain.py @@ -1,7 +1,7 @@ import logging from aleph_message.models import Chain -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, HttpUrl, model_validator logger = logging.getLogger(__name__) @@ -12,7 +12,7 @@ class ChainInfo(BaseModel): """ chain_id: int - rpc: str + rpc: HttpUrl standard_token: str | None = None super_token: str | None = None testnet: bool = False @@ -22,7 +22,8 @@ class ChainInfo(BaseModel): def token(self) -> str | None: return self.super_token or self.standard_token - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_tokens(cls, values): if not values.get("standard_token") and not values.get("super_token"): msg = "At least one of standard_token or super_token must be provided." @@ -34,7 +35,7 @@ def check_tokens(cls, values): # TESTNETS "SEPOLIA": ChainInfo( chain_id=11155111, - rpc="https://eth-sepolia.public.blastapi.io", + rpc=HttpUrl("https://eth-sepolia.public.blastapi.io"), standard_token="0xc4bf5cbdabe595361438f8c6a187bdc330539c60", super_token="0x22064a21fee226d8ffb8818e7627d5ff6d0fc33a", active=False, @@ -43,18 +44,18 @@ def check_tokens(cls, values): # MAINNETS Chain.ETH: ChainInfo( chain_id=1, - rpc="https://eth-mainnet.public.blastapi.io", + rpc=HttpUrl("https://eth-mainnet.public.blastapi.io"), standard_token="0x27702a26126e0B3702af63Ee09aC4d1A084EF628", active=False, ), Chain.AVAX: ChainInfo( chain_id=43114, - rpc="https://api.avax.network/ext/bc/C/rpc", + rpc=HttpUrl("https://api.avax.network/ext/bc/C/rpc"), super_token="0xc0Fbc4967259786C743361a5885ef49380473dCF", ), Chain.BASE: ChainInfo( chain_id=8453, - rpc="https://base-mainnet.public.blastapi.io", + rpc=HttpUrl("https://base-mainnet.public.blastapi.io"), super_token="0xc0Fbc4967259786C743361a5885ef49380473dCF", ), } diff --git a/src/aleph/vm/orchestrator/reactor.py b/src/aleph/vm/orchestrator/reactor.py index 785f2c233..f8326fa97 100644 --- a/src/aleph/vm/orchestrator/reactor.py +++ b/src/aleph/vm/orchestrator/reactor.py @@ -61,7 +61,7 @@ async def trigger(self, message: AlephMessage): for subscription in listener.content.on.message: if subscription_matches(subscription, message): vm_hash = listener.item_hash - event = message.json() + event = message.model_dump_json() # Register the listener in the list of coroutines to run asynchronously: coroutines.append(run_code_on_event(vm_hash, event, self.pubsub, pool=self.pool)) break diff --git a/src/aleph/vm/orchestrator/resources.py b/src/aleph/vm/orchestrator/resources.py index 673c324a1..ac5a627cb 100644 --- a/src/aleph/vm/orchestrator/resources.py +++ b/src/aleph/vm/orchestrator/resources.py @@ -76,8 +76,8 @@ class MachineProperties(BaseModel): class GpuProperties(BaseModel): - devices: list[GpuDevice] | None - available_devices: list[GpuDevice] | None + devices: list[GpuDevice] | None = None + available_devices: list[GpuDevice] | None = None class MachineUsage(BaseModel): @@ -158,7 +158,7 @@ async def about_system_usage(request: web.Request): gpu=get_machine_gpus(request), ) - return web.json_response(text=usage.json(exclude_none=True)) + return web.json_response(text=usage.model_dump_json(exclude_none=True)) @cors_allow_all diff --git a/src/aleph/vm/orchestrator/run.py b/src/aleph/vm/orchestrator/run.py index f82a8ae17..638a1b983 100644 --- a/src/aleph/vm/orchestrator/run.py +++ b/src/aleph/vm/orchestrator/run.py @@ -1,4 +1,5 @@ import asyncio +import json import logging from typing import Any @@ -55,7 +56,7 @@ async def create_vm_execution(vm_hash: ItemHash, pool: VmPool, persistent: bool message, original_message = await load_updated_message(vm_hash) pool.message_cache[vm_hash] = message - logger.debug(f"Message: {message.json(indent=4, sort_keys=True, exclude_none=True)}") + logger.debug(f"Message: {json.dumps(message.dict(exclude_none=True), indent=4, sort_keys=True, default=str)}") execution = await pool.create_a_vm( vm_hash=vm_hash, diff --git a/src/aleph/vm/orchestrator/tasks.py b/src/aleph/vm/orchestrator/tasks.py index 7f1574157..803d3ca32 100644 --- a/src/aleph/vm/orchestrator/tasks.py +++ b/src/aleph/vm/orchestrator/tasks.py @@ -80,10 +80,10 @@ async def subscribe_via_ws(url) -> AsyncIterable[AlephMessage]: try: yield parse_message(data) - except pydantic.error_wrappers.ValidationError as error: + except pydantic.ValidationError as error: item_hash = data.get("item_hash", "ITEM_HASH_NOT_FOUND") logger.warning( - f"Invalid Aleph message: {item_hash} \n {error.json()}\n {error.raw_errors}", + f"Invalid Aleph message: {item_hash} \n {error.errors}", exc_info=False, ) continue diff --git a/src/aleph/vm/orchestrator/views/__init__.py b/src/aleph/vm/orchestrator/views/__init__.py index d8a18e227..f4705f015 100644 --- a/src/aleph/vm/orchestrator/views/__init__.py +++ b/src/aleph/vm/orchestrator/views/__init__.py @@ -384,7 +384,7 @@ async def update_allocations(request: web.Request): try: data = await request.json() - allocation = Allocation.parse_obj(data) + allocation = Allocation.model_validate(data) except ValidationError as error: return web.json_response(text=error.json(), status=web.HTTPBadRequest.status_code) @@ -478,7 +478,7 @@ async def notify_allocation(request: web.Request): await update_aggregate_settings() try: data = await request.json() - vm_notification = VMNotification.parse_obj(data) + vm_notification = VMNotification.model_validate(data) except JSONDecodeError: return web.HTTPBadRequest(reason="Body is not valid JSON") except ValidationError as error: diff --git a/src/aleph/vm/orchestrator/views/authentication.py b/src/aleph/vm/orchestrator/views/authentication.py index 55ed624ef..dee57e339 100644 --- a/src/aleph/vm/orchestrator/views/authentication.py +++ b/src/aleph/vm/orchestrator/views/authentication.py @@ -5,6 +5,8 @@ Can be enabled on an endpoint using the @require_jwk_authentication decorator """ +from __future__ import annotations + # Keep datetime import as is as it allow patching in test import datetime import functools @@ -22,7 +24,7 @@ from jwcrypto import jwk from jwcrypto.jwa import JWA from nacl.exceptions import BadSignatureError -from pydantic import BaseModel, ValidationError, root_validator, validator +from pydantic import BaseModel, ValidationError, field_validator, model_validator from solathon.utils import verify_signature from aleph.vm.conf import settings @@ -90,39 +92,40 @@ class SignedPubKeyHeader(BaseModel): signature: bytes payload: bytes - @validator("signature") + @field_validator("signature") + @classmethod def signature_must_be_hex(cls, v: bytes) -> bytes: """Convert the signature from hexadecimal to bytes""" return bytes.fromhex(v.removeprefix(b"0x").decode()) - @validator("payload") + @field_validator("payload") + @classmethod def payload_must_be_hex(cls, v: bytes) -> bytes: """Convert the payload from hexadecimal to bytes""" return bytes.fromhex(v.decode()) - @root_validator(pre=False, skip_on_failure=True) - def check_expiry(cls, values) -> dict[str, bytes]: + @model_validator(mode="after") + def check_expiry(values) -> SignedPubKeyHeader: """Check that the token has not expired""" - payload: bytes = values["payload"] - content = SignedPubKeyPayload.parse_raw(payload) + payload = values.payload + content = SignedPubKeyPayload.model_validate_json(payload) if not is_token_still_valid(content.expires): - msg = "Token expired" - raise ValueError(msg) + raise ValueError("Token expired") return values - @root_validator(pre=False, skip_on_failure=True) - def check_signature(cls, values) -> dict[str, bytes]: + @model_validator(mode="after") + def check_signature(values) -> SignedPubKeyHeader: """Check that the signature is valid""" - signature: list = values["signature"] - payload: bytes = values["payload"] - content = SignedPubKeyPayload.parse_raw(payload) + signature = values.signature + payload = values.payload + content = SignedPubKeyPayload.model_validate_json(payload) check_wallet_signature_or_raise(content.address, content.chain, payload, signature) return values @property def content(self) -> SignedPubKeyPayload: """Return the content of the header""" - return SignedPubKeyPayload.parse_raw(self.payload) + return SignedPubKeyPayload.model_validate_json(self.payload) class SignedOperationPayload(BaseModel): @@ -132,7 +135,8 @@ class SignedOperationPayload(BaseModel): path: str # body_sha256: str # disabled since there is no body - @validator("time") + @field_validator("time") + @classmethod def time_is_current(cls, v: datetime.datetime) -> datetime.datetime: """Check that the time is current and the payload is not a replay attack.""" max_past = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(minutes=2) @@ -152,7 +156,8 @@ class SignedOperation(BaseModel): signature: bytes payload: bytes - @validator("signature") + @field_validator("signature") + @classmethod def signature_must_be_hex(cls, v) -> bytes: """Convert the signature from hexadecimal to bytes""" try: @@ -162,17 +167,18 @@ def signature_must_be_hex(cls, v) -> bytes: logger.warning(v) raise error - @validator("payload") + @field_validator("payload") + @classmethod def payload_must_be_hex(cls, v) -> bytes: """Convert the payload from hexadecimal to bytes""" v = bytes.fromhex(v.decode()) - _ = SignedOperationPayload.parse_raw(v) + _ = SignedOperationPayload.model_validate_json(v) return v @property def content(self) -> SignedOperationPayload: """Return the content of the header""" - return SignedOperationPayload.parse_raw(self.payload) + return SignedOperationPayload.model_validate_json(self.payload) def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader: @@ -182,29 +188,30 @@ def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader: raise web.HTTPBadRequest(reason="Missing X-SignedPubKey header") try: - return SignedPubKeyHeader.parse_raw(signed_pubkey_header) + data = json.loads(signed_pubkey_header) + if "expires" in data and isinstance(data["expires"], float): + data["expires"] = str(data["expires"]) + return SignedPubKeyHeader.model_validate_json(json.dumps(data)) except KeyError as error: logger.debug(f"Missing X-SignedPubKey header: {error}") raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey fields") from error except json.JSONDecodeError as error: raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey format") from error - except ValueError as errors: + except ValidationError as errors: logging.debug(errors) - for err in errors.args[0]: - if isinstance(err.exc, json.JSONDecodeError): - raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey format") from errors - if str(err.exc) == "Token expired": + for err in errors.errors(): + if err["type"] == "value_error" and "Token expired" in str(err["msg"]): raise web.HTTPUnauthorized(reason="Token expired") from errors - if str(err.exc) == "Invalid signature": + elif err["type"] == "value_error" and "Invalid signature" in str(err["msg"]): raise web.HTTPUnauthorized(reason="Invalid signature") from errors - raise errors + raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey data") def get_signed_operation(request: web.Request) -> SignedOperation: """Get the signed operation public key that is signed by the ephemeral key from the request headers.""" try: signed_operation = request.headers["X-SignedOperation"] - return SignedOperation.parse_raw(signed_operation) + return SignedOperation.model_validate_json(signed_operation) except KeyError as error: raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error except json.JSONDecodeError as error: @@ -248,8 +255,8 @@ async def authenticate_websocket_message(message) -> str: """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" if not isinstance(message, dict): raise Exception("Invalid format for auth packet, see /doc/operator_auth.md") - signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) - signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) + signed_pubkey = SignedPubKeyHeader.model_validate(message["X-SignedPubKey"]) + signed_operation = SignedOperation.model_validate(message["X-SignedOperation"]) if signed_operation.content.domain != settings.DOMAIN_NAME: logger.debug(f"Invalid domain '{signed_operation.content.domain}' != '{settings.DOMAIN_NAME}'") raise web.HTTPUnauthorized(reason="Invalid domain") diff --git a/src/aleph/vm/orchestrator/views/operator.py b/src/aleph/vm/orchestrator/views/operator.py index b808e94fe..e811f73c3 100644 --- a/src/aleph/vm/orchestrator/views/operator.py +++ b/src/aleph/vm/orchestrator/views/operator.py @@ -362,7 +362,7 @@ async def operate_confidential_inject_secret(request: web.Request, authenticated """ try: data = await request.json() - params = InjectSecretParams.parse_obj(data) + params = InjectSecretParams.model_validate(data) except json.JSONDecodeError: return web.HTTPBadRequest(reason="Body is not valid JSON") except pydantic.ValidationError as error: diff --git a/src/aleph/vm/pool.py b/src/aleph/vm/pool.py index cacc0a077..12ceeb456 100644 --- a/src/aleph/vm/pool.py +++ b/src/aleph/vm/pool.py @@ -16,7 +16,7 @@ Payment, PaymentType, ) -from pydantic import parse_raw_as +from pydantic import TypeAdapter from aleph.vm.conf import settings from aleph.vm.controllers.firecracker.snapshot_manager import SnapshotManager @@ -283,7 +283,9 @@ async def load_persistent_executions(self): if execution.is_running: # TODO: Improve the way that we re-create running execution # Load existing GPUs assigned to VMs - execution.gpus = parse_raw_as(list[HostGPU], saved_execution.gpus) if saved_execution.gpus else [] + execution.gpus = ( + TypeAdapter(list[HostGPU]).validate_python(saved_execution.gpus) if saved_execution.gpus else [] + ) # Load and instantiate the rest of resources and already assigned GPUs await execution.prepare() if self.network: diff --git a/src/aleph/vm/resources.py b/src/aleph/vm/resources.py index 98b317865..fe9276568 100644 --- a/src/aleph/vm/resources.py +++ b/src/aleph/vm/resources.py @@ -1,8 +1,9 @@ import subprocess from enum import Enum +from typing import Optional from aleph_message.models import HashableModel -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, ConfigDict, Field from aleph.vm.orchestrator.utils import get_compatible_gpus @@ -12,9 +13,7 @@ class HostGPU(BaseModel): pci_host: str = Field(description="GPU PCI host address") supports_x_vga: bool = Field(description="Whether the GPU supports x-vga QEMU parameter", default=True) - - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class GpuDeviceClass(str, Enum): @@ -28,7 +27,7 @@ class GpuDevice(HashableModel): """GPU properties.""" vendor: str = Field(description="GPU vendor name") - model: str | None = Field(description="GPU model name on Aleph Network") + model: Optional[str] = Field(default=None, description="GPU model name on Aleph Network") device_name: str = Field(description="GPU vendor card name") device_class: GpuDeviceClass = Field( description="GPU device class. Look at https://admin.pci-ids.ucw.cz/read/PD/03" @@ -47,8 +46,7 @@ def has_x_vga_support(self) -> bool: """ return self.device_class == GpuDeviceClass.VGA_COMPATIBLE_CONTROLLER - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class CompatibleGPU(BaseModel): diff --git a/src/aleph/vm/storage.py b/src/aleph/vm/storage.py index 5d0ed3cd8..22425a938 100644 --- a/src/aleph/vm/storage.py +++ b/src/aleph/vm/storage.py @@ -136,7 +136,7 @@ async def get_latest_amend(item_hash: str) -> str: if settings.FAKE_DATA_PROGRAM: return item_hash else: - url = f"{settings.CONNECTOR_URL}/compute/latest_amend/{item_hash}" + url = f"{settings.CONNECTOR_URL}compute/latest_amend/{item_hash}" async with aiohttp.ClientSession() as session: resp = await session.get(url) resp.raise_for_status() @@ -154,7 +154,7 @@ async def get_message(ref: str) -> ProgramMessage | InstanceMessage: logger.debug("Using the fake data message") else: cache_path = (Path(settings.MESSAGE_CACHE) / ref).with_suffix(".json") - url = f"{settings.CONNECTOR_URL}/download/message/{ref}" + url = f"{settings.CONNECTOR_URL}download/message/{ref}" await download_file(url, cache_path) with open(cache_path) as cache_file: @@ -190,7 +190,7 @@ async def get_code_path(ref: str) -> Path: raise ValueError(msg) cache_path = Path(settings.CODE_CACHE) / ref - url = f"{settings.CONNECTOR_URL}/download/code/{ref}" + url = f"{settings.CONNECTOR_URL}download/code/{ref}" await download_file(url, cache_path) return cache_path @@ -202,7 +202,7 @@ async def get_data_path(ref: str) -> Path: return Path(f"{data_dir}.zip") cache_path = Path(settings.DATA_CACHE) / ref - url = f"{settings.CONNECTOR_URL}/download/data/{ref}" + url = f"{settings.CONNECTOR_URL}download/data/{ref}" await download_file(url, cache_path) return cache_path @@ -223,7 +223,7 @@ async def get_runtime_path(ref: str) -> Path: return Path(settings.FAKE_DATA_RUNTIME) cache_path = Path(settings.RUNTIME_CACHE) / ref - url = f"{settings.CONNECTOR_URL}/download/runtime/{ref}" + url = f"{settings.CONNECTOR_URL}download/runtime/{ref}" if not cache_path.is_file(): # File does not exist, download it @@ -241,7 +241,7 @@ async def get_rootfs_base_path(ref: ItemHash) -> Path: return Path(settings.FAKE_INSTANCE_BASE) cache_path = Path(settings.RUNTIME_CACHE) / ref - url = f"{settings.CONNECTOR_URL}/download/runtime/{ref}" + url = f"{settings.CONNECTOR_URL}download/runtime/{ref}" await download_file(url, cache_path) await chown_to_jailman(cache_path) return cache_path @@ -363,7 +363,7 @@ async def get_existing_file(ref: str) -> Path: return Path(settings.FAKE_DATA_VOLUME) cache_path = Path(settings.DATA_CACHE) / ref - url = f"{settings.CONNECTOR_URL}/download/data/{ref}" + url = f"{settings.CONNECTOR_URL}download/data/{ref}" await download_file(url, cache_path) await chown_to_jailman(cache_path) return cache_path diff --git a/src/aleph/vm/utils/__init__.py b/src/aleph/vm/utils/__init__.py index d8eecad95..9bf30d001 100644 --- a/src/aleph/vm/utils/__init__.py +++ b/src/aleph/vm/utils/__init__.py @@ -24,9 +24,9 @@ def get_message_executable_content(message_dict: dict) -> ExecutableContent: try: - return ProgramContent.parse_obj(message_dict) + return ProgramContent.model_validate(message_dict) except ValueError: - return InstanceContent.parse_obj(message_dict) + return InstanceContent.model_validate(message_dict) def cors_allow_all(function): diff --git a/tests/supervisor/test_gpu_x_vga_support.py b/tests/supervisor/test_gpu_x_vga_support.py index 9081d49f7..d3a62ef08 100644 --- a/tests/supervisor/test_gpu_x_vga_support.py +++ b/tests/supervisor/test_gpu_x_vga_support.py @@ -4,7 +4,7 @@ from aleph.vm.controllers.configuration import QemuGPU from aleph.vm.hypervisors.qemu.qemuvm import QemuVM -from aleph.vm.resources import GpuDevice, GpuDeviceClass, HostGPU +from aleph.vm.resources import GpuDevice, GpuDeviceClass class TestGpuXVgaSupport: diff --git a/tests/supervisor/test_status.py b/tests/supervisor/test_status.py index 0e0449dbf..3197133f0 100644 --- a/tests/supervisor/test_status.py +++ b/tests/supervisor/test_status.py @@ -16,6 +16,7 @@ async def test_check_internet_wrong_result_code(): mock_session.get.return_value.__aenter__.return_value.json = AsyncMock( return_value={"result": 200, "headers": {"Server": "nginx"}} ) + assert await check_internet(mock_session, vm_id) is True mock_session.get.return_value.__aenter__.return_value.json = AsyncMock( diff --git a/tests/supervisor/test_views.py b/tests/supervisor/test_views.py index 246066d2e..4d0438769 100644 --- a/tests/supervisor/test_views.py +++ b/tests/supervisor/test_views.py @@ -25,15 +25,20 @@ async def test_allocation_fails_on_invalid_item_hash(aiohttp_client): response: web.Response = await client.post( "/control/allocations", json={"persistent_vms": ["not-an-ItemHash"]}, headers={"X-Auth-Signature": "test"} ) + assert response.status == 400 - assert await response.json() == [ + + response = await response.json() + for error in response: + error.pop("url", None) + + assert response == [ { - "loc": [ - "persistent_vms", - 0, - ], - "msg": "Could not determine hash type: 'not-an-ItemHash'", - "type": "value_error.unknownhash", + "loc": ["persistent_vms", 0], + "msg": "Value error, Could not determine hash type: 'not-an-ItemHash'", + "type": "value_error", + "ctx": {"error": "Could not determine hash type: 'not-an-ItemHash'"}, + "input": "not-an-ItemHash", }, ] diff --git a/vm_connector/conf.py b/vm_connector/conf.py index d2ee465fc..8164c0320 100644 --- a/vm_connector/conf.py +++ b/vm_connector/conf.py @@ -1,7 +1,7 @@ import logging from typing import NewType -from pydantic import BaseSettings +from pydantic import BaseSettings, ConfigDict logger = logging.getLogger(__name__) @@ -27,10 +27,7 @@ def display(self) -> str: f"{annotation:<17} = {getattr(self, annotation)}" for annotation, value in self.__annotations__.items() ) - class Config: - env_prefix = "ALEPH_" - case_sensitive = False - env_file = ".env" + model_config = ConfigDict(env_prefix="ALEPH_", case_sensitive=False, env_file=".env") # Settings singleton