Skip to content

Commit

Permalink
updates to efs, file system
Browse files Browse the repository at this point in the history
  • Loading branch information
rpmcginty committed Dec 23, 2023
1 parent e125fce commit fdcc708
Show file tree
Hide file tree
Showing 16 changed files with 1,565 additions and 169 deletions.
39 changes: 30 additions & 9 deletions src/aibs_informatics_aws_utils/constants/efs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
from dataclasses import dataclass

EFS_MOUNT_PATH_VAR = "EFS_MOUNT_PATH"

@dataclass
class EFSTag:
key: str
value: str


## EFS Environment Variable Constants
EFS_MOUNT_POINT_PATH_VAR = "EFS_MOUNT_POINT_PATH"
EFS_MOUNT_POINT_ID_VAR = "EFS_MOUNT_POINT_ID"


EFS_MOUNT_POINT_PATH_VAR_PREFIX = "EFS_MOUNT_POINT_PATH_"
EFS_MOUNT_POINT_ID_VAR_PREFIX = "EFS_MOUNT_POINT_ID_"


# ------------------------------------
# Standard Names and Paths for EFS


# fmt: off
Expand All @@ -11,13 +28,17 @@
# fmt: on


@dataclass
class EFSTag:
key: str
value: str
# fmt: off
EFS_ROOT_ACCESS_POINT_NAME = "root"
EFS_SHARED_ACCESS_POINT_NAME = "shared"
EFS_SCRATCH_ACCESS_POINT_NAME = "scratch"
EFS_TMP_ACCESS_POINT_NAME = "tmp"
# fmt: on


EFS_ROOT_ACCESS_POINT_TAG = EFSTag("Name", "root")
EFS_SHARED_ACCESS_POINT_TAG = EFSTag("Name", "shared")
EFS_SCRATCH_ACCESS_POINT_TAG = EFSTag("Name", "scratch")
EFS_TMP_ACCESS_POINT_TAG = EFSTag("Name", "tmp")
# fmt: off
EFS_ROOT_ACCESS_POINT_TAG = EFSTag("Name", EFS_ROOT_ACCESS_POINT_NAME)
EFS_SHARED_ACCESS_POINT_TAG = EFSTag("Name", EFS_SHARED_ACCESS_POINT_NAME)
EFS_SCRATCH_ACCESS_POINT_TAG = EFSTag("Name", EFS_SCRATCH_ACCESS_POINT_NAME)
EFS_TMP_ACCESS_POINT_TAG = EFSTag("Name", EFS_TMP_ACCESS_POINT_NAME)
# fmt: on
21 changes: 7 additions & 14 deletions src/aibs_informatics_aws_utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"get_resource",
"get_session",
"AWSService",
"AWS_REGION_VAR",
]
import logging
import os
Expand All @@ -15,7 +16,7 @@
from typing import TYPE_CHECKING, ClassVar, Generic, Literal, Optional, Pattern, TypeVar, cast

import boto3
from aibs_informatics_core.collections import ValidatedStr
from aibs_informatics_core.models.aws.core import AWSRegion
from aibs_informatics_core.models.aws.iam import IAMArn, UserId
from aibs_informatics_core.utils.decorators import cache
from boto3 import Session
Expand All @@ -31,6 +32,7 @@
from mypy_boto3_ecr import ECRClient
from mypy_boto3_ecs import ECSClient
from mypy_boto3_efs import EFSClient
from mypy_boto3_lambda import LambdaClient
from mypy_boto3_logs import CloudWatchLogsClient
from mypy_boto3_s3 import S3Client, S3ServiceResource
from mypy_boto3_secretsmanager import SecretsManagerClient
Expand All @@ -52,6 +54,7 @@
ECSClient = object
EFSClient = object
GetCallerIdentityResponseTypeDef = dict
LambdaClient = object
S3Client, S3ServiceResource = object, object
SecretsManagerClient = object
SESClient = object
Expand All @@ -71,20 +74,8 @@
# AWS Session / Account / Region utilties
# ----------------------------------------------------------------------------

AWS_REGION_VAR = "AWS_REGION"

AWS_REGION_PATTERN = re.compile(
r"(us(?:-gov)?|ap|ca|cn|eu|sa)-(central|(?:north|south)?(?:east|west)?)-(\d)"
)
AWS_ACCOUNT_PATTERN = re.compile(r"[\d]{10-12}")


class AWSAccountId(ValidatedStr):
regex_pattern: ClassVar[Pattern] = AWS_ACCOUNT_PATTERN


class AWSRegion(ValidatedStr):
regex_pattern: ClassVar[Pattern] = AWS_REGION_PATTERN
AWS_REGION_VAR = "AWS_REGION"


def get_session(session: Optional[Session] = None) -> Session:
Expand Down Expand Up @@ -197,6 +188,7 @@ def client_error_code_check(client_error: ClientError, *error_codes: str) -> boo
"ecr",
"ecs",
"efs",
"lambda",
"logs",
"s3",
"secretsmanager",
Expand Down Expand Up @@ -303,6 +295,7 @@ class AWSService:
ECR = AWSServiceProvider[ECRClient]("ecr")
ECS = AWSServiceProvider[ECSClient]("ecs")
EFS = AWSServiceProvider[EFSClient]("efs")
LAMBDA = AWSServiceProvider[LambdaClient]("lambda")
LOGS = AWSServiceProvider[CloudWatchLogsClient]("logs")
S3 = AWSServiceAndResourceProvider[S3Client, S3ServiceResource]("s3")
SECRETSMANAGER = AWSServiceProvider[SecretsManagerClient]("secretsmanager")
Expand Down
49 changes: 41 additions & 8 deletions src/aibs_informatics_aws_utils/data_sync/file_system.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

__all__ = ["BaseFileSystem", "LocalFileSystem", "S3FileSystem"]

import errno
Expand All @@ -10,6 +12,7 @@
from typing import Dict, List, Optional, Union

import pytz
from aibs_informatics_core.models.aws.efs import EFSPath
from aibs_informatics_core.models.aws.s3 import S3URI
from aibs_informatics_core.models.base import CustomAwareDateTime, custom_field
from aibs_informatics_core.models.base.model import SchemaModel
Expand All @@ -18,6 +21,7 @@
from aibs_informatics_core.utils.time import BEGINNING_OF_TIME
from aibs_informatics_core.utils.tools.strtools import removeprefix

from aibs_informatics_aws_utils.efs import get_efs_path, get_local_path
from aibs_informatics_aws_utils.s3 import get_s3_resource

logger = get_logger(__name__)
Expand Down Expand Up @@ -159,11 +163,11 @@ def __post_init__(self):

@abstractmethod
def initialize_node(self) -> Node:
raise NotImplementedError() # pragma: no cover
raise NotImplementedError()

@abstractmethod
def refresh(self, **kwargs):
raise NotImplementedError() # pragma: no cover
raise NotImplementedError()

def partition(
self,
Expand Down Expand Up @@ -219,17 +223,16 @@ def partition(

@classmethod
@abstractmethod
def from_path(cls, path: str, **kwargs) -> "BaseFileSystem":
raise NotImplementedError() # pragma: no cover
def from_path(cls, path: str, **kwargs) -> BaseFileSystem:
pass


# TODO: figure out better package to house this and the base class
@dataclass
class LocalFileSystem(BaseFileSystem):
path: Path

def initialize_node(self) -> Node:
return Node(path_part=str(self.path))
return Node(path_part=self.path.as_posix())

def refresh(self, **kwargs):
self.node = self.initialize_node()
Expand Down Expand Up @@ -259,13 +262,34 @@ def refresh(self, **kwargs):
raise ose

@classmethod
def from_path(cls, path: Union[str, Path], **kwargs) -> "LocalFileSystem":
def from_path(cls, path: Union[str, Path], **kwargs) -> LocalFileSystem:
local_path = Path(path)
local_root = LocalFileSystem(path=local_path)
local_root.refresh(**kwargs)
return local_root


@dataclass
class EFSFileSystem(LocalFileSystem):
efs_path: EFSPath

def initialize_node(self) -> Node:
return Node(path_part=self.efs_path)

@classmethod
def from_path(cls, path: Union[str, Path], **kwargs) -> EFSFileSystem:
if isinstance(path, str) and EFSPath.is_valid(path):
efs_path = EFSPath(path)
local_path = get_local_path(efs_path=efs_path)
else:
local_path = Path(path)
efs_path = get_efs_path(local_path=local_path)

efs_root = EFSFileSystem(path=local_path, efs_path=efs_path)
efs_root.refresh(**kwargs)
return efs_root


@dataclass
class S3FileSystem(BaseFileSystem):
"""Generates a FS tree structure of an S3 path with size and object count stats
Expand Down Expand Up @@ -295,8 +319,17 @@ def refresh(self, **kwargs):
)

@classmethod
def from_path(cls, path: str, **kwargs) -> "S3FileSystem":
def from_path(cls, path: str, **kwargs) -> S3FileSystem:
s3_path = S3URI(path)
s3_root = S3FileSystem(bucket=s3_path.bucket, key=s3_path.key)
s3_root.refresh(**kwargs)
return s3_root


def get_file_system(path: Optional[Union[str, Path]]) -> BaseFileSystem:
if isinstance(path, str) and S3URI.is_valid(path):
return S3FileSystem.from_path(path)
elif isinstance(path, str) and EFSPath.is_valid(path):
return EFSFileSystem.from_path(path)
else:
return LocalFileSystem.from_path(path)
Loading

0 comments on commit fdcc708

Please sign in to comment.