Skip to content

Commit

Permalink
Merge pull request #7 from AllenInstitute/feature/add-tmp-vol-config-…
Browse files Browse the repository at this point in the history
…support

Update Context Manager to support tmp volume and other configurable options
  • Loading branch information
rpmcginty authored Jun 14, 2024
2 parents 4dce0d1 + 4b28bfc commit 8af165d
Show file tree
Hide file tree
Showing 5 changed files with 384 additions and 52 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from datetime import timedelta
from pathlib import Path
from typing import List, TypeVar

from aibs_informatics_aws_utils.data_sync.file_system import BaseFileSystem, Node, get_file_system
from aibs_informatics_aws_utils.efs import detect_mount_points, get_local_path
from aibs_informatics_core.models.aws.efs import EFSPath, Path
from aibs_informatics_core.models.aws.efs import EFSPath
from aibs_informatics_core.models.aws.s3 import S3URI
from aibs_informatics_core.utils.file_operations import get_path_size_bytes, remove_path

Expand Down
150 changes: 108 additions & 42 deletions src/aibs_informatics_aws_lambda/handlers/demand/context_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
import re
import sys
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
Expand Down Expand Up @@ -35,6 +35,11 @@
from aibs_informatics_core.utils.os_operations import write_env_file
from aibs_informatics_core.utils.units import BYTES_PER_GIBIBYTE

from aibs_informatics_aws_lambda.handlers.demand.model import (
ContextManagerConfiguration,
EnvFileWriteMode,
)

if TYPE_CHECKING: # pragma: no cover
from mypy_boto3_batch.type_defs import (
EFSVolumeConfigurationTypeDef,
Expand Down Expand Up @@ -99,13 +104,18 @@ class DemandExecutionContextManager:
demand_execution: DemandExecution
scratch_vol_configuration: BatchEFSConfiguration
shared_vol_configuration: BatchEFSConfiguration
tmp_vol_configuration: Optional[BatchEFSConfiguration] = None
configuration: ContextManagerConfiguration = field(default_factory=ContextManagerConfiguration)
env_base: EnvBase = field(default_factory=EnvBase.from_env)

def __post_init__(self):
self._batch_job_builder = None

self.demand_execution = update_demand_execution_parameter_inputs(
self.demand_execution, self.container_shared_path
self.demand_execution,
self.container_shared_path,
self.container_working_path,
self.configuration.isolate_inputs,
)
self.demand_execution = update_demand_execution_parameter_outputs(
self.demand_execution, self.container_working_path
Expand Down Expand Up @@ -136,6 +146,8 @@ def container_tmp_path(self) -> Path:
Returns:
Path: container path for tmp volume
"""
if self.tmp_vol_configuration:
return self.tmp_vol_configuration.mount_point_config.mount_point
return self.scratch_vol_configuration.mount_point_config.as_mounted_path("tmp")

@property
Expand Down Expand Up @@ -193,10 +205,13 @@ def efs_mount_points(self) -> List[MountPointConfiguration]:
Returns:
List[MountPointConfiguration]: list of mount point configurations
"""
return [
mpcs = [
self.scratch_vol_configuration.mount_point_config,
self.shared_vol_configuration.mount_point_config,
]
if self.tmp_vol_configuration:
mpcs.append(self.tmp_vol_configuration.mount_point_config)
return mpcs

@property
def batch_job_builder(self) -> BatchJobBuilder:
Expand All @@ -207,7 +222,11 @@ def batch_job_builder(self) -> BatchJobBuilder:
tmp_path=self.efs_tmp_path,
scratch_mount_point=self.scratch_vol_configuration.mount_point_config,
shared_mount_point=self.shared_vol_configuration.mount_point_config,
tmp_mount_point=self.tmp_vol_configuration.mount_point_config
if self.tmp_vol_configuration
else None,
env_base=self.env_base,
env_file_write_mode=self.configuration.env_file_write_mode,
)
return self._batch_job_builder

Expand All @@ -228,6 +247,8 @@ def pre_execution_data_sync_requests(self) -> List[PrepareBatchDataSyncRequest]:
retain_source_data=True,
require_lock=True,
batch_size_bytes_limit=75 * BYTES_PER_GIBIBYTE, # 75 GiB max
size_only=self.configuration.size_only,
force=self.configuration.force,
)
for param in self.demand_execution.execution_parameters.downloadable_job_param_inputs
]
Expand All @@ -241,12 +262,19 @@ def post_execution_data_sync_requests(self) -> List[PrepareBatchDataSyncRequest]
retain_source_data=False,
require_lock=False,
batch_size_bytes_limit=75 * BYTES_PER_GIBIBYTE, # 75 GiB max
size_only=self.configuration.size_only,
force=self.configuration.force,
)
for param in self.demand_execution.execution_parameters.uploadable_job_param_outputs
]

@classmethod
def from_demand_execution(cls, demand_execution: DemandExecution, env_base: EnvBase):
def from_demand_execution(
cls,
demand_execution: DemandExecution,
env_base: EnvBase,
configuration: Optional[ContextManagerConfiguration] = None,
):
vol_configuration = get_batch_efs_configuration(
env_base=env_base,
container_path=f"/opt/efs{EFS_SCRATCH_PATH}",
Expand All @@ -259,12 +287,15 @@ def from_demand_execution(cls, demand_execution: DemandExecution, env_base: EnvB
access_point_name=EFS_SHARED_ACCESS_POINT_NAME,
read_only=True,
)
tmp_vol_configuration = None

logger.info(f"Using following efs configuration: {vol_configuration}")
return DemandExecutionContextManager(
demand_execution=demand_execution,
scratch_vol_configuration=vol_configuration,
shared_vol_configuration=shared_vol_configuration,
tmp_vol_configuration=tmp_vol_configuration,
configuration=configuration or ContextManagerConfiguration(),
env_base=env_base,
)

Expand All @@ -275,7 +306,10 @@ def from_demand_execution(cls, demand_execution: DemandExecution, env_base: EnvB


def update_demand_execution_parameter_inputs(
demand_execution: DemandExecution, container_shared_path: Path
demand_execution: DemandExecution,
container_shared_path: Path,
container_scratch_path: Path,
isolate_inputs: bool = False,
) -> DemandExecution:
"""Modifies demand execution input destinations with the location of the volume configuration
Expand Down Expand Up @@ -303,20 +337,24 @@ def update_demand_execution_parameter_inputs(
Args:
demand_execution (DemandExecution): Demand execution object to modify (copied)
vol_configuration (BatchEFSConfiguration): volume configuration
isolate_inputs (bool): flag to determine if inputs should be isolated
Returns:
DemandExecution: a demand execution with modified execution parameter inputs
"""

demand_execution = demand_execution.copy()
execution_params = demand_execution.execution_parameters
updated_params = {
param.name: Resolvable(
local=(container_shared_path / sha256_hexdigest(param.remote_value)).as_posix(),
remote=param.remote_value,
)
for param in execution_params.downloadable_job_param_inputs
}
updated_params = {}
for param in execution_params.downloadable_job_param_inputs:
if isolate_inputs:
local = container_scratch_path / demand_execution.execution_id / param.value
else:
local = container_shared_path / sha256_hexdigest(param.remote_value)

new_resolvable = Resolvable(local=local.as_posix(), remote=param.remote_value)
updated_params[param.name] = new_resolvable

execution_params.update_params(**updated_params)
return demand_execution

Expand Down Expand Up @@ -403,16 +441,20 @@ def get_batch_efs_configuration(

def generate_batch_job_builder(
demand_execution: DemandExecution,
env_base: EnvBase,
working_path: EFSPath,
tmp_path: EFSPath,
scratch_mount_point: MountPointConfiguration,
shared_mount_point: MountPointConfiguration,
env_base: EnvBase,
tmp_mount_point: Optional[MountPointConfiguration] = None,
env_file_write_mode: EnvFileWriteMode = EnvFileWriteMode.ALWAYS,
) -> BatchJobBuilder:
logger.info(f"Constructing BatchJobBuilder instance")

demand_execution = demand_execution.copy()
efs_mount_points = [scratch_mount_point, shared_mount_point]
if tmp_mount_point is not None:
efs_mount_points.append(tmp_mount_point)
logger.info(f"Resolving local paths of working dir = {working_path} and tmp dir = {tmp_path}")
container_working_path = get_local_path(working_path, mount_points=efs_mount_points)
container_tmp_path = get_local_path(tmp_path, mount_points=efs_mount_points)
Expand Down Expand Up @@ -473,42 +515,64 @@ def generate_batch_job_builder(

# If the local environment file is not None, then the file is writable from this local machine
# We will now write a portion of environment variables to files that can be written.
if local_environment_file is not None:
# Steps for writing environment variables to file:
# 1. Identify all environment variables that are not referenced in the command
# if not referenced, then add to environment file.
# 2. Write environment file
# 3. Add environment file to command
ENVIRONMENT_FILE_VAR = "ENVIRONMENT_FILE"

# Step 1:, split environment variables based on reference are referenced in the command
writable_environment = environment.copy()
required_environment: Dict[str, str] = {}
for arg in command + [_ for c in pre_commands for _ in c]:
for match in re.findall(r"\$\{?([\w]+)\}?", arg):
if match in writable_environment:
required_environment[match] = writable_environment.pop(match)

# Add the environment file variable to the required environment variables
environment = required_environment.copy()
environment[ENVIRONMENT_FILE_VAR] = container_environment_file.as_posix()

# Step 2: write to the environment file
local_environment_file.parent.mkdir(parents=True, exist_ok=True)
write_env_file(writable_environment, local_environment_file)

# Finally, add the environment file to the command
pre_commands.append(f". ${{{ENVIRONMENT_FILE_VAR}}}".split(" "))
else:
if local_environment_file is None or env_file_write_mode == EnvFileWriteMode.NEVER:
# If the environment file cannot be written to, then the environment variables are
# passed directly to the container. This is a fallback option and will fail if the
# environment variables are too long.
if local_environment_file is None:
reason = f"Could not write environment variables to file {efs_environment_file_uri}."
else:
reason = "Environment file write mode set to NEVER."

logger.warning(
f"Could not write environment variables to file {efs_environment_file_uri}."
"Environment variables will be passed directly to the container. "
f"{reason} Environment variables will be passed directly to the container. "
"THIS MAY CAUSE THE CONTAINER TO FAIL IF THE ENVIRONMENT VARIABLES "
"ARE LONGER THAN 8192 CHARACTERS!!!"
)

else:
if env_file_write_mode == EnvFileWriteMode.IF_REQUIRED:
env_size = sum([sys.getsizeof(k) + sys.getsizeof(v) for k, v in environment.items()])

if env_size > 8192 * 0.9:
logger.info(
f"Environment variables are too large to pass directly to container (> 90% of 8192). "
f"Writing environment variables to file {efs_environment_file_uri}."
)
confirm_write = True
else:
confirm_write = False
elif env_file_write_mode == EnvFileWriteMode.ALWAYS:
logger.info(f"Writing environment variables to file {efs_environment_file_uri}.")
confirm_write = True

if confirm_write:
# Steps for writing environment variables to file:
# 1. Identify all environment variables that are not referenced in the command
# if not referenced, then add to environment file.
# 2. Write environment file
# 3. Add environment file to command
ENVIRONMENT_FILE_VAR = "_ENVIRONMENT_FILE"

# Step 1:, split environment variables based on reference are referenced in the command
writable_environment = environment.copy()
required_environment: Dict[str, str] = {}
for arg in command + [_ for c in pre_commands for _ in c]:
for match in re.findall(r"\$\{?([\w]+)\}?", arg):
if match in writable_environment:
required_environment[match] = writable_environment.pop(match)

# Add the environment file variable to the required environment variables
environment = required_environment.copy()
environment[ENVIRONMENT_FILE_VAR] = container_environment_file.as_posix()

# Step 2: write to the environment file
local_environment_file.parent.mkdir(parents=True, exist_ok=True)
write_env_file(writable_environment, local_environment_file)

# Finally, add the environment file to the command
pre_commands.append(f". ${{{ENVIRONMENT_FILE_VAR}}}".split(" "))

# ------------------------------------------------------------------

command_string = " && ".join([" ".join(_) for _ in pre_commands + [command]])
Expand All @@ -517,6 +581,8 @@ def generate_batch_job_builder(
BatchEFSConfiguration(scratch_mount_point, read_only=False),
BatchEFSConfiguration(shared_mount_point, read_only=True),
]
if tmp_mount_point:
vol_configurations.append(BatchEFSConfiguration(tmp_mount_point, read_only=False))
logger.info(f"Constructing BatchJobBuilder instance...")
return BatchJobBuilder(
image=demand_execution.execution_image,
Expand Down
29 changes: 27 additions & 2 deletions src/aibs_informatics_aws_lambda/handlers/demand/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from enum import Enum
from typing import Any, Dict, List, Literal, Optional

from aibs_informatics_core.models.base import SchemaModel, custom_field
from aibs_informatics_core.models.base import EnumField, SchemaModel, custom_field
from aibs_informatics_core.models.data_sync import DataSyncRequest
from aibs_informatics_core.models.demand_execution import DemandExecution

Expand Down Expand Up @@ -33,6 +34,26 @@ class DemandFileSystemConfigurations(SchemaModel):
scratch: FileSystemConfiguration = custom_field(
mm_field=FileSystemConfiguration.as_mm_field(), default_factory=FileSystemConfiguration
)
tmp: Optional[FileSystemConfiguration] = custom_field(
mm_field=FileSystemConfiguration.as_mm_field(), default=None
)


class EnvFileWriteMode(str, Enum):
NEVER = "NEVER"
ALWAYS = "ALWAYS"
IF_REQUIRED = "IF_REQUIRED"


@dataclass
class ContextManagerConfiguration(SchemaModel):
isolate_inputs: bool = custom_field(default=False)
env_file_write_mode: EnvFileWriteMode = custom_field(
mm_field=EnumField(EnvFileWriteMode), default=EnvFileWriteMode.ALWAYS
)
# data sync configurations
force: bool = custom_field(default=False)
size_only: bool = custom_field(default=True)


@dataclass
Expand All @@ -42,6 +63,10 @@ class PrepareDemandScaffoldingRequest(SchemaModel):
mm_field=DemandFileSystemConfigurations.as_mm_field(),
default_factory=DemandFileSystemConfigurations,
)
context_manager_configuration: ContextManagerConfiguration = custom_field(
mm_field=ContextManagerConfiguration.as_mm_field(),
default_factory=ContextManagerConfiguration,
)


@dataclass
Expand Down
19 changes: 18 additions & 1 deletion src/aibs_informatics_aws_lambda/handlers/demand/scaffolding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
EFS_SCRATCH_PATH,
EFS_SHARED_ACCESS_POINT_NAME,
EFS_SHARED_PATH,
EFS_TMP_ACCESS_POINT_NAME,
EFS_TMP_PATH,
)
from aibs_informatics_aws_utils.efs import MountPointConfiguration
from aibs_informatics_core.env import EnvBase
Expand Down Expand Up @@ -56,13 +58,28 @@ def handle(self, request: PrepareDemandScaffoldingRequest) -> PrepareDemandScaff
read_only=True,
)

if request.file_system_configurations.tmp is not None:
tmp_vol_configuration = construct_batch_efs_configuration(
env_base=self.env_base,
file_system=request.file_system_configurations.tmp.file_system,
access_point=request.file_system_configurations.tmp.access_point
if request.file_system_configurations.tmp.access_point
else EFS_TMP_ACCESS_POINT_NAME,
container_path=request.file_system_configurations.tmp.container_path
if request.file_system_configurations.tmp.container_path
else f"/opt/efs{EFS_TMP_PATH}",
read_only=False,
)
else:
tmp_vol_configuration = None

context_manager = DemandExecutionContextManager(
demand_execution=request.demand_execution,
scratch_vol_configuration=scratch_vol_configuration,
shared_vol_configuration=shared_vol_configuration,
tmp_vol_configuration=tmp_vol_configuration,
env_base=self.env_base,
)

batch_job_builder = context_manager.batch_job_builder

self.setup_file_system(context_manager)
Expand Down
Loading

0 comments on commit 8af165d

Please sign in to comment.