diff --git a/robohive/logger/grouped_datasets.py b/robohive/logger/grouped_datasets.py index 3450af9e..0ba24e65 100644 --- a/robohive/logger/grouped_datasets.py +++ b/robohive/logger/grouped_datasets.py @@ -19,7 +19,6 @@ # access pattern for pickle and h5 backbone post load isn't the same # - Should we get rid of pickle support and double down on h5? # - other way would to make the default container (trace.trace) h5 container instead of a dict -# Should we explicitely keep tract if the trace has been flattened/ stacked/ closed etc? class TraceType(enum.Enum): @@ -32,7 +31,7 @@ def get_type(input_type): """ A more robust way of getting trace type. Supports strings """ - if type(input_type) == str: + if isinstance(input_type, str): if input_type.lower() == "robohive": return TraceType.ROBOHIVE elif input_type.lower() == "roboset": @@ -49,22 +48,23 @@ def __init__(self, name): self.trace = self.root[name] self.index = 0 self.type = TraceType.ROBOHIVE + self.closed = False # False: Trace is open for edits. True: Trace can be analyzed but not edited. # Create a group in your logs def create_group(self, name): self.trace[name] = {} - # Directly add a full dataset to a given group + # Directly add a full dataset to a given group. If data appending is needed, use create_datum() instead def create_dataset(self, group_key, dataset_key, dataset_val): if group_key not in self.trace.keys(): self.create_group(name=group_key) - self.trace[group_key][dataset_key] = [dataset_val] + self.trace[group_key][dataset_key] = dataset_val # Remove dataset from an existing group(s) def remove_dataset(self, group_keys:list, dataset_key:str): - if type(group_keys)==str: + if isinstance(group_keys, str): if group_keys==":": group_keys = self.trace.keys() else: @@ -76,6 +76,13 @@ def remove_dataset(self, group_keys:list, dataset_key:str): del self.trace[group_key][dataset_key] + # Create the first datum of an existing group. Use append_datum() to append more elements + def create_datum(self, group_key, dataset_key, dataset_val): + if group_key not in self.trace.keys(): + self.create_group(name=group_key) + self.trace[group_key][dataset_key] = [dataset_val] + + # Append dataset datum to an existing group def append_datum(self, group_key, dataset_key, dataset_val): assert group_key in self.trace.keys(), "Group:{} does not exist".format(group_key) @@ -114,7 +121,21 @@ def set(self, group_key, dataset_key, dataset_ind=None, dataset_val=None): # verify if a data can be a part of an existing datasets def verify_type(self, dataset, data): dataset_type = type(dataset[0]) - assert type(data) == dataset_type, TypeError("Type mismatch while appending. Datum should be {}".format(dataset_type)) + assert isinstance(data, dataset_type), TypeError("Type mismatch while appending. Datum should be {}".format(dataset_type)) + + # check for array + if isinstance(data, np.ndarray): + assert data.shape == dataset[0].shape, ValueError(f"Data dimenstion({data.shape}) not compatible with dataset dimensions({dataset[0].shape})") + # check for list + if isinstance(data, list): + assert len(data) == len(dataset[0]), ValueError(f"Data dimenstion({len(data)}) not compatible with dataset dimensions({len(dataset[0])})") + # check for dictionary + if isinstance(data, dict): + flattened_data = flatten_dict(data) + flattened_dataset = flatten_dict(dataset[0]) + assert flattened_data.keys() == flattened_dataset.keys(), ValueError(f"Data keys {flattened_data.keys()} not compatible with dataset keys {flattened_dataset.keys()}") + for key in flattened_data: + assert np.array(flattened_data[key]).shape == np.array(flattened_dataset[key]).shape, ValueError(f"Data dimension for key '{key}' ({np.array(flattened_data[key]).shape}) not compatible with dataset dimensions ({np.array(flattened_dataset[key]).shape})") # Verify that all datasets in each groups are of same length. Helpful for time synced traces @@ -126,11 +147,14 @@ def verify_len(self): trace_len = len(self.trace[grp_k][key]) else: key_len = len(self.trace[grp_k][key]) - assert trace_len == key_len, ValueError("len({}[{}]={}, should be {}".format(grp_k, key, key_len, trace_len)) + assert trace_len == key_len, ValueError("Dataset length mismatch: len({}[{}]={}, should be {}".format(grp_k, key, key_len, trace_len)) # Very if trace is stacked and flattened. Useful for utilities like render, save etc def verify_stacked_flattened(self): + if self.closed: + True + for grp_k, grp_v in self.trace.items(): for dst_k, dst_v in grp_v.items(): # Check if stacked @@ -141,6 +165,120 @@ def verify_stacked_flattened(self): return False return True + # plot data + def plot(self, output_dir, output_format, groups:list, datasets:list, x_dataset:str='time'): + # Plot dataset traces using the groups and datasets keys list. T + # ARGUMENTS: + # output_dir: path for output + # output_format: pdf/png/None(for onscreen) + # groups: - list(Groups)_ng to plot: + # - ng = len(groups) == number of subplots + # - ":" to consider each group once + # - None entry in the list can will leave the subplot empty + # datasets: - list(list(datasets))_ng to plot([['left',], ['right', 'top']]), + # - ng = len(groups) == len(datasets) == number of subplots + # - ":" to plot each dataset once + # x_dataset: - dataset key to use as x-axis if available + # EXAMPLES + # 1. plot(..., groups=":", data=":") + # produces a plot with len(groups) subplots + # 2. plot(...,groups=['traj1', 'traj1', 'traj2'], data=[['qpos'], ['qpos','qvel'], ['qvel']]) + # produces a plot with three subplots + + if not self.closed: + prompt("Trace is still open for edits. Close the trace to enable plotting", type=Prompt.WARN) + return + + import matplotlib as mpl + # mpl.use('Agg') + import matplotlib.pyplot as plt + plt.rcParams.update({'font.size': 5}) + h_fig = plt.figure(self.name) + plt.clf() + + # Resolve groups + if isinstance(groups, str) and groups==":": + groups = list(self.trace.keys()) + elif isinstance(groups, str): + groups = [groups] + else: + assert isinstance(groups, list), TypeError(f"Expected a list of groups. Got {groups}") + + # number of subplots + n_subplot = len(groups) + + # Check for datasets + if isinstance(datasets, str) and datasets==":": + datasets = n_subplot*[":"] + elif isinstance(datasets, str): + datasets = [datasets] + else: + assert (isinstance(datasets, list)), TypeError(f"Dataset keys needs to be a list. Got {datasets}") + + # Check for group and datasets sizes + assert len(datasets)==n_subplot, ValueError(f"len(groups):{n_subplot} has to match len(datasets):{len(datasets)}") + # print(groups) + # print(datasets) + + # Run through all groups + for i_grp, grp_key in enumerate(groups): + + # Leave empty if requested + if grp_key is None: + continue + + # process group / subplot + assert isinstance(grp_key, str), TypeError(f"Dataset key needs to be a string. Got {grp_key}") + assert grp_key in self.trace.keys(), "Unknown group {}. Available groups {}".format(grp_key, self.trace.keys()) + grp_val = self.trace[grp_key] + + # print('selected group', grp_key) + + # Resolve datasets within existing group + if isinstance(datasets, str) and datasets==":": + i_grp_datasets = list(grp_val.keys()) + elif isinstance(datasets[i_grp], str) and datasets[i_grp]==":": + i_grp_datasets = list(grp_val.keys()) + else: + i_grp_datasets = datasets[i_grp] + + assert isinstance(i_grp_datasets, list) and isinstance(i_grp_datasets[0], str), TypeError(f"Unrecognized dataset input for group:{grp_key}. Expected ':', or a list from {grp_val.keys()}. Got: {i_grp_datasets}") + + # Run through all dataset requests within the group + for ds_key in i_grp_datasets: + assert ds_key in grp_val.keys(), f"Group: {grp_key} :> Unknown dataset {ds_key}. Available datasets {grp_val.keys()}" + ds_val = grp_val[ds_key] + + assert isinstance(ds_val, np.ndarray), ValueError(f"Dataset for plotting needs to be an array. Provided data:{ds_val}, type:{type(ds_val)}") + assert np.issubdtype(ds_val.dtype, np.number), ValueError(f"Dataset for plotting needs to of numerical dtype. Provided dtype: {ds_val.dtype}") + assert len(ds_val.shape)<3, ValueError(f"Plotting is only supported for 1D and 2D Dataset. Provided data dims: {ds_val.shape}") + + # print(f"g:{grp_key}/ d:{ds_key}") + h_axis = plt.subplot(n_subplot, 1, i_grp+1) + # h_axis.set_prop_cycle(None) + + if x_dataset in grp_val.keys(): + plt.plot(grp_val[x_dataset][:], ds_val, label=f"{ds_key}", marker='') + h_axis.set_xlabel(x_dataset) + elif x_dataset in self.trace.keys(): + plt.plot(self.trace[x_dataset], ds_val, label=f"{ds_key}", marker='') + h_axis.set_xlabel(x_dataset) + else: + plt.plot(ds_val, label=f"{grp_key}/{ds_key}", marker='*') + h_axis.set_title(grp_key) + h_axis.legend() + + + # show/save plot + if output_format is None: + plt.show() + return False + else: + file_name = os.path.join(output_dir, f"{self.name}_{grp_key}_{ds_key}_{output_format}".replace("/", "_")) + # plt.savefig(file_name) + print("saved ", file_name) + return h_fig + # Render frames/videos def render(self, output_dir, output_format, groups:list, datasets:list, input_fps:int=25): @@ -249,10 +387,10 @@ def items(self): return zip(self.trace.keys(), self) # return length - """ - returns the number of groups in the trace - """ def __len__(self) -> str: + """ + returns the number of groups in the trace + """ return len(self.trace.keys()) @@ -337,6 +475,8 @@ def close(self, if verify_length: self.verify_len() + self.closed = True + # Save def save(self, @@ -391,34 +531,109 @@ def load(trace_path, trace_type=TraceType.UNSET): trace.trace = file_data[trace.name] # load data trace.root = file_data # build root trace.trace_type=TraceType.get_type(trace_type) + trace.closed = True return trace +def test_trace_plot(): + trace = Trace("root_name") + + data1 = np.sin(np.arange(0,100)) + data2 = np.cos(np.arange(0,100)) + data3 = np.sin(np.arange(0,200))+np.cos(np.arange(0,200)) + time = 0.01*np.arange(0,200) + + trace.create_group("grp1") + trace.create_dataset(group_key="grp1", dataset_key="dst1", dataset_val=data1) + trace.create_dataset(group_key="grp1", dataset_key="dst2", dataset_val=data2) + + trace.create_group("grp2") + trace.create_dataset(group_key="grp2", dataset_key="time", dataset_val=time) + trace.create_dataset(group_key="grp2", dataset_key="dst3", dataset_val=data3) + trace.close() + + trace.plot(output_format='plot0.pdf', output_dir=".", groups=["grp1",], datasets=[["dst1",],], x_dataset="dst1") + trace.plot(output_format='plot1.pdf', output_dir=".", groups=":", datasets=":") + trace.plot(output_format='plot2.pdf', output_dir=".", groups=":", datasets=[":", ":"]) + trace.plot(output_format='plot3.pdf', output_dir=".", groups=":", datasets=[["dst2",], ":"]) + + # Catch issues plotting string array + try: + trace = Trace("string") + trace.create_dataset(group_key="grp1", dataset_key="dst_k1", dataset_val=np.array(["v1", "v2", "v3"])) + trace.close() + trace.plot(output_format=None, output_dir=".", groups=["grp1",], datasets=[["dst_k1",],]) + except Exception as e: + prompt(f"EXPECTED: Caught exception while trying to plot array of strings: {e}", type=Prompt.WARN) + + # Catch issues plotting list(strings) + try: + trace = Trace("string") + trace.create_dataset(group_key="grp1", dataset_key="dst_k1", dataset_val=["v1", "v2", "v3"]) + trace.close() + trace.plot(output_format='plot4.pdf', output_dir=".", groups=["grp1",], datasets=[["dst_k1",],]) + except Exception as e: + prompt(f"EXPECTED: Caught exception while trying to plot list of strings: {e}", type=Prompt.WARN) + + + # plot complex dicts + trace = Trace("root_dict") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":1, "two":2.0, "three":"3"}) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":11, "two":22.0, "three":"33"}) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":111, "two":222.0, "three":"333"}) + trace.close() + trace.plot(output_format='plot5.pdf', output_dir=".", groups=["grp1"], datasets=[["dst_k1/one","dst_k1/two"],]) + trace.plot(output_format='plot6.pdf', output_dir=".", groups=["grp1","grp1"], datasets=[["dst_k1/one"],["dst_k1/two"]]) + trace.plot(output_format='plot7.pdf', output_dir=".", groups=["grp1","grp1"], datasets=[["dst_k1/one"],["dst_k1/one","dst_k1/two",]]) + trace.plot(output_format='plot8.pdf', output_dir=".", groups=[None, "grp1"], datasets=[None, ["dst_k1/one","dst_k1/two"]]) + # catch trying to plot strings + try: + trace.plot(output_format=None, output_dir=".", groups=["grp1"], datasets=[":"]) + except Exception as e: + prompt(f"EXPECTED: Caught exception while trying to plot dict with strings: {e}", type=Prompt.WARN) + + + # Catch trying to plot >2D array + trace = Trace("root_3darray") + trace.create_group("grp1") + trace.create_dataset(group_key="grp1", dataset_key="dst_k1", dataset_val=np.ones([4, 2, 4])) + trace.close() + try: + trace.plot(output_format='plot9.pdf', output_dir=".", groups=["grp1"], datasets=[["dst_k1",],]) + except Exception as e: + prompt(f"EXPECTED: Caught expected exception during plotting >2D dataset: {e}", type=Prompt.WARN) + + # Test trace def test_trace(): trace = Trace("Root_name") # Create a group: append and verify trace.create_group("grp1") - trace.create_dataset(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v1") trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v11") - trace.create_dataset(group_key="grp1", dataset_key="dst_k2", dataset_val="dst_v2") + trace.create_datum(group_key="grp1", dataset_key="dst_k2", dataset_val="dst_v2") trace.append_datum(group_key="grp1", dataset_key="dst_k2", dataset_val="dst_v22") trace.verify_len() # Add another group trace.create_group("grp2") - trace.create_dataset(group_key="grp2", dataset_key="dst_k3", dataset_val={"dst_v3":[3]}) - trace.create_dataset(group_key="grp2", dataset_key="dst_k4", dataset_val={"dst_v4":[4]}) + trace.create_datum(group_key="grp2", dataset_key="dst_k3", dataset_val={"dst_v3":[3]}) + trace.create_datum(group_key="grp2", dataset_key="dst_k4", dataset_val={"dst_v4":[4]}) print(trace) # get set methods datum = "dst_v111" trace.set('grp1','dst_k1', 0, datum) assert datum == trace.get('grp1','dst_k1', 0), "Get-Set error" - datum = {"dst_v33":[33]} + datum = {"dst_v4":[0]} trace.set('grp2','dst_k4', 0, datum) assert datum == trace.get('grp2','dst_k4', 0), "Get-Set error" + try: + datum = {"dst_diff_name":[33]} + trace.set('grp2','dst_k4', 0, datum) + except Exception as e: + prompt(f"Caught expected exception trying to insert an inconsistent datum: {e}", type=Prompt.WARN) # save-load methods trace.save(trace_name='test_trace.pickle', verify_length=True) @@ -432,10 +647,77 @@ def test_trace(): print("PKL trace") print(pkl_trace) -if __name__ == '__main__': - test_trace() +def test_trace_append(): + # Create a group: append str + trace = Trace("string") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v1") + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v11") + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v111") + # print(trace) + trace.close() + print(trace) + # Create a group: append list(string) + trace = Trace("list(string)") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=["dst_v1","dst_v2"]) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=["dst_v11","dst_v22"]) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=["dst_v111","dst_v222"]) + # print(trace) + trace.close() + print(trace) + # Create a group: append list(float) + trace = Trace("list(float)") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=[1, 2]) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=[11, 22]) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=[111, 222]) + # print(trace) + trace.close(i_res=np.int16) + print(trace) + # Create a group: append dict + trace = Trace("dict") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":1, "two":2.0, "three":"3"}) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":11, "two":22.0, "three":"33"}) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":111, "two":222.0, "three":"333"}) + # print(trace) + trace.close() + print(trace) + # Create a group: append ndarray + trace = Trace("ndarray") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.array([1, 2])) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.array([11, 22])) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.array([111, 222])) + print(trace) + trace.close(i_res=np.int16) + print(trace) + + # Create a group: append ndarray + trace = Trace("ndarray_stack") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.ones([4, 2])) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.zeros([4, 2])) + try: + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.array([11, 22])) + except Exception as e: + prompt(f"Caught expected exception during append_datum: {e}", type=Prompt.WARN) + trace.close(i_res=np.int16) + assert trace['grp1']['dst_k1'].shape==(2, 4 ,2), ValueError("Check ndarray concatenation") + try: + trace.plot(output_format=None, output_dir=".", groups=["grp1"], datasets=[["dst_k1",],]) + except Exception as e: + prompt(f"Caught expected exception during plotting >2D dataset: {e}", type=Prompt.WARN) + + + +if __name__ == '__main__': + test_trace() + test_trace_append() + test_trace_plot() diff --git a/robohive/tests/test_logger.py b/robohive/tests/test_logger.py index d07e0010..7cbf5787 100644 --- a/robohive/tests/test_logger.py +++ b/robohive/tests/test_logger.py @@ -2,18 +2,42 @@ import click import click.testing -from robohive.logger.grouped_datasets import test_trace +from robohive.logger.grouped_datasets import test_trace, test_trace_append, test_trace_plot from robohive.logger.examine_logs import examine_logs from robohive.utils.examine_env import main as examine_env import os import re +import glob class TestTrace(unittest.TestCase): - def teast_trace(self): + def test_trace(self): # Call your function and test its output/assertions print("Testing Trace Basics") test_trace() + def test_trace_append(self): + # Call your function and test its output/assertions + print("Testing Trace complex appends") + test_trace_append() + + def test_trace_plot(self): + # Call your function and test its output/assertions + print("Testing Trace plotting") + test_trace_plot() + # Define the pattern for the files you want to delete + pattern = "./*plot*.pdf" + + # Use glob to find all files matching the pattern + files_to_delete = glob.glob(pattern) + + # Iterate over the list of files and delete each one + for file_path in files_to_delete: + try: + os.remove(file_path) + print(f"Deleted: {file_path}") + except Exception as e: + print(f"Error deleting file {file_path}: {e}") + class TestExamineTrace(unittest.TestCase): def test_logs_playback(self): diff --git a/robohive/utils/prompt_utils.py b/robohive/utils/prompt_utils.py index 0bc06f5f..5d1752f7 100644 --- a/robohive/utils/prompt_utils.py +++ b/robohive/utils/prompt_utils.py @@ -33,7 +33,7 @@ class Prompt(enum.IntEnum): # Infer verbose mode to be used VERBOSE_MODE = os.getenv('ROBOHIVE_VERBOSITY') -if VERBOSE_MODE==None: +if VERBOSE_MODE is None: VERBOSE_MODE = Prompt.WARN else: VERBOSE_MODE = VERBOSE_MODE.upper() diff --git a/robohive/utils/tensor_utils.py b/robohive/utils/tensor_utils.py index 5f54d284..6de21549 100644 --- a/robohive/utils/tensor_utils.py +++ b/robohive/utils/tensor_utils.py @@ -1,6 +1,4 @@ -# Source: https://github.dev/aravindr93/mjrl/tree/master/mjrl -import operator - +# Adapted from Source: https://github.com/aravindr93/mjrl/tree/master/mjrl import numpy as np