Skip to content

Commit

Permalink
Update mangrove.py
Browse files Browse the repository at this point in the history
Signed-off-by: Subhransu Bhattacharjee <[email protected]>
  • Loading branch information
1ssb committed Aug 28, 2023
1 parent 36afbb6 commit f560bc0
Showing 1 changed file with 40 additions and 60 deletions.
100 changes: 40 additions & 60 deletions mangroves/mangrove.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,38 @@ class Mangrove:
__slots__ = ["depths", "data", "types", "levels"]

def __init__(self) -> None:
# Initialize with depth 0 pre-configured for basic types
self.depths: Dict[int, List[Type]] = {0: [int, float, str, torch.Tensor]}
self.data: Dict[str, Any] = {}
self.types: Dict[str, Type] = {}
self.levels: Dict[str, int] = {}
self.depths = {0: [int, float, str, torch.Tensor]}
self.data = {}
self.types = {}
self.levels = {}

def config(self, depth: int, types: List[Type]) -> None:
# Configure a new depth level with allowed types
if depth == 0:
raise MangroveException("Depth 0 is pre-configured and cannot be modified.")
self.depths[depth] = types

def add_data(self, depth: int, data_type: Type, var: List[str], value: Optional[List[Any]] = None) -> None:
# Add data to a specific depth level
if len(var) != len(value):
raise MangroveException("Length of variable names and values must match.")

if depth not in self.depths:
depth_types = self.depths.get(depth, None)
if depth_types is None:
raise MangroveException("Depth not configured. Please configure the depth first.")

if data_type not in self.depths[depth]:
if data_type not in depth_types:
raise MangroveException(f"Type {data_type} is not allowed at depth {depth}.")

for i, v in enumerate(var):
if v in self.data:
try:
_ = self.data[v]
raise MangroveException(f"Variable name {v} is already in use.")

val = value[i] if value else None
self.data[v] = val
self.types[v] = data_type
self.levels[v] = depth
except KeyError:
val = value[i] if value else None
self.data[v] = val
self.types[v] = data_type
self.levels[v] = depth

def summary(self) -> Dict[str, Union[Dict[str, Union[int, Type]], Dict[str, Type]]]:
# Generate a summary of all variables and their configurations
configured_vars = {}
unconfigured_vars = {}

Expand All @@ -56,72 +54,54 @@ def summary(self) -> Dict[str, Union[Dict[str, Union[int, Type]], Dict[str, Type
return {'configured': configured_vars, 'unconfigured (depth 0)': unconfigured_vars}

def __setattr__(self, name: str, value: Any) -> None:
# Overridden to allow setting value of already-defined variables
if name in self.__slots__:
object.__setattr__(self, name, value)
elif name in self.data:
dtype = self.types[name]
if isinstance(value, dtype):
self.data[name] = value
else:
raise MangroveException(f"Value must be of type {dtype}.")
else:
raise MangroveException(f"No such attribute: {name}")
try:
dtype = self.types[name]
if isinstance(value, dtype):
self.data[name] = value
else:
raise MangroveException(f"Value must be of type {dtype}.")
except KeyError:
raise MangroveException(f"No such attribute: {name}")

def __getattr__(self, name: str) -> Any:
# Overridden to allow getting value of already-defined variables
if name in self.data:
try:
return self.data[name]
else:
except KeyError:
raise MangroveException(f"No such attribute: {name}")

def var(self, depth: Optional[int] = None, data_type: Optional[Type] = None) -> List[str]:
# Retrieve variable names based on optional depth and type filters
return [name for name in self.data.keys() if (depth is None or depth == self.levels[name]) and (data_type is None or data_type == self.types[name])]
return list(name for name in self.data.keys() if (depth is None or depth == self.levels.get(name, None)) and (data_type is None or data_type == self.types.get(name, None)))

def index(self, depth: Optional[int] = None, data_type: Optional[Type] = None) -> Dict[str, Any]:
# Retrieve variables based on optional depth and type filters
return {name: value for name, value in self.data.items() if (depth is None or depth == self.levels[name]) and (data_type is None or data_type == self.types[name])}
return {name: value for name, value in self.data.items() if (depth is None or depth == self.levels.get(name, None)) and (data_type is None or data_type == self.types.get(name, None))}

def push(self, depth: int, var_name: str) -> None:
# Move a variable from depth 0 to a different depth
if var_name not in self.data:
try:
if self.levels[var_name] != 0:
raise MangroveException(f"{var_name} is not at depth 0. Cannot push.")
self.levels[var_name] = depth
except KeyError:
raise MangroveException(f"Variable {var_name} does not exist.")

if self.levels[var_name] != 0:
raise MangroveException(f"{var_name} is not at depth 0. Cannot push.")

self.levels[var_name] = depth

def tocuda(self, depth: Optional[int] = None, data_type: Optional[Type] = None) -> None:
# Move tensor variables to CUDA, if available
if torch.cuda.is_available():
for name in self.data.keys():
if (depth is None or depth == self.levels[name]) and (data_type is None or data_type == self.types[name]):
if (depth is None or depth == self.levels.get(name, None)) and (data_type is None or data_type == self.types.get(name, None)):
value = self.data[name]
if isinstance(value, torch.Tensor):
self.data[name] = value.cuda()
else:
raise MangroveException("A CUDA-enabled GPU is not available on this device. If nvidia-smi command returns correctly, check for compatibility of nvcc version.")

def shift(self, to: int, variable_name: str) -> None:
# Shift a variable to another depth if that depth is configured for the data type
if variable_name not in self.data:
try:
data_type = self.types[variable_name]
if to == 0 or data_type in self.depths.get(to, []):
self.levels[variable_name] = to
else:
raise MangroveException(f"Type {data_type} is not allowed at depth {to}.")
except KeyError:
raise MangroveException(f"Variable {variable_name} does not exist.")

# Special case: Shifting to depth 0 is always allowed
if to == 0:
self.levels[variable_name] = 0
return

# Check if the destination depth is configured
if to not in self.depths:
raise MangroveException(f"Depth {to} is not configured. Please configure the depth first.")

# Check if the data type of the variable is allowed at the destination depth
data_type = self.types[variable_name]
if data_type not in self.depths[to]:
raise MangroveException(f"Type {data_type} is not allowed at depth {to}.")

# Perform the shift
self.levels[variable_name] = to

0 comments on commit f560bc0

Please sign in to comment.