Skip to content

Commit

Permalink
Update mangrove.py
Browse files Browse the repository at this point in the history
Signed-off-by: Subhransu Bhattacharjee <[email protected]>
  • Loading branch information
1ssb committed Aug 28, 2023
1 parent 222b30f commit 2a7b324
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions mangroves/mangrove.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def add_data(self, depth: int, data_type: Type, var: List[str], value: Optional[
for i, v in enumerate(var):
if v in self.data:
raise MangroveException(f"Variable name {v} is already in use.")

val = value[i] if value else None
self.data[v] = val
self.types[v] = data_type
Expand All @@ -52,7 +52,7 @@ def summary(self) -> Dict[str, Union[Dict[str, Union[int, Type]], Dict[str, Type
unconfigured_vars[name] = dtype
else:
configured_vars[name] = {"depth": depth, "type": dtype}

return {'configured': configured_vars, 'unconfigured (depth 0)': unconfigured_vars}

def __setattr__(self, name: str, value: Any) -> None:
Expand Down Expand Up @@ -82,7 +82,7 @@ def var(self, depth: Optional[int] = None, data_type: Optional[Type] = None) ->
def index(self, depth: Optional[int] = None, data_type: Optional[Type] = None) -> Dict[str, Any]:
# Retrieve variables based on optional depth and type filters
return {name: value for name, value in self.data.items() if (depth is None or depth == self.levels[name]) and (data_type is None or data_type == self.types[name])}

def push(self, depth: int, var_name: str) -> None:
# Move a variable from depth 0 to a different depth
if var_name not in self.data:
Expand All @@ -103,3 +103,25 @@ def tocuda(self, depth: Optional[int] = None, data_type: Optional[Type] = None)
self.data[name] = value.cuda()
else:
raise MangroveException("A CUDA-enabled GPU is not available on this device. If nvidia-smi command returns correctly, check for compatibility of nvcc version.")

def shift(self, to: int, variable_name: str) -> None:
# Shift a variable to another depth if that depth is configured for the data type
if variable_name not in self.data:
raise MangroveException(f"Variable {variable_name} does not exist.")

# Special case: Shifting to depth 0 is always allowed
if to == 0:
self.levels[variable_name] = 0
return

# Check if the destination depth is configured
if to not in self.depths:
raise MangroveException(f"Depth {to} is not configured. Please configure the depth first.")

# Check if the data type of the variable is allowed at the destination depth
data_type = self.types[variable_name]
if data_type not in self.depths[to]:
raise MangroveException(f"Type {data_type} is not allowed at depth {to}.")

# Perform the shift
self.levels[variable_name] = to

0 comments on commit 2a7b324

Please sign in to comment.