Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Context Manager to support tmp volume and other configurable options #7

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
from datetime import timedelta
from pathlib import Path
from typing import List, TypeVar

from aibs_informatics_aws_utils.data_sync.file_system import BaseFileSystem, Node, get_file_system
from aibs_informatics_aws_utils.efs import detect_mount_points, get_local_path
from aibs_informatics_core.models.aws.efs import EFSPath, Path
from aibs_informatics_core.models.aws.efs import EFSPath
from aibs_informatics_core.models.aws.s3 import S3URI
from aibs_informatics_core.utils.file_operations import get_path_size_bytes, remove_path

Expand Down
145 changes: 103 additions & 42 deletions src/aibs_informatics_aws_lambda/handlers/demand/context_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
import re
import sys
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
Expand Down Expand Up @@ -35,6 +35,11 @@
from aibs_informatics_core.utils.os_operations import write_env_file
from aibs_informatics_core.utils.units import BYTES_PER_GIBIBYTE

from aibs_informatics_aws_lambda.handlers.demand.model import (
ContextManagerConfiguration,
EnvFileWriteMode,
)

if TYPE_CHECKING: # pragma: no cover
from mypy_boto3_batch.type_defs import (
EFSVolumeConfigurationTypeDef,
Expand Down Expand Up @@ -99,13 +104,15 @@ class DemandExecutionContextManager:
demand_execution: DemandExecution
scratch_vol_configuration: BatchEFSConfiguration
shared_vol_configuration: BatchEFSConfiguration
tmp_vol_configuration: Optional[BatchEFSConfiguration] = None
configuration: ContextManagerConfiguration = field(default_factory=ContextManagerConfiguration)
env_base: EnvBase = field(default_factory=EnvBase.from_env)

def __post_init__(self):
self._batch_job_builder = None

self.demand_execution = update_demand_execution_parameter_inputs(
self.demand_execution, self.container_shared_path
self.demand_execution, self.container_shared_path, self.configuration.isolate_inputs
)
self.demand_execution = update_demand_execution_parameter_outputs(
self.demand_execution, self.container_working_path
Expand Down Expand Up @@ -136,6 +143,8 @@ def container_tmp_path(self) -> Path:
Returns:
Path: container path for tmp volume
"""
if self.tmp_vol_configuration:
return self.tmp_vol_configuration.mount_point_config.mount_point
return self.scratch_vol_configuration.mount_point_config.as_mounted_path("tmp")

@property
Expand Down Expand Up @@ -193,10 +202,13 @@ def efs_mount_points(self) -> List[MountPointConfiguration]:
Returns:
List[MountPointConfiguration]: list of mount point configurations
"""
return [
mpcs = [
self.scratch_vol_configuration.mount_point_config,
self.shared_vol_configuration.mount_point_config,
]
if self.tmp_vol_configuration:
mpcs.append(self.tmp_vol_configuration.mount_point_config)
return mpcs

@property
def batch_job_builder(self) -> BatchJobBuilder:
Expand All @@ -207,7 +219,11 @@ def batch_job_builder(self) -> BatchJobBuilder:
tmp_path=self.efs_tmp_path,
scratch_mount_point=self.scratch_vol_configuration.mount_point_config,
shared_mount_point=self.shared_vol_configuration.mount_point_config,
tmp_mount_point=self.tmp_vol_configuration.mount_point_config
if self.tmp_vol_configuration
else None,
env_base=self.env_base,
env_file_write_mode=self.configuration.env_file_write_mode,
)
return self._batch_job_builder

Expand All @@ -228,6 +244,8 @@ def pre_execution_data_sync_requests(self) -> List[PrepareBatchDataSyncRequest]:
retain_source_data=True,
require_lock=True,
batch_size_bytes_limit=75 * BYTES_PER_GIBIBYTE, # 75 GiB max
size_only=self.configuration.size_only,
force=self.configuration.force,
)
for param in self.demand_execution.execution_parameters.downloadable_job_param_inputs
]
Expand All @@ -241,12 +259,19 @@ def post_execution_data_sync_requests(self) -> List[PrepareBatchDataSyncRequest]
retain_source_data=False,
require_lock=False,
batch_size_bytes_limit=75 * BYTES_PER_GIBIBYTE, # 75 GiB max
size_only=self.configuration.size_only,
force=self.configuration.force,
)
for param in self.demand_execution.execution_parameters.uploadable_job_param_outputs
]

@classmethod
def from_demand_execution(cls, demand_execution: DemandExecution, env_base: EnvBase):
def from_demand_execution(
cls,
demand_execution: DemandExecution,
env_base: EnvBase,
configuration: Optional[ContextManagerConfiguration] = None,
):
vol_configuration = get_batch_efs_configuration(
env_base=env_base,
container_path=f"/opt/efs{EFS_SCRATCH_PATH}",
Expand All @@ -259,12 +284,15 @@ def from_demand_execution(cls, demand_execution: DemandExecution, env_base: EnvB
access_point_name=EFS_SHARED_ACCESS_POINT_NAME,
read_only=True,
)
tmp_vol_configuration = None

logger.info(f"Using following efs configuration: {vol_configuration}")
return DemandExecutionContextManager(
demand_execution=demand_execution,
scratch_vol_configuration=vol_configuration,
shared_vol_configuration=shared_vol_configuration,
tmp_vol_configuration=tmp_vol_configuration,
configuration=configuration or ContextManagerConfiguration(),
env_base=env_base,
)

Expand All @@ -275,7 +303,7 @@ def from_demand_execution(cls, demand_execution: DemandExecution, env_base: EnvB


def update_demand_execution_parameter_inputs(
demand_execution: DemandExecution, container_shared_path: Path
demand_execution: DemandExecution, container_shared_path: Path, isolate_inputs: bool = False
) -> DemandExecution:
"""Modifies demand execution input destinations with the location of the volume configuration

Expand Down Expand Up @@ -303,20 +331,25 @@ def update_demand_execution_parameter_inputs(
Args:
demand_execution (DemandExecution): Demand execution object to modify (copied)
vol_configuration (BatchEFSConfiguration): volume configuration
isolate_inputs (bool): flag to determine if inputs should be isolated

Returns:
DemandExecution: a demand execution with modified execution parameter inputs
"""

demand_execution = demand_execution.copy()
execution_params = demand_execution.execution_parameters
updated_params = {
param.name: Resolvable(
local=(container_shared_path / sha256_hexdigest(param.remote_value)).as_posix(),
remote=param.remote_value,
)
for param in execution_params.downloadable_job_param_inputs
}
# TODO: we should allow for the ability to specify the local path for the input
updated_params = {}
for param in execution_params.downloadable_job_param_inputs:
if isolate_inputs:
local = container_shared_path / demand_execution.execution_id / param.value
else:
local = container_shared_path / sha256_hexdigest(param.remote_value)
rpmcginty marked this conversation as resolved.
Show resolved Hide resolved

new_resolvable = Resolvable(local=local.as_posix(), remote=param.remote_value)
updated_params[param.name] = new_resolvable

execution_params.update_params(**updated_params)
return demand_execution

Expand Down Expand Up @@ -403,16 +436,20 @@ def get_batch_efs_configuration(

def generate_batch_job_builder(
demand_execution: DemandExecution,
env_base: EnvBase,
working_path: EFSPath,
tmp_path: EFSPath,
scratch_mount_point: MountPointConfiguration,
shared_mount_point: MountPointConfiguration,
env_base: EnvBase,
tmp_mount_point: Optional[MountPointConfiguration] = None,
env_file_write_mode: EnvFileWriteMode = EnvFileWriteMode.ALWAYS,
) -> BatchJobBuilder:
logger.info(f"Constructing BatchJobBuilder instance")

demand_execution = demand_execution.copy()
efs_mount_points = [scratch_mount_point, shared_mount_point]
if tmp_mount_point is not None:
efs_mount_points.append(tmp_mount_point)
logger.info(f"Resolving local paths of working dir = {working_path} and tmp dir = {tmp_path}")
container_working_path = get_local_path(working_path, mount_points=efs_mount_points)
container_tmp_path = get_local_path(tmp_path, mount_points=efs_mount_points)
Expand Down Expand Up @@ -473,42 +510,64 @@ def generate_batch_job_builder(

# If the local environment file is not None, then the file is writable from this local machine
# We will now write a portion of environment variables to files that can be written.
if local_environment_file is not None:
# Steps for writing environment variables to file:
# 1. Identify all environment variables that are not referenced in the command
# if not referenced, then add to environment file.
# 2. Write environment file
# 3. Add environment file to command
ENVIRONMENT_FILE_VAR = "ENVIRONMENT_FILE"

# Step 1:, split environment variables based on reference are referenced in the command
writable_environment = environment.copy()
required_environment: Dict[str, str] = {}
for arg in command + [_ for c in pre_commands for _ in c]:
for match in re.findall(r"\$\{?([\w]+)\}?", arg):
if match in writable_environment:
required_environment[match] = writable_environment.pop(match)

# Add the environment file variable to the required environment variables
environment = required_environment.copy()
environment[ENVIRONMENT_FILE_VAR] = container_environment_file.as_posix()

# Step 2: write to the environment file
local_environment_file.parent.mkdir(parents=True, exist_ok=True)
write_env_file(writable_environment, local_environment_file)

# Finally, add the environment file to the command
pre_commands.append(f". ${{{ENVIRONMENT_FILE_VAR}}}".split(" "))
else:
if local_environment_file is None or env_file_write_mode == EnvFileWriteMode.NEVER:
# If the environment file cannot be written to, then the environment variables are
# passed directly to the container. This is a fallback option and will fail if the
# environment variables are too long.
if local_environment_file is None:
reason = f"Could not write environment variables to file {efs_environment_file_uri}."
else:
reason = "Environment file write mode set to NEVER."

logger.warning(
f"Could not write environment variables to file {efs_environment_file_uri}."
"Environment variables will be passed directly to the container. "
f"{reason} Environment variables will be passed directly to the container. "
"THIS MAY CAUSE THE CONTAINER TO FAIL IF THE ENVIRONMENT VARIABLES "
"ARE LONGER THAN 8192 CHARACTERS!!!"
)

else:
if env_file_write_mode == EnvFileWriteMode.IF_REQUIRED:
env_size = sum([sys.getsizeof(k) + sys.getsizeof(v) for k, v in environment.items()])

if env_size > 8192 * 0.9:
logger.info(
f"Environment variables are too large to pass directly to container (> 90% of 8192). "
f"Writing environment variables to file {efs_environment_file_uri}."
)
confirm_write = True
else:
confirm_write = False
Comment on lines +534 to +544
Copy link
Collaborator

@njmei njmei Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of prefer EnvFileWriteMode.ALWAYS to be the default and maybe just getting rid of EnvFileWriteMode.IF_REQUIRED as an optional entirely maybe? For troubleshooting, it may end up being more confusing if some jobs have environment variables presented in one way, while other jobs have it another just based solely on if their total size goes over the arbitrary AWS 8192 limit...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rpmcginty Let's defer this decision. Maybe just add a TODO: revisit whether to remove IF_REQUIRED and only support ALWAYS or NEVER

elif env_file_write_mode == EnvFileWriteMode.ALWAYS:
logger.info(f"Writing environment variables to file {efs_environment_file_uri}.")
confirm_write = True

if confirm_write:
# Steps for writing environment variables to file:
# 1. Identify all environment variables that are not referenced in the command
# if not referenced, then add to environment file.
# 2. Write environment file
# 3. Add environment file to command
ENVIRONMENT_FILE_VAR = "_ENVIRONMENT_FILE"

# Step 1:, split environment variables based on reference are referenced in the command
writable_environment = environment.copy()
required_environment: Dict[str, str] = {}
for arg in command + [_ for c in pre_commands for _ in c]:
for match in re.findall(r"\$\{?([\w]+)\}?", arg):
if match in writable_environment:
required_environment[match] = writable_environment.pop(match)

# Add the environment file variable to the required environment variables
environment = required_environment.copy()
environment[ENVIRONMENT_FILE_VAR] = container_environment_file.as_posix()

# Step 2: write to the environment file
local_environment_file.parent.mkdir(parents=True, exist_ok=True)
write_env_file(writable_environment, local_environment_file)

# Finally, add the environment file to the command
pre_commands.append(f". ${{{ENVIRONMENT_FILE_VAR}}}".split(" "))

# ------------------------------------------------------------------

command_string = " && ".join([" ".join(_) for _ in pre_commands + [command]])
Expand All @@ -517,6 +576,8 @@ def generate_batch_job_builder(
BatchEFSConfiguration(scratch_mount_point, read_only=False),
BatchEFSConfiguration(shared_mount_point, read_only=True),
]
if tmp_mount_point:
vol_configurations.append(BatchEFSConfiguration(tmp_mount_point, read_only=False))
logger.info(f"Constructing BatchJobBuilder instance...")
return BatchJobBuilder(
image=demand_execution.execution_image,
Expand Down
29 changes: 27 additions & 2 deletions src/aibs_informatics_aws_lambda/handlers/demand/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from enum import Enum
from typing import Any, Dict, List, Literal, Optional

from aibs_informatics_core.models.base import SchemaModel, custom_field
from aibs_informatics_core.models.base import EnumField, SchemaModel, custom_field
from aibs_informatics_core.models.data_sync import DataSyncRequest
from aibs_informatics_core.models.demand_execution import DemandExecution

Expand Down Expand Up @@ -33,6 +34,26 @@ class DemandFileSystemConfigurations(SchemaModel):
scratch: FileSystemConfiguration = custom_field(
mm_field=FileSystemConfiguration.as_mm_field(), default_factory=FileSystemConfiguration
)
tmp: Optional[FileSystemConfiguration] = custom_field(
mm_field=FileSystemConfiguration.as_mm_field(), default=None
)


class EnvFileWriteMode(str, Enum):
NEVER = "never"
ALWAYS = "always"
IF_REQUIRED = "IF_REQUIRED"


@dataclass
class ContextManagerConfiguration(SchemaModel):
isolate_inputs: bool = custom_field(default=False)
env_file_write_mode: EnvFileWriteMode = custom_field(
mm_field=EnumField(EnvFileWriteMode), default=EnvFileWriteMode.IF_REQUIRED
)
# data sync configurations
force: bool = custom_field(default=False)
size_only: bool = custom_field(default=True)


@dataclass
Expand All @@ -42,6 +63,10 @@ class PrepareDemandScaffoldingRequest(SchemaModel):
mm_field=DemandFileSystemConfigurations.as_mm_field(),
default_factory=DemandFileSystemConfigurations,
)
context_manager_configuration: ContextManagerConfiguration = custom_field(
mm_field=ContextManagerConfiguration.as_mm_field(),
default_factory=ContextManagerConfiguration,
)


@dataclass
Expand Down
19 changes: 18 additions & 1 deletion src/aibs_informatics_aws_lambda/handlers/demand/scaffolding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
EFS_SCRATCH_PATH,
EFS_SHARED_ACCESS_POINT_NAME,
EFS_SHARED_PATH,
EFS_TMP_ACCESS_POINT_NAME,
EFS_TMP_PATH,
)
from aibs_informatics_aws_utils.efs import MountPointConfiguration
from aibs_informatics_core.env import EnvBase
Expand Down Expand Up @@ -56,13 +58,28 @@ def handle(self, request: PrepareDemandScaffoldingRequest) -> PrepareDemandScaff
read_only=True,
)

if request.file_system_configurations.tmp is not None:
tmp_vol_configuration = construct_batch_efs_configuration(
env_base=self.env_base,
file_system=request.file_system_configurations.tmp.file_system,
access_point=request.file_system_configurations.tmp.access_point
if request.file_system_configurations.tmp.access_point
else EFS_TMP_ACCESS_POINT_NAME,
container_path=request.file_system_configurations.tmp.container_path
if request.file_system_configurations.tmp.container_path
else f"/opt/efs{EFS_TMP_PATH}",
read_only=False,
)
else:
tmp_vol_configuration = None

context_manager = DemandExecutionContextManager(
demand_execution=request.demand_execution,
scratch_vol_configuration=scratch_vol_configuration,
shared_vol_configuration=shared_vol_configuration,
tmp_vol_configuration=tmp_vol_configuration,
env_base=self.env_base,
)

batch_job_builder = context_manager.batch_job_builder

self.setup_file_system(context_manager)
Expand Down
Loading
Loading