Skip to content

Commit

Permalink
Tests updated
Browse files Browse the repository at this point in the history
  • Loading branch information
1ssb committed Aug 28, 2023
1 parent ec4e699 commit 618447d
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 17 deletions.
2 changes: 2 additions & 0 deletions build/lib/mangroves/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .mangrove import Mangrove, MangroveException

127 changes: 127 additions & 0 deletions build/lib/mangroves/mangrove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
from typing import List, Union, Type, Any, Dict, Optional

class MangroveException(Exception):
"""Custom Exception for the Mangrove class."""
pass

class Mangrove:
__slots__ = ["depths", "data", "types", "levels"]

def __init__(self) -> None:
# Initialize with depth 0 pre-configured for basic types
self.depths: Dict[int, List[Type]] = {0: [int, float, str, torch.Tensor]}
self.data: Dict[str, Any] = {}
self.types: Dict[str, Type] = {}
self.levels: Dict[str, int] = {}

def config(self, depth: int, types: List[Type]) -> None:
# Configure a new depth level with allowed types
if depth == 0:
raise MangroveException("Depth 0 is pre-configured and cannot be modified.")
self.depths[depth] = types

def add_data(self, depth: int, data_type: Type, var: List[str], value: Optional[List[Any]] = None) -> None:
# Add data to a specific depth level
if len(var) != len(value):
raise MangroveException("Length of variable names and values must match.")

if depth not in self.depths:
raise MangroveException("Depth not configured. Please configure the depth first.")

if data_type not in self.depths[depth]:
raise MangroveException(f"Type {data_type} is not allowed at depth {depth}.")

for i, v in enumerate(var):
if v in self.data:
raise MangroveException(f"Variable name {v} is already in use.")

val = value[i] if value else None
self.data[v] = val
self.types[v] = data_type
self.levels[v] = depth

def summary(self) -> Dict[str, Union[Dict[str, Union[int, Type]], Dict[str, Type]]]:
# Generate a summary of all variables and their configurations
configured_vars = {}
unconfigured_vars = {}

for name, depth in self.levels.items():
dtype = self.types[name]
if depth == 0:
unconfigured_vars[name] = dtype
else:
configured_vars[name] = {"depth": depth, "type": dtype}

return {'configured': configured_vars, 'unconfigured (depth 0)': unconfigured_vars}

def __setattr__(self, name: str, value: Any) -> None:
# Overridden to allow setting value of already-defined variables
if name in self.__slots__:
object.__setattr__(self, name, value)
elif name in self.data:
dtype = self.types[name]
if isinstance(value, dtype):
self.data[name] = value
else:
raise MangroveException(f"Value must be of type {dtype}.")
else:
raise MangroveException(f"No such attribute: {name}")

def __getattr__(self, name: str) -> Any:
# Overridden to allow getting value of already-defined variables
if name in self.data:
return self.data[name]
else:
raise MangroveException(f"No such attribute: {name}")

def var(self, depth: Optional[int] = None, data_type: Optional[Type] = None) -> List[str]:
# Retrieve variable names based on optional depth and type filters
return [name for name in self.data.keys() if (depth is None or depth == self.levels[name]) and (data_type is None or data_type == self.types[name])]

def index(self, depth: Optional[int] = None, data_type: Optional[Type] = None) -> Dict[str, Any]:
# Retrieve variables based on optional depth and type filters
return {name: value for name, value in self.data.items() if (depth is None or depth == self.levels[name]) and (data_type is None or data_type == self.types[name])}

def push(self, depth: int, var_name: str) -> None:
# Move a variable from depth 0 to a different depth
if var_name not in self.data:
raise MangroveException(f"Variable {var_name} does not exist.")

if self.levels[var_name] != 0:
raise MangroveException(f"{var_name} is not at depth 0. Cannot push.")

self.levels[var_name] = depth

def tocuda(self, depth: Optional[int] = None, data_type: Optional[Type] = None) -> None:
# Move tensor variables to CUDA, if available
if torch.cuda.is_available():
for name in self.data.keys():
if (depth is None or depth == self.levels[name]) and (data_type is None or data_type == self.types[name]):
value = self.data[name]
if isinstance(value, torch.Tensor):
self.data[name] = value.cuda()
else:
raise MangroveException("A CUDA-enabled GPU is not available on this device. If nvidia-smi command returns correctly, check for compatibility of nvcc version.")

def shift(self, to: int, variable_name: str) -> None:
# Shift a variable to another depth if that depth is configured for the data type
if variable_name not in self.data:
raise MangroveException(f"Variable {variable_name} does not exist.")

# Special case: Shifting to depth 0 is always allowed
if to == 0:
self.levels[variable_name] = 0
return

# Check if the destination depth is configured
if to not in self.depths:
raise MangroveException(f"Depth {to} is not configured. Please configure the depth first.")

# Check if the data type of the variable is allowed at the destination depth
data_type = self.types[variable_name]
if data_type not in self.depths[to]:
raise MangroveException(f"Type {data_type} is not allowed at depth {to}.")

# Perform the shift
self.levels[variable_name] = to
30 changes: 13 additions & 17 deletions use/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,35 +82,31 @@ def test_push_variable_not_at_depth_zero(self):
m.add_data(1, int, ["x"], [1])
with self.assertRaises(MangroveException):
m.push(1, "x")

def test_tocuda(self):
m = Mangrove()
m.config(1, [torch.Tensor])
m.add_data(1, torch.Tensor, ["t"], [torch.zeros(5)])
m.tocuda()
self.assertTrue(m.t.is_cuda)

def test_shift(self):
m = Mangrove()
m.config(1, [int, torch.Tensor])
m.config(2, [int])
m.add_data(1, int, ["x"], [1])

# Test shifting to a compatible depth
m.shift(to=2, variable_name="x")
self.assertEqual(m.levels, {"x": 2})

# Test shifting back to depth 0
m.shift(to=0, variable_name="x")
self.assertEqual(m.levels, {"x": 0})

# Test attempting to shift to an incompatible depth

# Test valid shift operation
m.shift(0, "x")
self.assertEqual(m.levels["x"], 0)

# Test invalid depth
with self.assertRaises(MangroveException):
m.shift(to=1, variable_name="x")

# Test attempting to shift a non-existent variable
m.shift(2, "x")

# Test invalid data type at depth
m.config(2, [float])
with self.assertRaises(MangroveException):
m.shift(to=1, variable_name="non_existent_var")
m.shift(2, "x")

if __name__ == '__main__':
unittest.main()
Expand Down

0 comments on commit 618447d

Please sign in to comment.