Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
Signed-off-by: Subhransu Bhattacharjee <[email protected]>
  • Loading branch information
1ssb committed Aug 28, 2023
1 parent 4acf41f commit e351201
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion use/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def main():
mangrove = Mangrove()

print("Configuring depth 1...")
mangrove.config(1, [int, float])
mangrove.config(1, [int, float, torch.Tensor])
print("Depth 1 configured.")

print("Adding data to depth 1...")
Expand Down Expand Up @@ -49,8 +49,37 @@ def main():
mangrove.tocuda(data_type=torch.Tensor)
print("All tensor variables moved to CUDA.")

# Retrieving tensor_data from CUDA and adding it to depth 1
print("Retrieving 'tensor_data' back from CUDA...")
tensor_data_cpu = mangrove.tensor_data.cpu()
print("Retrieved 'tensor_data' back from CUDA.")

print("Adding retrieved 'tensor_data' to depth 1...")
mangrove.add_data(1, torch.Tensor, ["retrieved_tensor_data"], [tensor_data_cpu])
print("Retrieved 'tensor_data' added to depth 1.")

print("Adding tensor data to depth 0 (pre-configured)...")
mangrove.add_data(0, torch.Tensor, ["tensor_data"], [torch.tensor([1, 2, 3])])
print("Tensor data added to depth 0.")

# Test shift operation
print("Shifting 'tensor_data' from depth 0 to depth 1...")
mangrove.shift(to=1, variable_name="tensor_data")
print("'tensor_data' shifted to depth 1.")

print("Attempting to shift 'age' to depth that doesn't support its type...")
try:
mangrove.shift(to=2, variable_name="age")
except MangroveException as e:
print("Expected error occurred:", e)

print("Shifting 'age' back to depth 0...")
mangrove.shift(to=0, variable_name="age")
print("'age' shifted to depth 0.")

except MangroveException as e:
print("An error occurred:", e)

if __name__ == "__main__":
main()

0 comments on commit e351201

Please sign in to comment.