diff --git a/README.md b/README.md index 0a4081b8..0d8351bf 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,15 @@ Ideally you can start your GELLO near a known configuration each time. If this i We have provided a simple example for collecting data with gello. To save trajectories with the keyboard, add the following flag `--use-save-interface` -Data can then be processed using the demo_to_gdict script. +In order to process the data using the demo_to_gdict script, you will need to install the requirements by running +``` +pip install -r requirements_data_process.txt +``` +and then installing `ffmpeg` using your package manager. On Ubuntu, you can run +``` +sudo apt install ffmpeg +``` +Finally, run the demo_to_gdict script. ``` python gello/data_utils/demo_to_gdict.py --source-dir= ``` diff --git a/gello/data_utils/gdict/README.md b/gello/data_utils/gdict/README.md new file mode 100644 index 00000000..837d0415 --- /dev/null +++ b/gello/data_utils/gdict/README.md @@ -0,0 +1,3 @@ +### `gdict`: dictionaries which support {np, torch} slicing. +--- +taken from MS2: https://github.com/haosulab/ManiSkill2-Learn/tree/main/maniskill2_learn/utils \ No newline at end of file diff --git a/gello/data_utils/gdict/__init__.py b/gello/data_utils/gdict/__init__.py new file mode 100644 index 00000000..8d16d005 --- /dev/null +++ b/gello/data_utils/gdict/__init__.py @@ -0,0 +1 @@ +from .data import * \ No newline at end of file diff --git a/gello/data_utils/gdict/data/__init__.py b/gello/data_utils/gdict/data/__init__.py new file mode 100644 index 00000000..73803150 --- /dev/null +++ b/gello/data_utils/gdict/data/__init__.py @@ -0,0 +1,88 @@ +from .array_ops import ( + unsqueeze, + squeeze, + zeros_like, + ones_like, + repeat, + tile, + shuffle, + take, + concat, + stack, + share_memory, + to_item, + select_with_mask, + recover_with_mask, + slice_to_range, + detach, + split, + norm, + normalize, + clip, + arr_sum, + is_pcd, + arr_min, + arr_max, + arr_mean, + batch_shuffle, + batch_perm, + pad_item, + pad_clip, + clip_item, + to_gc, + to_nc, + encode_np, + decode_np, + gather, + to_two_dims, + reshape, + transpose, + contiguous, + split_dim, + slice_item, + sample_and_pad, + batch_index_select, + einsum, + broadcast_to, + deepcopy, + to_float +) +from .compression import float_to_int, int_to_float, f64_to_f32, to_f32, to_f16 +from .converter import as_dtype, to_np, to_torch, to_array, dict_to_str, list_to_str, dict_to_seq, seq_to_dict, slice_to_range, range_to_slice, index_to_slice +from .dict_array import GDict, DictArray, SharedGDict, SharedDictArray, create_smm, delete_smm +from .filtering import filter_none, filter_with_regex +from .misc import equal, SLICE_ALL +from .seq_utils import ( + concat_seq, + concat_list, + concat_tuple, + auto_pad_seq, + flatten_seq, + split_list_of_parameters, + select_by_index, + random_pad_clip_list, +) +from .string_utils import regex_match, custom_format, prefix_match, num_to_str, float_str, regex_replace, any_string, is_regex +from .type_utils import ( + is_str, + is_dict, + is_num, + is_integer, + is_type, + is_seq_of, + is_list_of, + is_tuple_of, + is_iterable, + get_dtype, + is_np, + is_np_arr, + is_torch, + is_arr, + is_slice, + is_torch_distribution, + is_h5, + is_null, + is_not_null, +) +from .dict_utils import update_dict_with_begin_keys, first_dict_key, map_dict_keys, update_dict +from .wrappers import process_input, process_output, seq_to_np diff --git a/gello/data_utils/gdict/data/array_ops.py b/gello/data_utils/gdict/data/array_ops.py new file mode 100644 index 00000000..30862b2a --- /dev/null +++ b/gello/data_utils/gdict/data/array_ops.py @@ -0,0 +1,816 @@ +from io import BytesIO +import random +import sys +import base64 +import numpy as np + +from .converter import range_to_slice, to_np, to_torch, slice_to_range +from .type_utils import get_dtype, is_np, is_np_arr, is_num, is_torch, is_integer, is_torch_distribution, is_not_null, is_arr, is_h5 +from .wrappers import seq_to_np + + +""" Unified API for torch and numpy """ + +# We import torch inside the function to reduce the memory usage when you only want to work with numpy array + + +def to_float(item): + if np.isscalar(item): + item = float(item) + elif isinstance(item, np.ndarray): + item = item.astype(np.float32) + return item + + +def deepcopy(item): + from copy import deepcopy + + if is_np_arr(item): + item = item.copy() + elif is_torch(item): + item = item.clone() + elif not is_h5(item): + item = deepcopy(item) + + return item + + +def unsqueeze(item, axis): + if is_np_arr(item): + # Trick to speed up, expand_dims is very slow .... + if axis == 0: + return item[None] + elif axis == -1: + return item[..., None] + item = np.expand_dims(item, axis) + elif is_torch(item): + if axis == 0: + return item[None] + elif axis == -1: + return item[..., None] + item = item.unsqueeze(axis) + elif is_torch_distribution(item): + item = ops_single_torch_distribution(item, unsqueeze, axis=axis) + return item + + +def squeeze(item, axis): + if is_np_arr(item): + if axis == 0: + return item[0] + elif axis == -1: + return item[..., 0] + + if axis is None: + return np.squeeze(item) + elif item.shape[axis] != 1: + return item + else: + return np.squeeze(item, axis) + elif is_torch(item): + if axis == 0: + return item[0] + elif axis == -1: + return item[..., 0] + + if axis is None: + return item.squeeze() + elif item.shape[axis] != 1: + return item + else: + return item.squeeze(axis) + elif is_torch_distribution(item): + return ops_single_torch_distribution(item, squeeze, axis=axis) + else: + return item + + +def zeros_like(item): + if is_np_arr(item): + return np.zeros_like(item) + elif is_torch(item): + import torch + + return torch.zeros_like(item) + else: + return item + + +def ones_like(item): + if is_np_arr(item): + return np.ones_like(item) + elif is_torch(item): + import torch + + return torch.ones_like(item) + else: + return item + + +def repeat(item, rep, axis=None): + # when axis=0, we will use tile for numpy and repeat for torch + if not (is_np_arr(item) or is_torch(item)): + return item + if is_np_arr(item): + if axis is None: + return np.tile(item, rep) + else: + return np.repeat(item, rep, axis) + else: + import torch + + if axis is None: + return item.repeat(*rep) + else: + return torch.repeat_interleave(item, rep, axis) + + +def tile(item, rep): + if is_integer(rep): + rep = (rep,) + if is_np_arr(item): + return np.tile(item, rep) + elif is_torch(item): + import torch + + return torch.tile(item, rep) + else: + return item + + +def slice_item(item, slice, axis=0): + # Avoid copying the data for numpy array and torch tensor + if is_arr(item) or is_torch(item): + if axis == 0: + ret = item[slice] + elif axis == 1: + ret = item[:, slice] + elif axis == 2: + ret = item[:, :, slice] + elif axis == 3: + ret = item[:, :, :, slice] + else: + raise NotImplementedError("Axis is too large!") + return ret + else: + return item + + +def take(item, indices, axis=None): + # It will copy the data for np array and torch tensor + if isinstance(item, list): + assert axis == 0 and is_num(indices), "For list we only support operation on the first dimension!" + return item[indices] + + # Convert slice and range to np array + if isinstance(indices, slice): + indices = slice_to_range(indices) + if isinstance(indices, range): + indices = list(indices) + indices = np.array(indices, dtype=np.int64) if isinstance(indices, (list, tuple)) else indices + + if is_np_arr(item): + return item.take(indices=indices, axis=axis) + elif is_torch(item): + single = False + if not is_torch(indices): + single = is_integer(indices) + indices = to_torch(indices, device=item.device, non_blocking=True) + if indices.ndim > 1: + new_shape = list(item.shape) + new_shape = new_shape[:axis] + list(indices.shape) + new_shape[axis + 1 :] + ret = item.index_select(index=indices.reshape(-1), dim=axis).reshape(new_shape) + else: + ret = item.index_select(index=indices, dim=axis) + if single: + ret = ret.squeeze(dim=axis) + return ret + elif is_torch_distribution(item): + return ops_single_torch_distribution(item, take, indices=indices, axis=axis) + else: + return item + + +def shuffle(item, axis=0): + if isinstance(item, (list, tuple)): + is_tuple = type(item) == tuple + ret = list(item) if is_tuple else item + random.shuffle(ret) + return tuple(ret) if is_tuple else item + elif is_np_arr(item): + indices = np.random.permutation(item.shape[axis]) + return take(item, indices, axis=axis) + elif is_torch(item): + import torch + + indices = torch.randperm(item.shape[axis], device=item.device) + return take(item, indices, axis=axis) + else: + return item + + +def reshape(item, newshape): + if hasattr(item, "reshape"): + item = item.reshape(newshape) + return item + + +def split_dim(item, axis, newaxes): + if is_arr(item) and len(newaxes) > 1: + und_index = np.where(np.array(newaxes) == -1)[0] + assert len(und_index) <= 1 + if len(und_index) == 1: + und_index = und_index[0] + newaxes[und_index] = 1 + newaxes[und_index] = item.shape[axis] // np.prod(newaxes) + assert np.prod(newaxes) == item.shape[axis] + item_shape = list(item.shape) + item_shape = item_shape[:axis] + newaxes + item_shape[axis + 1 :] + item = item.reshape(item_shape) + elif is_torch_distribution(item) and len(newaxes) > 1: + item = ops_single_torch_distribution(item, split_dim, axis=axis, newaxes=newaxes) + return item + + +def transpose(item, axis0, axis1, contiguous=True): + if is_np_arr(item): + item = np.swapaxes(item, axis0, axis1) + elif is_torch(item): + import torch + + item = torch.transpose(item, axis0, axis1) + if contiguous: + item = item.contiguous() + return item + + +def contiguous(item): + if is_torch(item): + item = item.contiguous() + return item + + +def einsum(subscripts, *items): + items = list(items) + if is_np_arr(items[0]): + return np.einsum(subscripts, *items) + elif is_torch(items[0]): + import torch + + return torch.einsum(subscripts, *items) + return items + + +def concat(item, axis=0): + if len(item) == 1: + return item[0] + elif is_np_arr(item[0]): + return np.concatenate(item, axis=axis) + elif is_torch(item[0]): + import torch + + return torch.cat(item, dim=axis) + elif is_torch_distribution(item[0]): + return concat_torch_distribution(item, axis=axis) + else: + return item + + +def stack(item, axis=0): + if len(item) == 1: + return unsqueeze(item[0], axis) + elif is_np_arr(item[0]): + return np.stack(item, axis=axis) + elif is_torch(item[0]): + import torch + + return torch.stack(item, dim=axis) + else: + return item + + +def share_memory(x, y): + if type(x) != type(y): + return False + elif is_np_arr(x): + ret = x.base is not None and y.base is not None and x.base == y.base + return ret.any() if is_np_arr(ret) else ret + elif is_torch(x): + sign = x.storage().data_ptr() == y.storage().data_ptr() + return sign if isinstance(sign, bool) else sign.any() + else: + if isinstance(x, (int, str, float)): + return False + else: + return id(x) == id(y) + + +def to_cpu(x): + import torch + + if isinstance(x, torch.Tensor): + x = x.cpu() + return x + + +def to_cuda(x, device="cuda"): + import torch + + if isinstance(x, torch.Tensor): + x = x.to(device) + return x + + +def type_as(item, other): + if is_np_arr(item): + return item.astype(other.dtype) + elif is_torch(item): + return item.type_as(other) + else: + return item + + +@seq_to_np(True) +def arr_sum(item, axis=None, keepdim=False, mask=None, dtype=None): + if is_np_arr(item): + item = item if mask is None else (type_as(mask, item) * item) + return np.sum(item, axis, dtype, None, keepdim) + elif is_torch(item): + import torch + + item = item if mask is None else (type_as(mask, item) * item) + if axis is None: + return torch.sum(item) + else: + return torch.sum(item, axis, keepdim, dtype=dtype) + else: + return item + + +@seq_to_np(True) +def arr_mean(item, axis=None, keepdim=False, mask=None, dtype=None, mask_clip=1e-12): + if is_np_arr(item) or is_torch(item): + if mask is None: + if is_np_arr(item): + return np.mean(item, axis, dtype, None, keepdim) + elif is_torch(item): + import torch + + return torch.mean(item, axis, keepdim, dtype=dtype) if axis is not None else torch.mean(item, dtype=dtype) + else: + return arr_sum(item, axis, keepdim, mask) / (arr_sum(mask, axis, keepdim) + mask_clip) + else: + return item + + +# def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue): +# @array_function_dispatch(_amin_dispatcher) +# def amin(a, axis=None, out=None, keepdims=np._NoValue, initial=np._NoValue, +# where=np._NoValue): + + +@seq_to_np(True) +def arr_min(item, axis=None, keepdim=False, mask=None, inf=1e30): + if is_np_arr(item) or is_torch(item): + if mask is not None: + item = item * mask + inf * (1 - mask) # Both torch and numpy can deal with inf + if is_np_arr(item): + return np.min(item, axis, None, keepdim) + else: + import torch + + return torch.min(item, axis, keepdim).values + else: + return item + + +@seq_to_np(True) +def arr_max(item, axis=None, keepdim=False, mask=None, inf=1e30): + if is_np_arr(item) or is_torch(item): + if mask is not None: + item = item * mask + -inf * (1 - mask) # Both torch and numpy can deal with inf + if is_np_arr(item): + return np.max(item, axis, None, keepdim) + else: + import torch + + return torch.max(item, axis, keepdim).values + else: + return item + + +def to_item(item): + if is_np_arr(item): + if item.size == 1: + return item.reshape(-1)[0] + elif is_torch(item): + if item.numel() == 1: + return item.item() + return item + + +def select_with_mask(item, mask): + if is_arr(item): + return item[mask] + else: + return ops_single_torch_distribution(item, select_with_mask, mask=mask) + + +def recover_with_mask(item, mask): + ret_shape = list(mask.shape) + list(item[0].shape) + if is_np_arr(item): + ret = np.zeros(ret_shape, dtype=item.dtype, device=item.device) + else: + import torch + + ret = torch.zeros(*ret_shape, dtype=item.dtype, device=item.device) + ret[mask] = item + return ret + + +def get_nbytes(item): + if is_np_arr(item): + if item.dtype == object: + tmp = item.reshape(-1) + return sum([get_nbytes(i) for i in tmp]) + else: + return item.nbytes + elif is_torch(item): + return item.view(-1).shape[0] * item.element_size() + else: + return sys.getsizeof(item) + + +def split(item, split_size_or_sections, axis=0): + # Use the torch style + if is_np_arr(item): + if is_integer(split_size_or_sections): + num_blocks = int(item.shape[axis] // split_size_or_sections) + tmp = [ + split_size_or_sections, + ] * num_blocks + if split_size_or_sections * num_blocks < item.shape[axis]: + tmp.append(item.shape[axis] - split_size_or_sections * num_blocks) + split_size_or_sections = tmp + elif np.sum(split_size_or_sections) != item.shape[axis]: + split_size_or_sections.append(item.shape[axis] - np.sum(split_size_or_sections)) + split_size_or_sections = np.cumsum(split_size_or_sections) + return np.split(item, split_size_or_sections, axis=axis)[:-1] + elif is_torch(item): + import torch + + return torch.split(item, split_size_or_sections, dim=axis) + return item + + +def norm(item, ord=None, axis=None, keepdim=False): + if is_np_arr(item): + return np.linalg.norm(item, ord, axis, keepdim) + elif is_torch(item): + return item.norm(ord, axis, keepdim) + return item + + +def normalize(item, p=2.0, axis=1, eps=1e-12): + if is_np_arr(item): + return item / np.maximum(norm(item, p, axis, True), eps) + elif is_torch(item): + import torch.nn.functional as F + + return F.normalize(item, p, axis, eps) + return item + + +def clip(item, a_min=None, a_max=None): + if is_np_arr(item): + return np.clip(item, a_min, a_max) + elif is_torch(item): + import torch + + return torch.clamp(item, a_min, a_max) + + +def to_gc(item, dim=None): + """ + To generealized coordinates + dim = 3 means transform 3-dim vectors to 4-dim vectors. + """ + if dim is not None: + assert item.shape[-1] == dim or item.shape[-1] == dim + 1 + if item.shape[-1] == dim + 1: + return item + return concat([item, ones_like(item[..., :1])], axis=-1) + + +def to_nc(item, dim=None): + """ + To normal coordinates + dim = 3 means transform 4-dim vectors to 3-dim vectors. + """ + + if dim is not None: + assert item.shape[-1] == dim or item.shape[-1] == dim + 1 + if item.shape[-1] == dim: + return item + return item[..., :-1] / item[..., -1:] + + +def is_pcd(item, axis=-1): + return item.shape[axis] == 3 + + +def minimum(a, b): + if is_np(a) and is_np(b): + return np.minimum(a, b) + elif is_torch(a) and is_torch(b): + import torch + + return torch.minimum(a, b) + else: + raise ValueError(f"Bad inputs {type(a)} {type(b)}") + + +def broadcast_to(item, shape): + if is_np_arr(item): + return np.broadcast_to(item, shape) + elif is_torch(item): + return item.expand(shape) + return item + + +def expand_as(item, other, exclude_axis=[]): + if is_np_arr(item) or is_torch(item): + assert item.ndim == other.ndim, f"{item.ndim}, {other.ndim}" + other_shape = other.shape + item_shape = item.shape + rep_shape = [(other_shape[i] // item_shape[i] if i not in exclude_axis else 1) for i in range(item.ndim)] + return repeat(item, rep=rep_shape) + return item + + +def gather(item, axis, index): + """ + Refer + https://stackoverflow.com/questions/46065873/how-to-do-scatter-and-gather-operations-in-numpy + """ + if is_np_arr(item): + if item.ndim != index.ndim: + return item + index = expand_as( + index, + item, + [ + axis, + ], + ) + index_xsec_shape = index.shape[:axis] + index.shape[axis + 1 :] + item_xsec_shape = item.shape[:axis] + item.shape[axis + 1 :] + if index_xsec_shape != item_xsec_shape: + raise ValueError(f"Except for dimension {axis}, all dimensions of index and self should be the same size") + data_swaped = np.swapaxes(item, 0, axis) + index_swaped = np.swapaxes(index, 0, axis) + gathered = np.choose(index_swaped, data_swaped) + return np.swapaxes(gathered, 0, axis) + elif is_torch(item): + if item.ndim != index.ndim: + return item + import torch + + return torch.gather( + item, + axis, + expand_as( + index, + item, + [ + axis, + ], + ), + ) + else: + return item + + +def batch_perm(item, axis=1, num_samples=None): + # This is slow for large arries. + if is_np_arr(item) or is_torch(item): + assert axis > 0 + if num_samples is None: + num_samples = item.shape[axis] + num_samples = min(num_samples, item.shape[axis]) + shape = [item.shape[0], item.shape[axis]] + if is_np_arr(item): + index = np.argsort(np.random.rand(*shape), axis) + else: + import torch + + index = torch.rand(*shape, device=item.device).argsort(axis) + index = index[:, :num_samples] + rep = [ + 1, + ] + for i in range(1, axis): + index = index[..., None, :] + rep.append(item.shape[i]) + rep.append(1) + for i in range(axis + 1, item.ndim): + index = index[..., None] + rep.append(item.shape[i]) + index = repeat(index, rep, axis=None) + return index + else: + return item + + +def batch_shuffle(item, axis=1, num_samples=None): + """ + item [B, ...] + For each item in batch, we use independently shuffle the items. + = concat([shuffle(item[i], axis) for i in range(item.shape[0])], axis=0) + """ + if is_np_arr(item) or is_torch(item): + index = batch_perm(item, axis, num_samples) + return gather(item, axis, index) + else: + return item + + +def clip_item(item, num, axis=1): + if (is_np_arr(item) or is_torch(item)) and item.shape[axis] > num: + item = take(item, slice(0, num), axis) + return item + + +def pad_item(item, num, axis=1, pad_value=None): + if (is_np_arr(item) or is_torch(item)) and item.shape[axis] < num: + padded_shape = list(item.shape) + padded_shape[axis] = num - padded_shape[axis] + if is_not_null(pad_value): + if is_np_arr(item): + pad = np.full(padded_shape, pad_value, dtype=item.dtype) + else: + import torch + + pad = torch.ones(padded_shape, dtype=item.dtype, device=item.device) + else: + pad = repeat(take(item, range(1), axis), padded_shape[axis], axis) + item = concat([item, pad], axis) + return item + + +def pad_clip(item, num, axis=1, pad_value=None): + item = pad_item(item, num, axis, pad_value) + item = clip_item(item, num, axis) + return item + + +def to_two_dims(item): + if (is_np_arr(item) or is_torch(item)) and item.ndim == 1: + return item[..., None] + return item + + +def to_list(item): + if is_np_arr(item) or is_torch(item): + item = item.reshape(-1) + item = [item[i] for i in range(item.shape[0])] + return item + + +def allreduce(item, op="MEAN", device="cuda"): + assert op in ["MEAN", "SUM", "AVG", "PRODUCT", "MIN", "MAX", "BAND", "BOR"] # 'BXOR' is not supported for NCLL + """Allreduce items. + # allreduce is a inplaced operation for torch tensor. + Args: + items ([Number, numpy, tensor]): any numbers or tensors. + coalesce (bool, optional): Whether allreduce parameters as a whole. Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. Defaults to -1. + """ + from ..torch import get_dist_info + import torch + from torch import distributed as dist + from gdict.data import as_dtype + + _, world_size = get_dist_info() + if world_size == 1: + return item + + if is_num(item): + data_type = ("number", type(item)) + item = torch.tensor(item, device=device) + elif is_np(item): + data_type = ("np", item.dtype) + item = torch.tensor(item, device=device) + elif is_torch(item): + data_type = ("torch", item.dtype) + item_device = item.device + if item_device != device: + item = item.to(device=device) + else: + return item + if op == "BAND": + op = "PRODUCT" + elif op == "BOR": + op = "SUM" + + item = item.double() + + if op == "MEAN": + dist.all_reduce(item.div_(world_size), op=torch.distributed.ReduceOp.SUM) + else: + dist.all_reduce(item, op=getattr(torch.distributed.ReduceOp, op)) + + if data_type[0] == "number": + item = item > 0.5 if item is bool else data_type[1](item.item()) + elif data_type[0] == "np": + item = item > 0.5 if data_type[1].name == "bool" else as_dtype(item.detach().cpu().numpy(), dtype=data_type[1]) + elif data_type[0] == "torch": + item = item > 0.5 if "bool" in str(data_type[1]) else item.to(dtype=data_type[1]) + if item_device != device: + item = item.to(device=item_device) + return item + + +""" Torch only functions, which means it will works only for torch.Tensor """ + + +def detach(item): + if is_torch(item): + return item.detach() + else: + return item + +def batch_index_select(input, index, axis): + """Batch index_select + + Args: + input (torch.Tensor): [B, ...] + index (torch.Tensor): [B, N] or [B] + dim (int): the dimension to index + + References: + https://discuss.pytorch.org/t/batched-index-select/9115/7 + https://github.com/vacancy/AdvancedIndexing-PyTorch + """ + import torch + + if index.dim() == 1: + index = index.unsqueeze(1) + squeeze_dim = True + else: + assert index.dim() == 2, "index is expected to be 2-dim (or 1-dim), but {} received.".format(index.dim()) + squeeze_dim = False + assert input.size(0) == index.size(0), "Mismatched batch size: {} vs {}".format(input.size(0), index.size(0)) + views = [1 for _ in range(input.dim())] + views[0] = index.size(0) + views[axis] = index.size(1) + expand_shape = list(input.shape) + expand_shape[axis] = -1 + index = index.view(views).expand(expand_shape) + out = torch.gather(input, axis, index) + if squeeze_dim: + out = out.squeeze(1) + return out + + +""" Numpy only functions """ + + +def encode_np(item): + from gdict.file import dump + + item = dump(item, file_format="pkl") + item = base64.binascii.b2a_base64(item) + return item + + +def decode_np(item, dtype=None, shape_template=None): + if is_num(shape_template): + shape_template = (shape_template,) + if isinstance(item, (bytes, np.void, str)): + item = base64.binascii.a2b_base64(item) + from gdict.file import load + + item = load(BytesIO(item), file_format="pkl") + if is_not_null(shape_template): + item = item.reshape(*shape_template) + elif isinstance(item, np.ndarray) and True and get_dtype(item) == "object": + item_shape = item.shape + ret = [decode_np(item_i, dtype, shape_template) for item_i in item.reshape(-1)] + item = np.array(ret, dtype=object) + item = item.reshape(*item_shape) + return item + + +def sample_and_pad(n, num=1200): + index = np.arange(n) + if n == 0: + return np.zeros(num, dtype=np.int64) + if index.shape[0] > num: + np.random.shuffle(index) + index = index[:num] + elif index.shape[0] < num: + num_repeat = num // index.shape[0] + index = np.concatenate([index for i in range(num_repeat)]) + index = np.concatenate([index, index[: num - index.shape[0]]]) + return index diff --git a/gello/data_utils/gdict/data/compression.py b/gello/data_utils/gdict/data/compression.py new file mode 100644 index 00000000..1b84122f --- /dev/null +++ b/gello/data_utils/gdict/data/compression.py @@ -0,0 +1,66 @@ +import numpy as np, base64 +from .dict_array import GDict +from .array_ops import encode_np, decode_np +from .converter import as_dtype +from .type_utils import is_np_arr, get_dtype, is_dict, is_not_null, is_null, is_seq_of + +def float_to_int(data, vrange=[0.0, 1.0], res=None, dtype="uint8"): + data_dtype = get_dtype(data) + if "int" in data_dtype: + return as_dtype(data, dtype) if data_dtype != dtype else data + assert data_dtype.startswith("float"), f"{type(data), data}" + min_v = np.iinfo(getattr(np, dtype)).min + max_v = np.iinfo(getattr(np, dtype)).max + if is_not_null(vrange): + assert vrange[0] < vrange[1] and is_null(res) + data = (np.clip(data, a_min=vrange[0], a_max=vrange[1]) - vrange[0]) / (vrange[1] - vrange[0]) # Normalize value to [0, 1] + data = data * max_v + (1 - data) * min_v + else: + assert is_not_null(res) + data = data / res + + data = as_dtype(np.clip(data, a_min=min_v, a_max=max_v), dtype) + return data + + +def int_to_float(data, vrange=[0.0, 1.0], res=None, *dtype): + data_dtype = get_dtype(data) + if data_dtype == "object": + assert data.shape == (1,) + data = data[0] + elif data_dtype.startswith("float"): + return as_dtype(data, dtype) if data_dtype != dtype else data + + data_dtype = get_dtype(data) + + assert data_dtype.startswith("int") or data_dtype.startswith("uint"), f"{data_dtype}" + min_v = np.float32(np.iinfo(getattr(np, data_dtype)).min) + max_v = np.float32(np.iinfo(getattr(np, data_dtype)).max) + if is_not_null(vrange): + assert vrange[0] < vrange[1] and is_null(res) + data = (data - min_v) / (max_v - min_v) # [0, 1] + data = data * np.float32(vrange[1]) + (1 - data) * np.float32(vrange[0]) + else: + assert is_not_null(res) + res = np.float32(res) + data = data * res + return as_dtype(data, "float32") + + +def f64_to_f32(item): + """ + Convert all float64 data to float32 + """ + from .type_utils import get_dtype + from .converter import as_dtype + + sign = get_dtype(item) in ["float64", "double"] + return as_dtype(item, "float32") if sign else item + + +def to_f32(item): + return as_dtype(item, "float32") + + +def to_f16(item): + return as_dtype(item, "float16") diff --git a/gello/data_utils/gdict/data/converter.py b/gello/data_utils/gdict/data/converter.py new file mode 100644 index 00000000..f3b87487 --- /dev/null +++ b/gello/data_utils/gdict/data/converter.py @@ -0,0 +1,151 @@ +from typing import Optional + +import numpy as np +from numbers import Number +from .misc import equal +from .type_utils import is_np_arr, is_type, get_dtype, is_integer, is_torch, is_np, is_seq_of, is_num + + +""" Convert tensor type """ + + +def as_dtype(item, dtype: str): + if is_np(item): + return item.astype(dtype) + elif is_torch(item): + import torch + + return item.to(getattr(torch, dtype)) + else: + try: + dtype = eval(dtype) + return dtype(item) + except (RuntimeError, NameError, ValueError): + return item + + +def to_torch( + item, use_copy: bool = False, device: Optional[str] = "cpu", non_blocking: bool = False, dtype: Optional[str] = None, requires_grad: bool = False +): + import torch + + dtype_torch_map = {"uint16": "int"} + + item_dtype = get_dtype(item) + same_type = equal(item_dtype, dtype) + if hasattr(device, "type"): + device = f"{device.type}:{device.index}" if device.index is not None else f"{device.type}" + + if dtype is not None and dtype: + dtype = getattr(torch, dtype) + + if is_seq_of(item, Number) or is_num(item): + item = to_np(item) + + if is_np(item): + if item_dtype in dtype_torch_map: + item = item.astype(dtype_torch_map[item_dtype]) + if use_copy or not device.startswith("cpu") or requires_grad or np.isscalar(item) or not same_type: + extra_kwargs = {} if dtype is None else {"dtype": dtype} + return torch.tensor(item, requires_grad=requires_grad, device=device, **extra_kwargs) + else: + return torch.from_numpy(item) + elif is_torch(item): + if not equal(item.device.type, device): + item = item.to(device=device, non_blocking=non_blocking) + if not same_type: + item = item.to(dtype=dtype, non_blocking=non_blocking) + if use_copy: + item = item.clone().detach() + item = item.requires_grad_(requires_grad) + return item + else: + return item + + +def to_np(item, use_copy=False, dtype=None): + use_copy = use_copy or np.isscalar(item) or not equal(get_dtype(item), dtype) + if isinstance(item, (str, bytes)): + return np.array([item], dtype=object) + elif is_seq_of(item, Number): + return np.array(item, dtype=get_dtype(item[0]) if dtype is None else dtype) + + if is_np(item): + kwargs = {} if dtype is None else {"dtype": dtype} + return np.array(item, **kwargs) if use_copy else item + elif is_torch(item): + item = item.detach().cpu().numpy() + return to_np(item, False, dtype) + else: + return item + + +def to_array(item): + if is_torch(item): + return item.reshape(1) if item.nelement() == 1 else item + elif is_np_arr(item): + return item if item.ndim > 0 else item.reshape(1) + elif is_num(item) or (hasattr(item, "ndim") and item.ndim == 0): + return np.array([item]).reshape(1) + else: + try: + return np.array([item], dtype=object).reshape(1) + except: + print(item) + + +""" Convert normal python type """ + + +def dict_to_seq(x): + keys = list(sorted(x.keys())) + values = [x[k] for k in keys] + return keys, values + + +def seq_to_dict(keys, values): + return {keys[i]: values[i] for i in range(len(keys))} + + +def dict_to_str(x): + ret = "" + for key in x: + if ret != "": + ret += " " + if isinstance(x[key], (float, np.float32, np.float64)): + from math import log10 + + if abs(x[key]) < 1e-8: + ret += f"{key}: 0" + elif -1 <= log10(abs(x[key])) <= 5: # > 10000 or < 0.0001 + ret += f"{key}: {x[key]:.3f}" + else: + ret += f"{key}: {x[key]:.3e}" + else: + ret += f"{key}: {x[key]}" + return ret + + +def list_to_str(x): + return '[' + ','.join([f'{x[i]:.3f}' for i in range(len(x))]) + ']' + + +def slice_to_range(item): + start = item.start if item.start is not None else 0 + step = item.step if item.step is not None else 1 + return range(start, item.stop, step) + + +def range_to_slice(item): + return slice(item.start, item.stop, item.step) + + +def index_to_slice(index): + if len(index) == 1: + return index + diff = np.diff(index) + is_sorted = np.all(diff[0] == diff) + if is_sorted: + si, ei = index[0], index[-1] + index = slice(si, ei + 1, diff[0]) + return index \ No newline at end of file diff --git a/gello/data_utils/gdict/data/dict_array.py b/gello/data_utils/gdict/data/dict_array.py new file mode 100644 index 00000000..c39fc6ae --- /dev/null +++ b/gello/data_utils/gdict/data/dict_array.py @@ -0,0 +1,959 @@ +""" +TODO: Merge or improved with pytree in jax. +""" + +from collections import defaultdict +import numpy as np +from functools import wraps +from multiprocessing.shared_memory import SharedMemory + +from .array_ops import ( + squeeze, + unsqueeze, + zeros_like, + repeat, + tile, + shuffle, + take, + share_memory, + concat, + stack, + arr_mean, + to_item, + select_with_mask, + recover_with_mask, + detach, + get_nbytes, + split, + batch_shuffle, + decode_np, + to_two_dims, + to_list, + gather, + reshape, + transpose, + contiguous, + split_dim, + to_item, + to_cpu, + to_cuda, + allreduce, + slice_item, + deepcopy, +) +from .converter import as_dtype, to_np, to_torch, slice_to_range, to_array +from .type_utils import get_dtype, is_list_of, is_dict, is_h5, is_arr, is_num, is_np, is_str + +SMM, use_shared_mem = None, False + + +def create_smm(): + global SMM, use_shared_mem + if not use_shared_mem: + from multiprocessing.managers import SharedMemoryManager + + use_shared_mem = True + SMM = SharedMemoryManager() + SMM.start() + + +def delete_smm(): + global SMM, use_shared_mem + if use_shared_mem: + use_shared_mem = False + SMM.shutdown() + + +def replace_empty_with_none(*args): + args = list(args) + for i, x in enumerate(args): + if x is not None and isinstance(x, (list, dict)) and len(x) == 0: + x = None + args[i] = x + return args + + +def count_none(*args): + ret = 0 + for _ in list(args): + if _ is None: + ret += 1 + return ret + + +def get_first_not_none(*args): + for _ in list(args): + if _ is not None: + return _ + return None + + +class GDict: + """ + Generalized Dict(GDict) + Unified interface for dict, single element, HDF5 File. + GDict are defined with syntax: + GDict = GDict-Final | GDict-List | GDict-Dict + GDict-Final = Any object not with type list, tuple, dict + GDict-Dict or GDict-List = Dict or List of GDict + + Examples: + 1. GDict-Final: + 1) np-array: x = np.zeros(100) + 2) tensor: x = torch.tensor(100) + 3) HDF5 File: x = File('tmp.h5', 'r') + 4) Other python basic element: string, scalar, object. + 3. GDict-Dict or GDict-List or GDict-Tuple: + GDict-Dict: x = {'0': {'b': np.zeros(100)}} + GDict-List: x = [{'b': np.zeros(100)}, ] + x['0/b'][0] = 1 (x['0/b/0'] is wrong!) + Rules: + 1. No '\<>|:&?*"' in any keys (Compatible with filename rules in windows and unix) + '/' is used to separate two keys between two layers. + 2. All integer key will be converted to string + 3. tuple object will be converted to list + 4. key does not contain any index in GDict-Final (See example 3) + 5. Rules for converting a GDict object to HDF5 + 1) any number in keys of GDict-Dict will be converted to 'int_hdf5_' + number + 2) For GDict-List, the list will be converted to a dict with key 'list_int_hdf5_' + number + 3) GDict-Final: + 1) torch.Tensor will be converted to numpy array when is saved as HDF5 File and cannot be recovered. + 2) np.array will be saved as h5py.Dataset + 3) h5py object will be deep copied. + 4) other object will be serialized with pickle + + More Examples: + >>> GDict(np.ones(3)).memory + array([1., 1., 1.]) + >>> GDict(np.ones(3)).shape.memory + 3 + >>> d={'a': np.ones([1,1]), 'b': np.ones([2,3])} + >>> GDict(d).memory + {'a': array([[1.]]), 'b': array([[1., 1., 1.], + [1., 1., 1.]])} + >>> GDict(d).shape.memory + {'a': (1, 1), 'b': (2, 3)} + >>> l = [d,d] + >>> GDict(l).memory + [{'a': array([[1.]]), 'b': array([[1., 1., 1.], + [1., 1., 1.]])}, {'a': array([[1.]]), 'b': array([[1., 1., 1.], + [1., 1., 1.]])}] + >>> GDict(l).shape.memory + [{'a': (1, 1), 'b': (2, 3)}, {'a': (1, 1), 'b': (2, 3)}] + """ + + def __init__(self, item=None, faster=False, **kwargs): + self.memory = item if faster else self.to_item(item) + self.capacity = getattr(item, "capacity", None) + + @classmethod + def _is_final(cls, item): + return not isinstance(item, (list, dict)) + + @classmethod + def to_item(cls, item): + if isinstance(item, GDict): + return cls.to_item(item.memory) + elif is_dict(item): + ret = {key: cls.to_item(item[key]) for key in item} + return ret + elif isinstance(item, (list, tuple)): + return [cls.to_item(x) for x in item] + else: + return item + + @classmethod + def check_item(cls, item): + if isinstance(item, dict): + for key in item: + if not cls.check_item(item[key]): + return False + elif isinstance(item, list): + for x in item: + if not cls.check_item(x): + return False + elif isinstance(item, (tuple, GDict)): + return False + return True + + @classmethod + def assert_item(cls, item): + assert cls.check_item(item), "Tuple and GDict should be missing in self.memory" + + @classmethod + def _recursive_do_on_memory(cls, memory, function, new=True, ignore_list=False, *args, **kwargs): + """Apply an operation to all elements in GDict. The operator can be functions in array_ops.""" + if isinstance(memory, dict): + ret = {} if new else memory + for key, value in memory.items(): + if cls._is_final(value): + ret[key] = function(value, *args, **kwargs) + else: + ret[key] = cls._recursive_do_on_memory(memory[key], function, new, ignore_list, *args, **kwargs) + return ret + elif isinstance(memory, list) and not ignore_list: + ret = [None for x in memory] if new else memory + for key, value in enumerate(memory): + if cls._is_final(value): + ret[key] = function(value, *args, **kwargs) + else: + ret[key] = cls._recursive_do_on_memory(memory[key], function, new, ignore_list, *args, **kwargs) + return ret + else: + return function(memory, *args, **kwargs) + + @classmethod + def _recursive_do(cls, memory, function, new=True, wrapper=True, capacity=None, *args, **kwargs): + item = cls._recursive_do_on_memory(memory, function, new, *args, **kwargs) + return cls(item, capacity=capacity, faster=True) if wrapper else item + + @classmethod + def _recursive_do_gdict(cls, memory, function, new=True, wrapper=True, *args, **kwargs): + item = cls._recursive_do_on_memory(memory, function, new, *args, **kwargs) + return GDict(item, faster=True) if wrapper else item + + @classmethod + def _recursive_compare(cls, a, b, function): + if isinstance(a, dict): + inter_set = set(a.keys()) & set(b.keys()) + for key in inter_set: + if not cls._recursive_compare(a[key], b[key], function): + return False + elif isinstance(a, list): + for i in range(min(len(a), len(b))): + if not cls._recursive_compare(a[i], b[i], function): + return False + else: + return function(a, b) + return True + + @classmethod + def _get_item(cls, memory, keys): + if len(keys) == 0 or memory is None: + return memory + elif is_dict(memory): + key = keys[0] + return cls._get_item(memory.get(key, None), keys[1:]) + elif is_list_of(memory): + key = eval(keys[0]) + return cls._get_item(memory[key], keys[1:]) + else: + print(f"Error! Keys should not cover the item in {type(memory)}, recent keys {keys}.") + + @classmethod + def _set_item(cls, memory, keys, value): + if isinstance(memory, GDict): + memory = memory.memory + if len(keys) == 0: + return value + elif is_dict(memory): + key = keys[0] + memory[key] = cls._set_item(memory.get(key, None), keys[1:], value) + elif is_list_of(memory): + key = eval(keys[0]) + if key > len(memory): + for i in range(key - len(memory) + 1): + memory.append(None) + memory[key] = cls._set_item(memory[key], keys[1:], value) + else: + print(f"Error! Keys should not cover the item in {type(memory)}, recent keys {keys}.") + return memory + + @classmethod + def _update_memory(cls, target, other): + if is_list_of(target): + if len(other) > len(target): + for i in range(len(other) - len(target)): + target.append(None) + for i in range(len(other)): + target[i] = cls._update_memory(target[i], other[i]) + elif is_dict(target): + for key in other: + target[key] = cls._update_memory(target.get(key, None), other[key]) + else: + target = other + return target + + def update(self, other): + if isinstance(other, GDict): + other = other.memory + self.memory = self._update_memory(self.memory, other) + + def compatible(self, other): + if isinstance(other, GDict): + other = other.memory + + def _compatible(a, b): + return type(a) == type(b) + + return self._recursive_compare(self.memory, other, _compatible) + + def shared_memory(self, other): + other = type(self)(other) + return self._recursive_compare(self.memory, other.memory, share_memory) + + def copy(self, wrapper=True): + return self._recursive_do(self.memory, deepcopy, wrapper=wrapper) + + def to_torch(self, use_copy=False, device="cpu", non_blocking=False, dtype=None, requires_grad=False, wrapper=True): + return self._recursive_do( + self.memory, + to_torch, + use_copy=use_copy, + device=device, + non_blocking=non_blocking, + dtype=dtype, + requires_grad=requires_grad, + wrapper=wrapper, + ) + + def to_array(self, wrapper=True): + return self._recursive_do(self.memory, to_array, wrapper=wrapper) + + def to_numpy(self, use_copy=False, dtype=None, wrapper=True): + return self._recursive_do(self.memory, to_np, use_copy=use_copy, dtype=dtype, wrapper=wrapper) + + def to_hdf5(self, file): + from gdict.file import dump_hdf5 + + dump_hdf5(self.memory, file) + + @classmethod + def from_hdf5(cls, file, wrapper=True): + from gdict.file import load_hdf5 + + ret = load_hdf5(file) + if wrapper: + ret = cls(ret) + return ret + + @property + def shape(self): + def get_shape(x): + shape = getattr(x, "shape", None) + if shape is not None and len(shape) == 1: + shape = shape[0] + return shape + + return self._recursive_do_on_memory(self.memory, get_shape) + + @property + def list_shape(self): + def get_shape(x): + shape = getattr(x, "shape", None) + if shape is not None and len(shape) == 1: + shape = shape[0] + else: + shape = list(shape) # For torch.Size + return shape + + return self._recursive_do_on_memory(self.memory, get_shape) + + @property + def type(self): + return self._recursive_do_on_memory(self.memory, type) + + @property + def dtype(self): + return self._recursive_do_on_memory(self.memory, get_dtype) + + @property + def nbytes(self): + return self._recursive_do_on_memory(self.memory, get_nbytes) + + @property + def is_np(self): + return self._recursive_do_on_memory(self.memory, is_np) + + @property + def is_np_all(self): + ret = self._flatten(self._recursive_do_on_memory(self.memory, is_np)) + return np.alltrue([v for k, v in ret.items()]) if isinstance(ret, dict) else ret + + @property + def nbytes_all(self): + ret = self._flatten(self._recursive_do_on_memory(self.memory, get_nbytes)) + return sum([v for k, v in ret.items()]) if isinstance(ret, dict) else ret + + @property + def is_big(self): + return self.nbytes_all / 1024 / 1024 > 1 + + @property + def device(self): + def get_device(x): + device = getattr(x, "device", None) + if device is not None: + device = f"{device.type}:{device.index}" if device.index is not None else f"{device.type}" + return device + + return self._recursive_do_on_memory(self.memory, get_device) + + def cpu(self, wrapper=True): + return self._recursive_do_gdict(self.memory, to_cpu, wrapper=wrapper) + + def cuda(self, device="cuda", wrapper=True): + return self._recursive_do_gdict(self.memory, to_cuda, device=device, wrapper=wrapper) + + def item(self, wrapper=True): + return self._recursive_do_gdict(self.memory, to_item, wrapper=wrapper) + + def item(self, wrapper=True): + return self._recursive_do_gdict(self.memory, to_item, wrapper=wrapper) + + def astype(self, dtype, wrapper=True): + return self._recursive_do(self.memory, as_dtype, dtype=dtype, wrapper=wrapper, capacity=self.capacity) + + def float(self, wrapper=True): + return self.astype("float32", wrapper=wrapper) + + def f64_to_f32(self, wrapper=True): + from .compression import f64_to_f32 + + return self._recursive_do(self.memory, f64_to_f32, wrapper=wrapper, capacity=self.capacity) + + def squeeze(self, axis=None, wrapper=True): + return self._recursive_do(self.memory, squeeze, axis=axis, wrapper=wrapper) + + def unsqueeze(self, axis, wrapper=True): + return self._recursive_do(self.memory, unsqueeze, axis=axis, wrapper=wrapper, + capacity=self.capacity if axis != 0 else 1) + + def detach(self, wrapper=True): + return self._recursive_do(self.memory, detach, wrapper=wrapper, capacity=self.capacity) + + def to_zeros(self, wrapper=True): + return self._recursive_do(self.memory, zeros_like, wrapper=wrapper, capacity=self.capacity) + + def repeat(self, rep, axis=None, wrapper=True): + return self._recursive_do( + self.memory, repeat, rep=rep, axis=axis, wrapper=wrapper, + capacity=self.capacity if axis != 0 and axis is not None else None + ) + + def reshape(self, newshape, wrapper=True): + return self._recursive_do(self.memory, reshape, newshape=newshape, wrapper=wrapper, capacity=newshape) + + def split_dim(self, axis, newaxes, wrapper=True): + assert isinstance(newaxes, (list, tuple)) + return self._recursive_do( + self.memory, split_dim, axis=axis, newaxes=newaxes, wrapper=wrapper, + capacity=self.capacity if axis != 0 else newaxes[0] + ) + + def transpose(self, axis0, axis1, contiguous=True, wrapper=True): + return self._recursive_do( + self.memory, + transpose, + axis0=axis0, + axis1=axis1, + contiguous=contiguous, + wrapper=wrapper, + capacity=self.capacity if 0 not in [axis0, axis1] else None, + ) + + def contiguous(self, wrapper=True): + return self._recursive_do(self.memory, contiguous, wrapper=wrapper, capacity=self.capacity) + + def tile(self, rep, wrapper=True): + return self._recursive_do(self.memory, tile, rep=rep, wrapper=wrapper) + + def mean(self, axis=None, keepdim=False, wrapper=True): + return self._recursive_do( + self.memory, arr_mean, axis=axis, keepdim=keepdim, wrapper=wrapper, + capacity=self.capacity if axis != 0 and axis is not None else None + ) + + @classmethod + def _assign(cls, memory, indices, value, ignore_list=False): + if isinstance(value, tuple): + value = list(value) + if is_dict(memory): + assert type(memory) == type(value), f"{type(memory), type(value)}" + for key in memory: + if key in value: + memory[key] = cls._assign(memory[key], indices, value[key], ignore_list) + elif is_arr(memory): + assert type(memory) == type(value) or np.isscalar(value), f"{type(memory), type(value)}" + if share_memory(memory, value): + memory[indices] = deepcopy(value) + else: + memory[indices] = value + elif is_list_of(memory): + if ignore_list: + memory[indices] = value + else: + # if is_num(indices): + # memory[indices] = value if is_num(value) else value[indices] + # else: + # assert type(memory) == type(value), f"{type(memory), type(value)}" + for i in range(min(len(memory), len(value))): + memory[i] = cls._assign(memory[i], indices, value[i], ignore_list) + return memory + + def assign_list(self, index, value): + if isinstance(value, GDict): + value = value.memory + assert is_num(index) + self.memory = self._assign(self.memory, index, value, True) + + def to_two_dims(self, wrapper=True): + return self._recursive_do(self.memory, to_two_dims, wrapper=wrapper) + + def take_list(self, index, wrapper=True): + assert is_num(index) + return self._recursive_do_gdict(self.memory, take, indices=index, axis=0, ignore_list=True, wrapper=wrapper) + + def to_list(self, wrapper=True): + return self._recursive_do(self.memory, to_list, wrapper=wrapper) + + def select_with_mask(self, mask, wrapper=True): + return self._recursive_do(self.memory, select_with_mask, mask=mask, wrapper=wrapper, + capacity=to_item(mask.sum())) + + def recover_with_mask(self, mask, wrapper=True): + return self._recursive_do(self.memory, select_with_mask, mask=mask, wrapper=wrapper, capacity=mask.shape[0]) + + def allreduce(self, op="MEAN", device="cuda", wrapper=True): + return self._recursive_do(self.memory, allreduce, op=op, device=device, wrapper=wrapper, capacity=self.capacity) + + def to_gdict(self): + return GDict(self.memory, faster=True) + + @property + def one_device(self): + return self._get_one_attr(self.memory, "device") + + @property + def one_shape(self): + return self._get_one_attr(self.memory, "shape") + + @property + def one_dtype(self): + return self._get_one_attr(self.memory, "dtype") + + def _flatten(cls, memory, root_key="", full=True): + if is_dict(memory): + ret = {} + for key in memory: + ret.update(cls._flatten(memory[key], f"{root_key}/{key}", full)) + elif is_list_of(memory) and (full or len(memory) > 10): + # Simplify flatten result for small list or tuple + ret = {} + for i in range(len(memory)): + ret.update(cls._flatten(memory[i], f"{root_key}/{i}", full)) + else: + return memory if root_key == "" else {root_key.replace("//", "/"): memory} + return ret + + def flatten(self, full=True): + return type(self)(self._flatten(self.memory, "", full)) + + @classmethod + def wrapper(cls, class_method=False): + if not class_method: + + def decorator(func): + @wraps(func) + def wrapper(item, *args, **kwargs): + if isinstance(item, GDict): + return func(item, *args, **kwargs) + else: + return func(GDict(item), *args, **kwargs).memory + + return wrapper + + else: + + def decorator(func): + @wraps(func) + def wrapper(self, item, *args, **kwargs): + if isinstance(item, GDict): + return func(self, item, *args, **kwargs) + else: + return func(self, GDict(item), *args, **kwargs).memory + + return wrapper + + return decorator + + def select_by_keys(self, keys=None, to_list=False, wrapper=True): + def _dfs_select(memory, keys=None): + if keys is None: + return memory + if isinstance(memory, dict): + new_keys = {} + for key in keys: + fk = key[0] + if len(key) > 1: + if fk not in new_keys: + new_keys[fk] = [] + new_keys[fk].append(key[1:]) + else: + new_keys[fk] = None + return {key: _dfs_select(memory[key], new_keys[key]) for key in new_keys} + elif isinstance(memory, list): + new_keys = {} + for key in keys: + fk = eval(key[0]) if is_str(key[0]) else key[0] + if len(key) > 1: + if fk not in new_keys: + new_keys[fk] = [] + new_keys[fk].append(key[1:]) + else: + new_keys[fk] = None + return [_dfs_select(memory[key], new_keys[key]) for key in sorted(new_keys)] + else: + raise ValueError(f"{keys}") + + if not isinstance(keys, (list, tuple)) and keys is not None: + keys = [keys] + single = True + else: + single = False + keys = [self._process_key(key) for key in keys] + memory = _dfs_select(self.memory, keys) + if to_list: + memory = type(self)(memory) + memory = [memory[key] for key in keys] + if single: + memory = memory[0] + if wrapper: + memory = type(self)(memory) + return memory + + def take(self, indices, axis=0, wrapper=True): # will always copy data, needs double check + if is_num(indices): + return self._recursive_do_gdict(self.memory, take, indices=indices, axis=axis, wrapper=wrapper) + else: + + if isinstance(indices, slice): + len_indices = len(slice_to_range(indices)) + else: + len_indices = len(indices) + new_capacity = len_indices if axis == 0 else self.capacity + return self._recursive_do(self.memory, take, indices=indices, axis=axis, wrapper=wrapper, + capacity=new_capacity) + + def slice(self, slice, axis=0, wrapper=True): # no copy + return self._recursive_do(self.memory, slice_item, slice=slice, axis=axis, wrapper=wrapper) + + def assign_all(self, value): + if isinstance(value, GDict): + value = value.memory + self.memory = self._assign(self.memory, slice(None, None, None), value) + + @classmethod + def _do_on_list_of_array(cls, memories, function, **kwargs): + for i in range(len(memories)): + assert type(memories[i]) is type(memories[0]), f"{type(memories[i]), type(memories[0])}" + if isinstance(memories[0], (tuple, list)): + for i in range(len(memories)): + assert len(memories[i]) == len(memories[0]) + ret = [] + for i in range(len(memories[0])): + ret.append(cls._do_on_list_of_array([memories[j][i] for j in range(len(memories))], function, **kwargs)) + elif isinstance(memories[0], dict): + for i in range(len(memories)): + assert set(memories[i].keys()) == set( + memories[0].keys()), f"{set(memories[i].keys())}, {set(memories[0].keys())}" + ret = {} + for key in memories[0]: + ret[key] = cls._do_on_list_of_array([memories[j][key] for j in range(len(memories))], function, + **kwargs) + else: + ret = function(memories, **kwargs) + return ret + + @classmethod + def concat(cls, items, axis=0, wrapper=True): + ret = cls._do_on_list_of_array([_.memory if isinstance(_, GDict) else _ for _ in items], concat, axis=axis) + if wrapper: + capacity = 0 + for item in items: + if isinstance(item, GDict) and item.capacity is not None: + capacity += item.capacity + else: + capacity = None + break + return cls(ret, capacity=capacity, faster=True) + else: + return ret + + @classmethod + def stack(cls, items, axis=0, wrapper=True): + ret = cls._do_on_list_of_array([_.memory if isinstance(_, GDict) else _ for _ in items], stack, axis=axis) + if wrapper: + if axis == 0: + capacity = len(items) + else: + capacity = None + for item in items: + if isinstance(item, cls) and item.capacity is not None: + capacity = item.capacity + break + return cls(ret, capacity=capacity, faster=True) + else: + return ret + + @classmethod + def _process_key(cls, key): + if is_num(key): + key = str(key) + return key if isinstance(key, (list, tuple)) else key.strip("/").replace("//", "/").split("/") + + def __getitem__(self, key): + return self._get_item(self.memory, self._process_key(key)) + + def __setitem__(self, key, value): + self.memory = self._set_item(self.memory, self._process_key(key), value) + return self.memory + + def __str__(self): + return str(self._flatten(self.memory, "", False)) + + def __dict__(self): + assert isinstance(self.memory, dict), "self.memory is not a dict!" + return self.memory + + def __getattr__(self, key): + if key == 'memory': + assert False, "GDict should always have a memory attribute!" + return getattr(self.memory, key) + + def __getstate__(self): + return self.memory + + def __setstate__(self, state): + self.memory = state + + def __contains__(self, key): + if "/" in key: + key = self._process_key(key) + memory = self.memory + for _ in key: + if _ not in memory: + return False + memory = memory[_] + return True + else: + return key in self.memory + + def __delitem__(self, key): + keys = list(self._process_key(key)) + last_memory = None + memory = self.memory + for i, key in enumerate(keys): + if isinstance(last_memory, list) and isinstance(key, str): + key = eval(key) + keys[i] = key + last_memory = memory + memory = memory[key] + + if last_memory is None: + self.memory = None + elif isinstance(last_memory, (dict, list)): + last_memory.pop(key) + + +class DictArray(GDict): + """ + DictArray is a special GDict which requires the first dimension of all GDict-Final must be same + """ + + def __init__(self, item=None, capacity=None, faster=False): + super(DictArray, self).__init__(item, faster=faster) + if item is None: + self.capacity = None + return + if capacity is not None: + self.capacity = capacity + if not faster: + self.memory = self.to_array(wrapper=False) + self.memory = self.unsqueeze(axis=0, wrapper=False) # .to_zeros(wrapper=False) + if capacity != 1: + self.memory = self.repeat(capacity, axis=0, wrapper=False) + elif self.capacity is None: + self.capacity = self._get_one_attr(self.memory, "shape")[0] + if not faster: + self.assert_shape(self.memory, self.capacity) + + @classmethod + def _get_one_attr(cls, memory, attr): + # print(type(memory), attr) + if isinstance(memory, dict): + for key in memory: + if hasattr(memory[key], attr): + return getattr(memory[key], attr) + ans = cls._get_one_attr(memory[key], attr) + if ans is not None: + return ans + elif isinstance(memory, list): + for x in memory: + if hasattr(x, attr): + return getattr(x, attr) + ans = cls._get_one_attr(x, attr) + if ans is not None: + return ans + elif hasattr(memory, attr): + return getattr(memory, attr) + return None + + @classmethod + def check_shape(cls, memory, capacity): + if isinstance(memory, dict): + for key in memory: + if not cls.check_shape(memory[key], capacity): + return False + elif isinstance(memory, list): + for x in memory: + if not cls.check_shape(x, capacity): + return False + elif hasattr(memory, "shape"): + return memory.shape[0] == capacity + return True + + @classmethod + def assert_shape(cls, memory, capacity): + assert cls.check_shape(memory, capacity), f"The first dimension is not {capacity}!" + + def sample(self, batch_size, valid_capacity=None, wrapper=True): + capacity = self.capacity if valid_capacity is None else valid_capacity + indices = np.random.randint(low=0, high=capacity, size=batch_size) + return self._recursive_do(self.memory, take, indices=indices, axis=0, wrapper=wrapper, capacity=batch_size) + + def shuffle(self, valid_capacity=None, wrapper=True, in_place=True): + capacity = self.capacity if valid_capacity is None else valid_capacity + indices = shuffle(np.arange(capacity), axis=0) + # print(valid_capacity, self.capacity) + # print(np.unique(indices).shape, len(indices)) + # exit(0) + # print(capacity, self.capacity) + if in_place: + # print(indices) + items = self.take(slice(0, capacity), wrapper=False) + # print(items.shape, share_memory(items['actions'], self.memory['actions'])) + self.assign(indices, items) + # self._recursive_do(self.memory, take, indices=indices, axis=0, wrapper=False, capacity=self.capacity) + else: + if capacity < self.capacity: + indices = np.concatenate([indices, np.arange(self.capacity - capacity) + capacity], axis=0) + return self._recursive_do(self.memory, take, indices=indices, axis=0, wrapper=wrapper, + capacity=self.capacity) + + def assign(self, indices, value): + if isinstance(value, GDict): + value = value.memory + self.memory = self._assign(self.memory, indices, value) + + def gather(self, axis, index, wrapper=True): + return self._recursive_do(self.memory, gather, axis=axis, index=index, wrapper=wrapper) + + def to_dict_array(self): + return DictArray(self.memory, capacity=self.capacity, faster=True) + + def __len__(self): + return self.capacity + + +class SharedGDict(GDict): + def __init__(self, gdict=None, shape=None, dtype=None, name=None): + if gdict is not None: + assert shape is None and dtype is None and name is None + assert isinstance(gdict, GDict) and gdict.is_np_all + shape = gdict.shape + dtype = gdict.dtype + nbytes = gdict.nbytes + else: + assert not (shape is None or dtype is None or name is None) + nbytes = None + + self.is_new = name is None + + name, self.shared_memory = self._create_shared_memory(shape, dtype, nbytes, name) + memory = self._create_np_from_memory(self.shared_memory, shape, dtype) + + self.shared_shape = shape + self.shared_dtype = dtype + self.shared_name = name + + super(SharedGDict, self).__init__(memory) + + def _create_np_from_memory(cls, shared_memory, shape, dtype): + if isinstance(shared_memory, dict): + memory = {k: cls._create_np_from_memory(shared_memory[k], shape[k], dtype[k]) for k in shared_memory} + elif isinstance(shared_memory, list): + memory = [cls._create_np_from_memory(shared_memory[k], shape[k], dtype[k]) for k in + range(len(shared_memory))] + else: + if isinstance(dtype, str): + dtype = np.dtype(dtype) + memory = np.ndarray(shape, dtype=dtype, buffer=shared_memory.buf) + return memory + + def _create_shared_memory(cls, shape, dtype, nbytes, name=None): + if name is None: + # Create new shared buffer + if isinstance(nbytes, dict): + ret_name, ret_memory = {}, {} + for key in nbytes: + name_k, memory_k = cls._create_shared_memory(shape[key], dtype[key], nbytes[key], None) + ret_name[key] = name_k + ret_memory[key] = memory_k + elif isinstance(nbytes, (list, tuple)): + ret_name, ret_memory = [], [] + for key in range(len(nbytes)): + name_k, memory_k = cls._create_shared_memory(shape[key], dtype[key], nbytes[key], None) + ret_name.append(name_k) + ret_memory.append(memory_k) + else: + assert is_num(nbytes), f"{nbytes}" + ret_memory = SharedMemory(size=nbytes, create=True) + ret_name = ret_memory.name + else: + ret_name = name + if isinstance(name, dict): + ret_memory = {k: cls._create_shared_memory(shape[k], dtype[k], None, name[k])[1] for k in name} + elif isinstance(name, (list, tuple)): + ret_memory = [cls._create_shared_memory(shape[k], dtype[k], None, name[k])[1] for k in range(len(name))] + else: + assert isinstance(name, str), f"{name}" + ret_memory = SharedMemory(name=name, create=False) + return ret_name, ret_memory + + def get_infos(self): + return self.shared_shape, self.shared_dtype, self.shared_name + + def _unlink(self): + memory = self._flatten(self.shared_memory) + if isinstance(memory, dict): + for k, v in memory.items(): + v.unlink() + else: + memory.unlink() + + def _close(self): + memory = self._flatten(self.shared_memory) + if isinstance(memory, dict): + for k, v in memory.items(): + v.close() + elif not callable(memory): + memory.close() + + def __del__(self): + self._close() + if self.is_new: + self._unlink() + + def get_full_by_key(self, key): + ret = [] + for name in ["shared_shape", "shared_dtype", "shared_name"]: + ret.append(self._get_item(getattr(self, name), self._process_key(key))) + return type(self)(None, *ret) + + def __setitem__(self, key, value): + assert False, "Please convert to GDict or Dictarray then change the value!" + + +class SharedDictArray(SharedGDict, DictArray): + pass diff --git a/gello/data_utils/gdict/data/dict_utils.py b/gello/data_utils/gdict/data/dict_utils.py new file mode 100644 index 00000000..af301c4b --- /dev/null +++ b/gello/data_utils/gdict/data/dict_utils.py @@ -0,0 +1,66 @@ +from copy import deepcopy +from .type_utils import is_dict, is_seq_of + + +def update_dict(x, y): + """ + Update x with y + """ + assert type(x) == type(y), f"{type(x), type(y)}" + if is_dict(x): + ret = deepcopy(x) + for key in y: + if key in x: + ret[key] = update_dict(x[key], y[key]) + else: + ret[key] = deepcopy(y[key]) + else: + ret = deepcopy(y) + return ret + + +def update_dict_with_begin_keys(x, y, keys, begin=False, history_key=()): + if len(keys) == 0: + if type(x) == type(y): + return update_dict(x, y) + elif is_seq_of(x, dict) and is_dict(y): + return [update_dict(_, y) for _ in x] + else: + raise NotImplementedError() + if not is_dict(x): + return deepcopy(x) + + ret = {} + for key in x: + if key == keys[0]: + ret[key] = update_dict_with_begin_keys(x[key], y, keys[1:], True, history_key + (key,)) + elif not begin: + ret[key] = update_dict_with_begin_keys(x[key], y, keys, False, history_key + (key,)) + else: + ret[key] = deepcopy(x[key]) + return ret + + +def first_dict_key(item): + return sorted(item.keys())[0] + + +def map_dict_keys(inputs, keys_map, logger_print=None): + from .string_utils import regex_replace, regex_match, is_regex + import re + + outputs = {} + for key, value in inputs.items(): + new_key = key + for in_pattern, out_pattern in keys_map.items(): + if regex_match(key, in_pattern): + new_key = regex_replace(key, in_pattern, out_pattern) + break + if new_key == "None" or new_key is None: + if logger_print is not None: + logger_print(f"Delete {key}!") + continue + if new_key != key and logger_print is not None: + logger_print(f"Change {key} to {new_key}.") + outputs[new_key] = value + return outputs diff --git a/gello/data_utils/gdict/data/filtering.py b/gello/data_utils/gdict/data/filtering.py new file mode 100644 index 00000000..027da5f9 --- /dev/null +++ b/gello/data_utils/gdict/data/filtering.py @@ -0,0 +1,37 @@ +from .string_utils import regex_match +from .type_utils import is_dict, is_tuple_of, is_list_of + + +def custom_filter(item, func, value=True): + """ + Recursively filter all elements with function func. + Assumptions: + None means the item does not pass func. + """ + if is_tuple_of(item): + item = list(item) + if is_list_of(item): + ret = [] + for i in range(len(item)): + x = custom_filter(item[i], func, value) + if x is not None: + ret.append(x) + item = ret + elif is_dict(item): + ret = {} + for key in item: + x = custom_filter(item[key], func, value) + if x is not None: + ret[key] = x + item = ret + return item if not value or (item is not None and func(item)) else None + + +def filter_none(x): + func = lambda _: _ is not None + return custom_filter(x, func, True) + + +def filter_with_regex(x, regex, value=True): + func = lambda _: _ is not None and regex_match(_, regex) + return custom_filter(x, func, value) diff --git a/gello/data_utils/gdict/data/misc.py b/gello/data_utils/gdict/data/misc.py new file mode 100644 index 00000000..ce52e7f1 --- /dev/null +++ b/gello/data_utils/gdict/data/misc.py @@ -0,0 +1,5 @@ +def equal(x, y): + return True if x is None or y is None else x == y + + +SLICE_ALL = slice(None, None, None) diff --git a/gello/data_utils/gdict/data/seq_utils.py b/gello/data_utils/gdict/data/seq_utils.py new file mode 100644 index 00000000..5ede4243 --- /dev/null +++ b/gello/data_utils/gdict/data/seq_utils.py @@ -0,0 +1,75 @@ +import itertools +from copy import deepcopy +from random import shuffle +from .type_utils import is_seq_of + + +def concat_seq(in_list, dtype): + assert dtype in [list, tuple] + return dtype(itertools.chain(*in_list)) + + +def concat_list(in_list): + return concat_seq(in_list, list) + + +def concat_tuple(in_list): + return concat_seq(in_list, tuple) + + +def auto_pad_seq(a, b): + """ + Input two sequence, then output two list of objects with the same size. + """ + a = list(a) if isinstance(a, (list, tuple)) else [a] + b = list(b) if isinstance(b, (list, tuple)) else [b] + if len(a) > len(b): + for i in range(len(a) - len(b)): + b.append(a[0]) + elif len(a) < len(b): + for i in range(len(b) - len(a)): + a.append(b[0]) + return a, b + + +def flatten_seq(x, dtype=list): + if not is_seq_of(x, (tuple, list)): + return x + return dtype(concat_list([flatten_seq(_) for _ in x])) + + +def split_list_of_parameters(num_procsess, *args, **kwargs): + from ..math import split_num + + args = [_ for _ in args if _ is not None] + kwargs = {_: __ for _, __ in kwargs.items() if __ is not None} + assert len(args) > 0 or len(kwargs) > 0 + first_item = args[0] if len(args) > 0 else kwargs[list(kwargs.keys())[0]] + n, running_steps = split_num(len(first_item), num_procsess) + start_idx = 0 + paras = [] + for i in range(n): + slice_i = slice(start_idx, start_idx + running_steps[i]) + start_idx += running_steps[i] + args_i = list([_[slice_i] for _ in args]) + kwargs_i = {_: kwargs[_][slice_i] for _ in kwargs} + paras.append([args_i, kwargs_i]) + return paras + + +def select_by_index(files, indices): + return [files[i] for i in indices] + + +def random_pad_clip_list(x, num): + x = deepcopy(list(x)) + if len(x) > num: + shuffle(x) + return x[:num] + else: + ret = [] + for i in range(num // len(x)): + shuffle(x) + ret = ret + x + ret = ret + x[: num - len(ret)] + return ret diff --git a/gello/data_utils/gdict/data/string_utils.py b/gello/data_utils/gdict/data/string_utils.py new file mode 100644 index 00000000..7842b447 --- /dev/null +++ b/gello/data_utils/gdict/data/string_utils.py @@ -0,0 +1,65 @@ +""" +Useful regex expression + 1. nothing else classifier: '^((?!classifier).)*$' + 2. any string: '(.*?)' +""" + +import re + + +any_string = r"(.*?)" + + +def custom_format(template_string, **kwargs): + template_string = template_string.replace("{", "{{") + template_string = template_string.replace("}", "}}") + template_string = template_string.replace("&lformat ", "{") + template_string = template_string.replace(" &rformat", "}") + return template_string.format_map(kwargs) + + +def regex_match(string, pattern): + return re.match(pattern, string) is not None + + +def regex_replace(string, pattern, new_pattern): + return re.sub(pattern, new_pattern, string) + + +def prefix_match(string, prefix=None): + """Check if the string matches the given prefix""" + if prefix is None or len(prefix) == 0: + return True + return re.match(f"({prefix})+(.*?)", string) is not None + + +def is_regex(s): + try: + re.compile(s) + return True + except: + return False + + +def float_str(num, precision): + format_str = "%.{0}f".format(precision) + return format_str % num + + +def num_to_str(num, unit=None, precision=2, number_only=False, auto_select_unit=False): + unit_list = ["K", "M", "G", "T", "P"] + if auto_select_unit and unit is None: + for i, tmp in enumerate(unit_list): + unit_num = 1024 ** (i + 1) + if num < unit_num: + break + unit = tmp + if unit is not None: + unit_num = 1024 ** (unit_list.index(unit) + 1) + num = num * 1.0 / unit_num + else: + unit = "" + if number_only: + return num + else: + return float_str(num, precision) + unit diff --git a/gello/data_utils/gdict/data/type_utils.py b/gello/data_utils/gdict/data/type_utils.py new file mode 100644 index 00000000..fbeca374 --- /dev/null +++ b/gello/data_utils/gdict/data/type_utils.py @@ -0,0 +1,121 @@ +from collections.abc import Sequence +from numbers import Number +import numpy as np + + +""" For python basic type """ + + +def is_null(item): + return item is None + + +def is_not_null(item): + return item is not None + + +def is_slice(item): + return isinstance(item, slice) + + +def is_str(item): + return isinstance(item, str) + + +def is_dict(item): + return isinstance(item, dict) + + +def is_num(item): + return isinstance(item, Number) + + +def is_integer(item): + return isinstance(item, (int, np.integer)) + + +def is_type(item): + return isinstance(item, type) + + +def is_seq_of(seq, expected_type=None, seq_type=None): + if seq_type is None: + exp_seq_type = Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + if expected_type: + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def is_list_of(seq, expected_type=None): + return is_seq_of(seq, expected_type, seq_type=list) + + +def is_tuple_of(seq, expected_type=None): + return is_seq_of(seq, expected_type, seq_type=tuple) + + +def is_iterable(item): + return isinstance(item, (dict, tuple, list)) + + +""" For numpy and torch type """ + + +def get_dtype(item): + if isinstance(item, (list, tuple)): + item = item[0] + if hasattr(item, "dtype"): + return str(item.dtype).split(".")[-1] + elif isinstance(item, (int, float, bytes, str)): + return type(item) + else: + return None + + +def is_np(item): + return isinstance(item, np.ndarray) or is_num(item) + + +def is_np_arr(item): + return isinstance(item, np.ndarray) + + +def is_torch(item): + import torch + + return isinstance(item, torch.Tensor) + + +def is_torch_distribution(item): + import torch + + return isinstance(item, torch.distributions.Distribution) + + +def is_arr(item, arr_type=None): + if is_num(item): + return False + if arr_type is not None: + assert arr_type in ["np", "torch"] + return eval(f"is_{arr_type}")(item) + elif is_np(item): + return True + else: + # Torch as the last option to reduce memory usage + return is_torch(item) + + +""" For HDF5 type """ + + +def is_h5(item): + from h5py import File, Group, Dataset + + return isinstance(item, (File, Group, Dataset)) diff --git a/gello/data_utils/gdict/data/wrappers.py b/gello/data_utils/gdict/data/wrappers.py new file mode 100644 index 00000000..a1bf9858 --- /dev/null +++ b/gello/data_utils/gdict/data/wrappers.py @@ -0,0 +1,98 @@ +from functools import wraps + +from numbers import Number +from .converter import to_np, to_torch, to_array + + +def seq_to_np(to_arr=True): + def decorator(func): + @wraps(func) + def wrapper(item, *args, **kwargs): + if isinstance(item, (list, tuple, Number)): + item = to_np(item) + if to_arr: + item = to_array(item) + return func(item, *args, **kwargs) + + return wrapper + + return decorator + + +def check_consistent(keys, dtypes): + if dtypes: + dtypes = (dtypes,) if not isinstance(dtypes, (list, tuple)) else dtypes + if keys: + keys = (keys,) if not isinstance(keys, (list, tuple)) else keys + if dtypes: + assert len(keys) == len(dtypes) + else: + if dtypes: + assert len(dtypes) == 1 + + +def apply_func(func, x): + if isinstance(x, (list, tuple, set)): + return type(x)(map(func, x)) + elif isinstance(x, dict): + for k in x: + x[k] = func(x[k]) + return x + else: + return func(x) + + +def change_dtype(x, keys=None, dtypes=None, np=False): + if dtypes is None: + return x + processor = to_np if np else to_torch + if not isinstance(dtypes, (list, tuple)): + dtypes = [dtypes] + + if not isinstance(x, (tuple, list, dict)) or keys is None: + assert len(dtypes) == 1 + return processor(x, dtypes[0]) + + if not isinstance(keys, (list, tuple)): + keys = [keys] + # key and dtypes are list or tuple, dtypes is a list, x is a list, tuple or dict + + ret = list(x) if isinstance(x, (list, tuple)) else x + if len(dtypes) == 1: + dtypes = [dtypes[0] for i in range(len(keys))] + for k, dtype in enumerate(keys, dtypes): + ret[k] = processor(ret[k], dtype) + return type(x)(ret) + + +def process_output(keys=None, dtypes=None, np=True): + check_consistent(keys, dtypes) + + def decorator(func): + wraps(func) + + def wrapper(*args, **kwargs): + ret = func(*args, **kwargs) + return change_dtype(ret, keys, dtypes, np) + + return wrapper + + return decorator + + +def process_input(keys=None, dtypes=None, np=True): + check_consistent(keys, dtypes) + + def decorator(func): + wraps(func) + + def wrapper(*args, **kwargs): + args = list(args) + kwargs = dict(kwargs) + args = change_dtype(args, keys, dtypes, np) + kwargs = change_dtype(kwargs, keys, dtypes, np) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/gello/data_utils/gdict/file/__init__.py b/gello/data_utils/gdict/file/__init__.py new file mode 100644 index 00000000..a885013c --- /dev/null +++ b/gello/data_utils/gdict/file/__init__.py @@ -0,0 +1,2 @@ +from .serialization import * +from .hdf5_utils import load_hdf5, dump_hdf5 diff --git a/gello/data_utils/gdict/file/hdf5_utils.py b/gello/data_utils/gdict/file/hdf5_utils.py new file mode 100644 index 00000000..e83f7931 --- /dev/null +++ b/gello/data_utils/gdict/file/hdf5_utils.py @@ -0,0 +1,111 @@ +from h5py import File, Group, Dataset +import numpy as np + +from ..data import is_h5, is_list_of, is_dict, is_arr, to_np, is_str, is_not_null + + +def load_hdf5(file, keys=None): + """ + Load all elements in HDF5 + """ + + def _load_hdf5(file, load_keys, only_one): + only_one = only_one and load_keys is not None + if is_not_null(load_keys): + new_keys = {} + for key in load_keys: + if key[0] not in new_keys: + new_keys[key[0]] = [] + if len(key) > 1: + new_keys[key[0]].append(key[1:]) + load_keys = {key: (None if len(item) == 0 else item) for key, item in new_keys.items()} + if isinstance(file, (File, Group)): + keys = list(file.keys()) + if keys[0].startswith("list"): + ret = [] + for key in range(len(keys)): + if is_not_null(load_keys) and f"{key}" not in load_keys: + continue + load_keys_i = load_keys[f"{key}"] if is_not_null(load_keys) else None + key = f"list_{type(key).__name__}_{key}" + ret.append(_load_hdf5(file[key], load_keys_i, only_one)) + ret = ret[0] if only_one else ret + elif keys[0].startswith("dict"): + ret = {} + for key in keys: + if key.startswith("dict"): + key_type = eval(key.split("_")[1]) + key_value = key_type(key[len(f"dict_{key.split('_')[1]}_") :]) + else: + key_value = key + if is_not_null(load_keys) and f"{key_value}" not in load_keys: + continue + load_keys_i = load_keys[f"{key_value}"] if is_not_null(load_keys) else None + ret[key_value] = _load_hdf5(file[key], load_keys_i, only_one) + ret = ret[list(ret.keys())[0]] if only_one and len(ret) > 0 else ret + elif len(keys) == 1 and keys[0] == "GDict": + ret = _load_hdf5(file["GDict"], load_keys, only_one) + else: + ret = {} + for key in keys: + if key.startswith("int__"): + key_value = key[len("int__") :] + else: + key_value = key + # print(key_value, load_keys, key_value in load_keys) + if is_not_null(load_keys) and f"{key_value}" not in load_keys: + continue + load_keys_i = load_keys[f"{key_value}"] if is_not_null(load_keys) else None + ret[key_value] = _load_hdf5(file[key], load_keys_i, only_one) + ret = ret[list(ret.keys())[0]] if only_one and len(ret) > 0 else ret + return ret + elif isinstance(file, Dataset): + assert load_keys is None or len(load_keys) == 0, f"{load_keys}" + ret = file[()] + if isinstance(ret, np.void): + from .serialization import load + from io import BytesIO + + return load(BytesIO(ret), file_format="pkl") + else: + return ret + + if is_str(keys): + keys = [keys] + only_one = True + else: + only_one = False + if is_not_null(keys): + keys = [key.strip("/").replace("//", "/").split("/") for key in keys] + if not is_h5(file): + file = File(file, "r") + ret = _load_hdf5(file, keys, only_one) + file.close() + else: + ret = _load_hdf5(file, keys, only_one) + return ret + + +def dump_hdf5(obj, file): + def _dump_hdf5(memory, file, root_key=""): + if isinstance(memory, (list, dict)): + keys = range(len(memory)) if is_list_of(memory) else memory.keys() + for key in keys: + _dump_hdf5(memory[key], file, f"{root_key}/{type(memory).__name__}_{type(key).__name__}_{key}") + else: + root_key = root_key.replace("//", "/") if root_key != "" else "GDict" + if is_arr(memory): + memory = to_np(memory) + file[root_key] = memory + else: + from .serialization import dump + + file[root_key] = np.void(dump(memory, file_format="pkl")) + + if not is_h5(file): + file = File(file, "w") + _dump_hdf5(obj, file, "") + file.close() + else: + assert isinstance(file, Group) + _dump_hdf5(obj, file, file.name) diff --git a/gello/data_utils/gdict/file/path_utils.py b/gello/data_utils/gdict/file/path_utils.py new file mode 100644 index 00000000..9d072464 --- /dev/null +++ b/gello/data_utils/gdict/file/path_utils.py @@ -0,0 +1,182 @@ +import os, os.path as osp, shutil +from pathlib import Path +import glob + + +def to_abspath(x): + return osp.abspath(x) + + +def get_filename(x): + return osp.basename(str(x)) + + +def get_dirname(x): + return osp.dirname(str(x)) + + +def get_filename_suffix(x): + return get_filename(x).split(".")[-1] + + +def is_filepath(x): + return isinstance(x, str) or isinstance(x, Path) + + +def add_suffix_to_filename(x, suffix=""): + dirname = get_dirname(x) + filename = get_filename(x) + dot_split = filename.split(".") + dot_split[-2] += f"_{suffix}" + return osp.join(dirname, ".".join(dot_split)) + + +def replace_suffix(x, suffix=""): + dirname = get_dirname(x) + filename = get_filename(x) + name_split = filename.split(".") + name_split[-1] = suffix + return osp.join(dirname, ".".join(name_split)) + + +def fopen(filepath, *args, **kwargs): + if isinstance(filepath, str): + return open(filepath, *args, **kwargs) + elif isinstance(filepath, Path): + return filepath.open(*args, **kwargs) + raise ValueError("`filepath` should be a string or a Path") + + +def check_files_exist(filenames, msg_tmpl='file "{}" does not exist'): + if isinstance(filenames, str): + filenames = [filenames] + for filename in filenames: + if not osp.isfile(str(filename)): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def mkdir_or_exist(dir_name, mode=0o777): + if dir_name == "": + return + dir_name = str(dir_name) + dir_name = osp.expanduser(dir_name) + os.makedirs(dir_name, mode=mode, exist_ok=True) + + +def symlink(src, dst, overwrite=True, **kwargs): + src, dst = str(src), str(dst) + if os.path.lexists(dst) and overwrite: + os.remove(dst) + os.symlink(src, dst, **kwargs) + + +def copy_folder(from_path, to_path, overwrite=True): + print(f"Copy files from {from_path} to {to_path}") + from_path = str(from_path) + to_path = str(to_path) + if os.path.exists(to_path) and overwrite: + shutil.rmtree(to_path) + shutil.copytree(from_path, to_path) + + +def copy_folders(source_dir, folder_list, target_dir, overwrite=True): + assert all(["/" not in _ for _ in folder_list]) + for i in folder_list: + copy_folder(osp.join(source_dir, i), osp.join(target_dir, i), overwrite) + + +def scandir(dir_path, suffix=None, recursive=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str | obj:`Path`): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the directory. Default: False. + Returns: + A generator for all the interested files with relative pathes. + """ + if isinstance(dir_path, (str, Path)): + dir_path = str(dir_path) + else: + raise TypeError('"dir_path" must be a string or Path object') + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith(".") and entry.is_file(): + rel_path = osp.relpath(entry.path, root) + if suffix is None: + yield rel_path + elif rel_path.endswith(suffix): + yield rel_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def find_vcs_root(path, markers=(".git",)): + """Finds the root directory (including itself) of specified markers. + Args: + path (str): Path of directory or file. + markers (list[str], optional): List of file or directory names. + Returns: + The directory contained one of the markers or None if not found. + """ + if osp.isfile(path): + path = osp.dirname(path) + + prev, cur = None, osp.abspath(osp.expanduser(path)) + while cur != prev: + if any(osp.exists(osp.join(cur, marker)) for marker in markers): + return cur + prev, cur = cur, osp.split(cur)[0] + return None + + +def parse_files(filenames): + """ + filenames can contain four types of files: txt, h5, record, record_episode + """ + from maniskill2_learn.utils.data import is_seq_of, concat_list + from maniskill2_learn.utils.file import load + + supported_types = ["txt", "h5", "record", "record_episode"] + ret_names = [] + if isinstance(filenames, str): + filenames = [filenames] + assert is_seq_of(filenames, str) + + def process_txt(file): + file = load(file) + replacements = (",", ";") + for r in replacements: + file = file.replace(r, " ") + return file.split() + + for name in filenames: + name = osp.expanduser(name) + if not osp.exists(name): + continue + if osp.isdir(name): + for file_type in supported_types: + files = list(glob.glob(osp.join(name, "**", f"*.{file_type}"))) + list(glob.glob(osp.join(name, f"*.{file_type}"))) + if len(files) == 0: + continue + if file_type == "txt": + files = parse_files(concat_list([process_txt(_) for _ in files])) + ret_names += files + else: + file_suffix = get_filename_suffix(name) + if file_suffix == "txt": + ret_names += parse_files(concat_list([process_txt(_) for _ in files])) + elif file_suffix in supported_types: + ret_names.append(name) + return ret_names diff --git a/gello/data_utils/gdict/file/serialization/__init__.py b/gello/data_utils/gdict/file/serialization/__init__.py new file mode 100644 index 00000000..28ef9519 --- /dev/null +++ b/gello/data_utils/gdict/file/serialization/__init__.py @@ -0,0 +1,3 @@ +from .handlers import * +from .io import dump, load, register_handler +from .utils import dict_from_file, list_from_file, serialize, deserialize diff --git a/gello/data_utils/gdict/file/serialization/handlers/__init__.py b/gello/data_utils/gdict/file/serialization/handlers/__init__.py new file mode 100644 index 00000000..ec674453 --- /dev/null +++ b/gello/data_utils/gdict/file/serialization/handlers/__init__.py @@ -0,0 +1,6 @@ +from .base import BaseFileHandler +from .json_handler import JsonHandler +from .pickle_handler import PickleHandler, PickleProtocol +from .yaml_handler import YamlHandler +from .csv_handler import CSVHandler +from .txt_handler import TxtHandler diff --git a/gello/data_utils/gdict/file/serialization/handlers/base.py b/gello/data_utils/gdict/file/serialization/handlers/base.py new file mode 100644 index 00000000..1f3a64ed --- /dev/null +++ b/gello/data_utils/gdict/file/serialization/handlers/base.py @@ -0,0 +1,23 @@ +from abc import ABCMeta, abstractmethod + + +class BaseFileHandler(metaclass=ABCMeta): + @abstractmethod + def load_from_fileobj(self, file, **kwargs): + pass + + @abstractmethod + def dump_to_fileobj(self, obj, file, **kwargs): + pass + + @abstractmethod + def dump_to_str(self, obj, **kwargs): + pass + + def load_from_path(self, filepath, mode='r', **kwargs): + with open(filepath, mode) as f: + return self.load_from_fileobj(f, **kwargs) + + def dump_to_path(self, obj, filepath, mode='w', **kwargs): + with open(filepath, mode) as f: + self.dump_to_fileobj(obj, f, **kwargs) diff --git a/gello/data_utils/gdict/file/serialization/handlers/csv_handler.py b/gello/data_utils/gdict/file/serialization/handlers/csv_handler.py new file mode 100644 index 00000000..1b3048ae --- /dev/null +++ b/gello/data_utils/gdict/file/serialization/handlers/csv_handler.py @@ -0,0 +1,20 @@ +import csv, io +from .base import BaseFileHandler + + +class CSVHandler(BaseFileHandler): + + def load_from_fileobj(self, file, use_eval=False, **kwargs): + ret = list(csv.reader(file, **kwargs)) + if use_eval: + ret = [[eval(__) for __ in _] for _ in ret] + return ret + + def dump_to_fileobj(self, obj, file, **kwargs): + csv_writer = csv.writer(file, **kwargs) + csv_writer.writerows(obj) + + def dump_to_str(self, obj, **kwargs): + output = io.StringIO() + self.dump_to_fileobj(output, obj) + return output.getvalue() diff --git a/gello/data_utils/gdict/file/serialization/handlers/json_handler.py b/gello/data_utils/gdict/file/serialization/handlers/json_handler.py new file mode 100644 index 00000000..6c01b707 --- /dev/null +++ b/gello/data_utils/gdict/file/serialization/handlers/json_handler.py @@ -0,0 +1,30 @@ +import json, numpy as np +from .base import BaseFileHandler + + +def set_default(obj): + """Set default json values for non-serializable values. + It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. + It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, etc.) into plain numbers of plain python + built-in types. + """ + if isinstance(obj, (set, range)): + return list(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.generic): + return obj.item() + raise TypeError(f'{type(obj)} is unsupported for json dump') + + +class JsonHandler(BaseFileHandler): + def load_from_fileobj(self, file): + return json.load(file) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('default', set_default) + json.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('default', set_default) + return json.dumps(obj, **kwargs) diff --git a/gello/data_utils/gdict/file/serialization/handlers/pickle_handler.py b/gello/data_utils/gdict/file/serialization/handlers/pickle_handler.py new file mode 100644 index 00000000..069074d0 --- /dev/null +++ b/gello/data_utils/gdict/file/serialization/handlers/pickle_handler.py @@ -0,0 +1,57 @@ +import pickle, importlib, bz2, gzip +from .base import BaseFileHandler +from ...path_utils import get_filename_suffix + + +class PickleProtocol: + def __init__(self, level): + self.previous = pickle.HIGHEST_PROTOCOL + self.level = level + + def __enter__(self): + importlib.reload(pickle) + pickle.HIGHEST_PROTOCOL = self.level + + def __exit__(self, *exc): + importlib.reload(pickle) + pickle.HIGHEST_PROTOCOL = self.previous + + +class PickleHandler(BaseFileHandler): + + def load_from_fileobj(self, file, **kwargs): + return pickle.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('protocol', 5) + pickle.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('protocol', 5) + return pickle.dumps(obj, **kwargs) + + def load_from_path(self, filepath, **kwargs): + file_suffix = get_filename_suffix(filepath) + assert file_suffix in ['pkl', 'pgz', 'pbz2'], f'{file_suffix} is not supported. Please use of pkl, pgz, pbz2' + if file_suffix == 'pkl': + with open(filepath, 'rb') as f: + return self.load_from_fileobj(f, **kwargs) + elif file_suffix == 'pgz': + with gzip.GzipFile(filepath, 'r') as f: + return self.load_from_fileobj(f, **kwargs) + elif file_suffix == 'pbz2': + with bz2.BZ2File(filepath, 'r') as f: + return self.load_from_fileobj(f, **kwargs) + + def dump_to_path(self, obj, filepath, **kwargs): + file_suffix = get_filename_suffix(filepath) + assert file_suffix in ['pkl', 'pgz', 'pbz2'], f'{file_suffix} is not supported. Please use of pkl, pgz, pbz2' + if file_suffix == 'pkl': + with open(filepath, 'wb') as f: + return self.dump_to_fileobj(obj, f, **kwargs) + elif file_suffix == 'pgz': + with gzip.GzipFile(filepath, 'w') as f: + return self.dump_to_fileobj(obj, f, **kwargs) + elif file_suffix == 'pbz2': + with bz2.BZ2File(filepath, 'w') as f: + return self.dump_to_fileobj(obj, f, **kwargs) diff --git a/gello/data_utils/gdict/file/serialization/handlers/txt_handler.py b/gello/data_utils/gdict/file/serialization/handlers/txt_handler.py new file mode 100644 index 00000000..9aad6d30 --- /dev/null +++ b/gello/data_utils/gdict/file/serialization/handlers/txt_handler.py @@ -0,0 +1,12 @@ +from .base import BaseFileHandler + + +class TxtHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + return file.read() + + def dump_to_fileobj(self, obj, file, **kwargs): + file.write(str(obj)) + + def dump_to_str(self, obj, **kwargs): + return str(obj) diff --git a/gello/data_utils/gdict/file/serialization/handlers/yaml_handler.py b/gello/data_utils/gdict/file/serialization/handlers/yaml_handler.py new file mode 100644 index 00000000..4a8dd74b --- /dev/null +++ b/gello/data_utils/gdict/file/serialization/handlers/yaml_handler.py @@ -0,0 +1,21 @@ +import yaml +try: + from yaml import CLoader as Loader, CDumper as Dumper +except ImportError: + from yaml import Loader, Dumper + +from .base import BaseFileHandler + + +class YamlHandler(BaseFileHandler): + def load_from_fileobj(self, file, **kwargs): + kwargs.setdefault('Loader', Loader) + return yaml.load(file, **kwargs) + + def dump_to_fileobj(self, obj, file, **kwargs): + kwargs.setdefault('Dumper', Dumper) + yaml.dump(obj, file, **kwargs) + + def dump_to_str(self, obj, **kwargs): + kwargs.setdefault('Dumper', Dumper) + return yaml.dump(obj, **kwargs) diff --git a/gello/data_utils/gdict/file/serialization/io.py b/gello/data_utils/gdict/file/serialization/io.py new file mode 100644 index 00000000..e3f00e8b --- /dev/null +++ b/gello/data_utils/gdict/file/serialization/io.py @@ -0,0 +1,81 @@ +from pathlib import Path +import os.path as osp +from io import BytesIO +from ...data import is_list_of, is_str +from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler, CSVHandler, TxtHandler + +file_handlers = { + "json": JsonHandler(), + "yaml": YamlHandler(), + "yml": YamlHandler(), + "pickle": PickleHandler(), + "pkl": PickleHandler(), + "pgz": PickleHandler(), + "pbz2": PickleHandler(), + "csv": CSVHandler(), + "txt": TxtHandler(), +} + + +def load(file, file_format=None, **kwargs): + if isinstance(file, Path): + file = str(file) + if isinstance(file, bytes): + file = BytesIO(file) + if is_str(file): + if file_format is None: + file_format = file.split(".")[-1] + if file_format not in file_handlers: + raise TypeError(f"Unsupported format: {file_format}") + + handler = file_handlers[file_format] + if is_str(file): + file = osp.expanduser(file) + obj = handler.load_from_path(file, **kwargs) + elif hasattr(file, "read"): + obj = handler.load_from_fileobj(file, **kwargs) + else: + raise TypeError('"file" must be a filepath str or a file-object') + return obj + + +def dump(obj, file=None, file_format=None, **kwargs): + if isinstance(file, Path): + file = str(file) + if file_format is None: + if is_str(file): + file_format = file.split(".")[-1] + elif file is None: + raise ValueError("file_format must be specified since file is None") + if file_format not in file_handlers: + raise TypeError(f"Unsupported format: {file_format}") + + handler = file_handlers[file_format] + if file is None: + return handler.dump_to_str(obj, **kwargs) + elif is_str(file): + file = osp.expanduser(file) + handler.dump_to_path(obj, file, **kwargs) + elif hasattr(file, "write"): + handler.dump_to_fileobj(obj, file, **kwargs) + else: + raise TypeError('"file" must be a filename str or a file-object') + + +def _register_handler(handler, file_formats): + if not isinstance(handler, BaseFileHandler): + raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}") + if isinstance(file_formats, str): + file_formats = [file_formats] + if not is_list_of(file_formats, str): + raise TypeError("file_formats must be a str or a list of str") + for ext in file_formats: + file_handlers[ext] = handler + + +def register_handler(file_formats, **kwargs): + def wrap(cls): + _register_handler(cls(**kwargs), file_formats) + return cls + + return wrap diff --git a/gello/data_utils/gdict/file/serialization/utils.py b/gello/data_utils/gdict/file/serialization/utils.py new file mode 100644 index 00000000..623c4c5e --- /dev/null +++ b/gello/data_utils/gdict/file/serialization/utils.py @@ -0,0 +1,57 @@ +import pickle + + +def serialize(obj): + return pickle.dumps(obj) + + +def deserialize(obj): + return pickle.loads(obj) + + +def list_from_file(filename, prefix="", offset=0, max_num=-1): + cnt = 0 + item_list = [] + with open(filename, "r") as f: + for _ in range(offset): + f.readline() + for line in f: + if max_num >= 0 and cnt >= max_num: + break + item_list.append(prefix + line.rstrip("\n")) + cnt += 1 + return item_list + + +def dict_from_file(filename, key_type=str, offset=0, max_num=-1): + mapping = {} + cnt = 0 + with open(filename, "r") as f: + for _ in range(offset): + f.readline() + for line in f: + if max_num >= 0 and cnt >= max_num: + break + items = line.rstrip("\n").split() + assert len(items) >= 2 + key = key_type(items[0]) + val = items[1:] if len(items) > 2 else items[1] + mapping[key] = val + cnt += 1 + return mapping + + +def dict_to_csv_table(x): + ret = [] + for key in x.keys(): + ret.append([key, x[key]]) + return ret + + +def csv_table_to_dict(x): + for y in x: + assert len(y) == 2 + ret = {} + for y in x: + ret[y[0]] = y[1] + return ret diff --git a/gello/data_utils/simple_bc/_interfaces/README.md b/gello/data_utils/simple_bc/_interfaces/README.md new file mode 100644 index 00000000..4e941d53 --- /dev/null +++ b/gello/data_utils/simple_bc/_interfaces/README.md @@ -0,0 +1,20 @@ +`Encoder`s are responsible for handling raw input. + +Classes should be entire algorithms, such as `impala.py` or `ddt.py`. Because `Encoder`s are fully responsible for processing the input, they also must preprocess (i.e. transformations, etc.) the output of the `Data` class. + +Methods: +``` +forward(self, obs): + +loss(self, obs, act): + +update(self, obs, act): + +save(self, path): + +load(self, path): + +_build_network(self, nn_cfg, device): +``` + +TODO: add typing hints and more documentation to this page. \ No newline at end of file diff --git a/gello/data_utils/simple_bc/_interfaces/encoder.py b/gello/data_utils/simple_bc/_interfaces/encoder.py new file mode 100644 index 00000000..856edf83 --- /dev/null +++ b/gello/data_utils/simple_bc/_interfaces/encoder.py @@ -0,0 +1,63 @@ +import torch +from abc import ABC, abstractmethod +import os + + +class Encoder(ABC, torch.nn.Module): + def __init__( + self, + obs_shapes, # dict of shapes of the observations. + out_shape=None, # Shape of the output of the encoder. + num_frames=1, + **kwargs, + ): + "Base class for all encoders." + super().__init__() + self.obs_shapes = obs_shapes + self.out_shape = out_shape + self.num_frames = num_frames + + @staticmethod + def build_encoder(encoder_cfg): + import simple_bc.encoder as e + + Encoder = eval(f"e.{encoder_cfg.name}") + kwargs = dict(encoder_cfg) + + kwargs.pop("name") + if "vit_cfg" in kwargs: + kwargs.update(**kwargs.pop("vit_cfg")) + encoder = Encoder(**kwargs) + encoder.out_shape = eval( + str(encoder.out_shape) + ) # encoder shape may have already been evaluated in constructor + return encoder + + def save(self, path): + """ + Save the encoder's state dict to a file. + """ + save_dir = os.path.dirname(path) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + torch.save(self.state_dict(), path) + + def load(self, path): + """ + Load the encoder's state dict from a file. + """ + self.load_state_dict(torch.load(path)) + + @abstractmethod + def preprocess(self, obs): + """ + Preprocess the observation before passing it through the encoder. + """ + pass + + @abstractmethod + def forward(self, obs): + """ + Forward pass of the encoder. Returns the encoded obs, which is used as input to the policy. + """ + pass diff --git a/gello/data_utils/simple_bc/_interfaces/policy.py b/gello/data_utils/simple_bc/_interfaces/policy.py new file mode 100644 index 00000000..3e014da9 --- /dev/null +++ b/gello/data_utils/simple_bc/_interfaces/policy.py @@ -0,0 +1,51 @@ +import os +from abc import ABC, abstractmethod +import torch +from einops import rearrange + + +class Policy(ABC, torch.nn.Module): + def __init__(self): + super().__init__() + self.cached_actions = [] + + def reset(self): + self.cached_actions = [] + + @staticmethod + def build_policy(encoder_out_shape, policy_cfg, encoder_cfg): + import simple_bc.policy as p + + Policy = eval(f"p.{policy_cfg.name}") + kwargs = dict(policy_cfg) + kwargs.pop("name") + policy = Policy(obs_shape=encoder_out_shape, encoder_cfg=encoder_cfg, **kwargs) + return policy + + def save(self, path): + """ + Save the encoder's state dict to a file. + """ + save_dir = os.path.dirname(path) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + torch.save(self.state_dict(), path) + + def load(self, path): + """ + Load the encoder's state dict from a file. + """ + self.load_state_dict(torch.load(path)) + + @abstractmethod + def forward(self, *args, **kwargs): + pass + + def act(self, *args, **kwargs): + if len(self.cached_actions) == 0 or True: + actions, info = self.forward(*args, **kwargs) + assert len(actions) == 1 # Batch size should be 1 + self.cached_actions = actions[0].detach().cpu().numpy() + self.cached_actions = self.cached_actions[:2] + action, self.cached_actions = self.cached_actions[0], self.cached_actions[1:] + return action, info diff --git a/gello/data_utils/simple_bc/constants.py b/gello/data_utils/simple_bc/constants.py new file mode 100644 index 00000000..1333ebcf --- /dev/null +++ b/gello/data_utils/simple_bc/constants.py @@ -0,0 +1,38 @@ +import os +import os.path as osp + +DATA_DIR = os.environ["DATA_DIR"] +DATASET_DIR = os.environ["DATASET_DIR"] + +BC_DATASET = { + "place_bag": { + "train_dataset": osp.join(DATASET_DIR, "place_bag_train"), + "val_dataset_l1": osp.join(DATASET_DIR, "place_bag_val_l1"), + "val_dataset_l2": osp.join(DATASET_DIR, "place_bag_val_l2"), + }, +} + +ISAACGYM_TASK = { + "open_drawer": { + "expert_policy": osp.join( + DATA_DIR, "yiran_pretrained/open_drawer/model_15400.pt" + ), + "task": "OneFrankaCabinetPCPartialCPMap", + "task_config": "cfg/open_drawer_expert.yaml", + "algo_config": "cfg/ppo_pc_pure/config.yaml", + "algo": "ppo_pc_pure", # Algo used to train the expert policy + "train_iteration": 10, + "save_freq": 1, + }, + "open_door_21": { + "expert_policy": osp.join( + DATA_DIR, "yiran_pretrained/open_door/model_16400.pt" + ), + "task": "OneFrankaCabinetPCPartialCPMap", + "task_config": "cfg/open_door_expert.yaml", + "algo_config": "cfg/ppo_pc_pure/config.yaml", + "algo": "ppo_pc_pure", + "train_iteration": 20, + "save_freq": 5, + }, +} diff --git a/gello/data_utils/simple_bc/dataset/replay_dataset.py b/gello/data_utils/simple_bc/dataset/replay_dataset.py new file mode 100644 index 00000000..19d83e51 --- /dev/null +++ b/gello/data_utils/simple_bc/dataset/replay_dataset.py @@ -0,0 +1,259 @@ +import copy +import numpy as np +import torch +import torchvision.transforms as T +from torch.utils.data import IterableDataset +from torchvision.transforms.functional import InterpolationMode + +from simple_bc.utils.data_utils import load_traj_from_memory, load_traj_files +from simple_bc.utils.torch_utils import pack_one, unpack_one + +""" +This file contains the dataset for BC training. +""" + + +class CachedTrajLoader(object): + def __init__( + self, + shuffle, + stack_idx, + stack_window, + all_traj_cache=dict(), # note that this is shared between all CachedTrajLoader instances + ): + """ + shuffle: If False, the trajectories are loaded in order. If True, the trajectories are shuffled. + """ + self.shuffle = shuffle + self.stack_idx = np.array(stack_idx) + self.stack_window = stack_window + self.num_cached_traj = 4 + self.all_traj_cache = all_traj_cache + + self.worker_filenames, self.unload_filenames = None, None + self.cached_trajs, self.cached_trajs_starts, self.cached_trajs_shuffled_idx = ( + None, + None, + None, + ) + + MAX_LEN = 2000 + ts = np.arange(MAX_LEN) + self.padded_ts = np.concatenate( + [np.zeros(self.stack_window, dtype=np.int32), ts] + ) + + def reset(self, worker_filenames): + "Reset the buffer" + self.worker_filenames = copy.copy(worker_filenames) + self.unload_filenames = copy.copy(worker_filenames) + del self.cached_trajs, self.cached_trajs_starts, self.cached_trajs_shuffled_idx + self.cached_trajs = [] + self.cached_trajs_starts = ( + [] + ) # The starting index of each trajectory which has not been sampled + self.cached_trajs_shuffled_idx = [] + + def _load_to_cache(self): + if len(self.unload_filenames) == 0: + return -1 # No more trajectory to load + + while ( + len(self.cached_trajs) < self.num_cached_traj + and len(self.unload_filenames) > 0 + ): + traj_filename = self.unload_filenames.pop() + if traj_filename in self.all_traj_cache: + traj = self.all_traj_cache[traj_filename] + else: + traj = load_traj_from_memory(traj_filename) + self.cached_trajs.append(traj) + self.cached_trajs_starts.append(0) + + T = len(traj) + ts = np.arange( + T - self.stack_window - 1 + len(self.stack_idx), dtype=np.int32 + ) + if self.shuffle: + np.random.shuffle(ts) + self.cached_trajs_shuffled_idx.append( + ts + ) # The oldest time step of frame stacking + + return 0 + + def sample(self): + """ + Randomly sample one time step from one trajectory. No repeat + Return (traj, ts) the corresponding trajectory and the sampled time steps (For frame stacking) + Return (None, None) if no trajectory is available + """ + if len(self.cached_trajs) == 0: + ret = self._load_to_cache() + if ret == -1: + return None, None + + traj_idx = np.random.randint(len(self.cached_trajs)) if self.shuffle else 0 + traj = self.cached_trajs[traj_idx] + start_t = self.cached_trajs_shuffled_idx[traj_idx][ + self.cached_trajs_starts[traj_idx] + ] + ts = self.padded_ts[start_t + self.stack_idx] + self.cached_trajs_starts[traj_idx] += 1 + + if self.cached_trajs_starts[traj_idx] == len( + self.cached_trajs_shuffled_idx[traj_idx] + ): + self.cached_trajs.pop(traj_idx) + self.cached_trajs_starts.pop(traj_idx) + self.cached_trajs_shuffled_idx.pop(traj_idx) + return traj, ts + + +class ReplayDataset(IterableDataset): + def __init__( + self, + dataset_dir, + aug_cfg, + token_name, + shuffle=True, + obs_shapes=None, + act_shape=None, + action_horizon=1, + stride=1, + cache_all_traj=False, + **kwargs, + ): + self.dataset_dir = dataset_dir + self.obs_shapes = obs_shapes + self.act_shape = act_shape + self.token_name = token_name + self.aug_cfg = aug_cfg + self.shuffle = shuffle + self.cache_all_traj = cache_all_traj + + self.stride = stride + + # stacking + self.stack_idx = aug_cfg.stack_idx + self.stack_window = max(aug_cfg.stack_idx) + + # augmentation + if self.aug_cfg.aug_prob > 0: + self.aug_transform = T.Compose( + [ + T.RandomResizedCrop( + size=224, + scale=(0.7, 1.0), + interpolation=InterpolationMode.BILINEAR, + antialias=False, + ), + T.ColorJitter(brightness=0.3), + ] + ) + else: + self.aug_transform = None + + # proprioception: + self.use_proprio = aug_cfg.use_proprio + if "use_rotation" in aug_cfg: + self.use_rotation = aug_cfg.use_rotation + else: + self.use_rotation = True + if "use_mv" in aug_cfg: + self.use_mv = aug_cfg.use_mv + else: + self.use_mv = True + + if "use_depth" in aug_cfg: + self.use_depth = aug_cfg.use_depth + if not self.use_depth: + print("ReplayDataset: not using depth.") + else: + self.use_depth = True + self.buffer_filenames = load_traj_files( + self.dataset_dir, self.token_name, stride=self.stride + ) + self.all_trajs = {} + if self.cache_all_traj: + for traj_filename in self.buffer_filenames: + self.all_trajs[traj_filename] = load_traj_from_memory(traj_filename) + print(f"ReplayDataset: cached all trajectories.") + + assert ( + len(self.buffer_filenames) > 0 + ), "No trajectories found in the specified folders." + print( + f"ReplayDataset: found {len(self.buffer_filenames)} trajectories in the specified folders." + ) + + # for multiprocessing on workers. see utils/_worker_init_fn. + self.world_rng = None + self.worker_filenames = [] + self.worker_start, self.worker_end = None, None + + self.cached_rgb_frames = None + self.action_horizon = action_horizon + self.traj_loader = CachedTrajLoader( + shuffle=shuffle, + stack_idx=self.stack_idx, + stack_window=self.stack_window, + all_traj_cache=self.all_trajs, + ) + + def __iter__(self): + self.reset_buffer() + + return self.__next__() + + def __next__(self): + while True: + traj, ts = self.traj_loader.sample() + if ts is None: + break + ret_traj = traj.slice(ts) + + ret_traj["obs"]["rgb"] = self._augment_frames(ret_traj["obs"]["rgb"]) + if not self.use_depth: + ret_traj["obs"]["depth"] = np.zeros_like(ret_traj["obs"]["depth"]) + elif not self.use_proprio: + ret_traj["obs"]["state"] = np.zeros_like(ret_traj["obs"]["state"]) + last_t = ts[-1] + actions = traj["actions"][last_t : last_t + self.action_horizon] + # pad actions to the same length + if len(actions) < self.action_horizon: + actions = np.concatenate( + [ + actions, + np.zeros((self.action_horizon - len(actions), *self.act_shape)), + ] + ) + ret_traj["actions"] = actions + ret_traj["dones"] = np.array([ts[-1] == len(traj) - 1], dtype=np.int32) + ret_traj["steps"] = np.array([ts[-1]], dtype=np.int32) + yield ret_traj + + def _augment_frames(self, rgb_frames): + """ + Augment the trajectory according to the augmentation config. + This includes RGB augmentation and frame stacking. + """ + p = np.random.rand() + if p < self.aug_cfg.aug_prob and self.aug_transform is not None: + rgb_frames, sh = pack_one(rgb_frames, "* c h w") + rgb_frames = torch.from_numpy(rgb_frames).float() / 255.0 + rgb_frames = self.aug_transform(rgb_frames).numpy() + ret = unpack_one(rgb_frames, sh, "* c h w") + # back to [0, 255] + ret = np.clip(ret * 255, 0, 255).astype(np.uint8) + return ret + else: + return rgb_frames + + def reset_buffer(self): + if self.shuffle: + self.world_rng.shuffle(self.buffer_filenames) + self.worker_filenames = self.buffer_filenames[ + self.worker_start : self.worker_end + ] + self.traj_loader.reset(self.worker_filenames) diff --git a/gello/data_utils/simple_bc/dataset/replay_online.py b/gello/data_utils/simple_bc/dataset/replay_online.py new file mode 100644 index 00000000..be4d5201 --- /dev/null +++ b/gello/data_utils/simple_bc/dataset/replay_online.py @@ -0,0 +1,152 @@ +import numpy as np +import torch +from einops import rearrange + +import gdict +from simple_bc.utils.torch_utils import pack_one, to_cpu, unpack_one +import torchvision.transforms as T +from torchvision.transforms.functional import InterpolationMode +from einops import repeat + +""" +This file contains the dataset for DAgger training. +""" + + +class ReplayBuffer(object): + def __init__( + self, + aug_cfg, + shuffle=True, + obs_shapes=None, + act_shape=None, + action_horizon=1, + MAX_LEN=200, + use_rotation=True, # Not used + use_proprio=True, + ): + self.obs_shapes = obs_shapes + self.act_shape = act_shape + self.aug_cfg = aug_cfg + self.shuffle = shuffle + self.action_horizon = action_horizon + + # stacking + self.stack_idx = aug_cfg.stack_idx + self.stack_window = max(aug_cfg.stack_idx) + 1 + + # Buffer + self.all_trajs = [] + self.MAX_LEN = MAX_LEN + self.num_trajs, self.num_timesteps = 0, 0 + + # random shift augmentation, as described in DRQ v1 by Denis Yarats + # and in "Revisiting LfS Baseline" by Hansen et al. + if self.aug_cfg.aug_prob > 0: + aug_pad_length = 5 + aug_crop_min = 0.7 + + self.aug_transform = T.Compose( + [ + T.Pad(aug_pad_length, padding_mode="edge"), + T.RandomResizedCrop( + size=224, + scale=(aug_crop_min, 1.0), + interpolation=InterpolationMode.BILINEAR, + ), + ] + ) + + else: + self.aug_transform = None + + def _augment_frames(self, rgb_frames, depth_frames): + """ + Augment the trajectory according to the augmentation config. + This includes RGB augmentation and frame stacking. + """ + p = np.random.rand() + if p < self.aug_cfg.aug_prob and self.aug_transform is not None: + rgb_frames, sh = pack_one(rgb_frames, "* c h w") + depth_frames, _ = pack_one(depth_frames, "* c h w") + depth_frames = repeat(depth_frames, "b c h w -> b (c d) h w", d=3) + + all_frames = torch.cat( + [rgb_frames.float() / 255.0, depth_frames / torch.max(depth_frames)], + axis=0, + ) + all_frames = self.aug_transform(all_frames) + B, _, _, _ = rgb_frames.shape + aug_rgb_frames, aug_depth_frames = all_frames[:B], all_frames[B:] + + aug_rgb_frames *= 255.0 + aug_rgb_frames = unpack_one(aug_rgb_frames, sh, "* c h w") + + aug_depth_frames *= torch.max(depth_frames) + aug_depth_frames = aug_depth_frames[:, 0:1] # depth is 1-channel + aug_depth_frames = unpack_one(aug_depth_frames, sh, "* c h w") + + return aug_rgb_frames, aug_depth_frames + else: + return rgb_frames, depth_frames + + def add_traj(self, obses, actions): + """List (time) of obses and actions as input. + Each action is B x A. Save into list of T x A. + """ + actions = rearrange(torch.stack(actions, dim=0), "t b a -> b t a") + rgb = rearrange( + torch.stack([obs["rgb"] for obs in obses], dim=0), "t b ... -> b t ..." + ) + depth = rearrange( + torch.stack([obs["depth"] for obs in obses], dim=0), "t b ... -> b t ..." + ) + state = rearrange( + torch.stack([obs["state"] for obs in obses], dim=0), "t b ... -> b t ..." + ) + actions, rgb, depth, state = ( + to_cpu(actions), + to_cpu(rgb), + to_cpu(depth), + to_cpu(state), + ) + num_traj = len(actions) + for i in range(num_traj): + traj = gdict.GDict( + { + "actions": actions[i], + "obs": gdict.GDict( + {"state": state[i], "rgb": rgb[i], "depth": depth[i]} + ), + } + ) + self.all_trajs.append(traj) + if len(self.all_trajs) > self.MAX_LEN: + self.all_trajs = self.all_trajs[-self.MAX_LEN :] + self.num_trajs += num_traj + self.num_timesteps += num_traj * actions.shape[1] + + def sample(self, batch_size): + traj_id = np.random.randint(len(self.all_trajs), size=batch_size) + trajs = [self.all_trajs[i] for i in traj_id] + + Ts = ( + np.array([traj["actions"].shape[0] for traj in trajs]) + - self.stack_window + + 1 + ) + ts = np.random.randint(Ts) + ret_traj = [ + trajs[i].slice(np.arange(ts[i], ts[i] + self.stack_window)) + for i in range(batch_size) + ] + + ret_traj = gdict.GDict.stack(ret_traj, axis=0) + + ret_traj["obs"]["rgb"], ret_traj["obs"]["depth"] = self._augment_frames( + ret_traj["obs"]["rgb"], ret_traj["obs"]["depth"] + ) + + if self.action_horizon > 1: + pass + return ret_traj diff --git a/gello/data_utils/simple_bc/encoder/__init__.py b/gello/data_utils/simple_bc/encoder/__init__.py new file mode 100644 index 00000000..8e5a3f2c --- /dev/null +++ b/gello/data_utils/simple_bc/encoder/__init__.py @@ -0,0 +1,4 @@ +from .impala import IMPALA +from .spawnnet import SpawnNet +from .r3m_encoder import R3MEncoder +from .vit_descriptor import ViTDescriptor \ No newline at end of file diff --git a/gello/data_utils/simple_bc/encoder/impala.py b/gello/data_utils/simple_bc/encoder/impala.py new file mode 100644 index 00000000..c238473b --- /dev/null +++ b/gello/data_utils/simple_bc/encoder/impala.py @@ -0,0 +1,203 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +from simple_bc._interfaces.encoder import Encoder +from simple_bc.utils.torch_utils import to_torch, to_numpy + +import hydra +from omegaconf import DictConfig + + +class IMPALA(Encoder): + def __init__(self, + in_channels, + shape, + use_depth=True, + large=True, + larger=True, + num_views=2, + **kwargs): + super().__init__(**kwargs) + self.feat_convs = [] + self.resnet1 = [] + self.resnet2 = [] + self.convs = [] + + self.use_depth = use_depth + self.num_views = num_views + + H, W = shape + if larger: + fcs = [128, 128, 128] + else: + fcs = [64, 64, 64] + self.shape = [H, W] + + self.large = large + if self.large: + in_channels = 4 + print("IMPALA: using large network, so using 4 channels, and not convolving over time and views.") + self.stem = nn.Conv2d(in_channels, fcs[0], kernel_size=4, stride=4) + in_channels = fcs[0] + + for num_ch in fcs: + feats_convs = [] + feats_convs.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=num_ch, + kernel_size=3, + stride=1, + padding=1, + ) + ) + feats_convs.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + self.feat_convs.append(nn.Sequential(*feats_convs)) + in_channels = num_ch + for i in range(2): + resnet_block = [] + resnet_block.append(nn.ReLU()) + resnet_block.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=num_ch, + kernel_size=3, + stride=1, + padding=1, + ) + ) + resnet_block.append(nn.ReLU()) + resnet_block.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=num_ch, + kernel_size=3, + stride=1, + padding=1, + ) + ) + if i == 0: + self.resnet1.append(nn.Sequential(*resnet_block)) + else: + self.resnet2.append(nn.Sequential(*resnet_block)) + + self.feat_convs = nn.ModuleList(self.feat_convs) + self.resnet1 = nn.ModuleList(self.resnet1) + self.resnet2 = nn.ModuleList(self.resnet2) + self.img_feat_size = (H * W) // (4 ** len(fcs) * 16) * fcs[-1] + + if self.large: + self.fc = nn.Identity() + self.out_shape = self.img_feat_size * self.num_views * self.num_frames + else: + self.fc = nn.Linear(self.img_feat_size, self.out_shape) + + self._update_out_shape(self.out_shape) + + def preprocess(self, obs): + with torch.no_grad(): + feature = [] + if "rgb" in self.obs_shapes and "rgb" in obs: + B = len(obs["rgb"]) + assert ( + obs["rgb"].shape[-3:] == self.obs_shapes["rgb"] + ), f"Observation shape of rgb is {obs['rgb'].shape}, but should be {(B, *self.obs_shapes['rgb'])}" + # B, F (Frames), V (views), C, H, W + if self.large: + rgb = rearrange(obs["rgb"], "b f v c h w -> (b f v) c h w") + else: + rgb = rearrange(obs["rgb"], "b f v c h w -> b (f v c) h w") + feature.append((rgb / 255.0)) + + if "depth" in self.obs_shapes and "depth" in obs: + B = len(obs["depth"]) + assert ( + obs["depth"].shape[-3:] == self.obs_shapes["depth"] + ), f"Observation shape of depth is {obs['depth'].shape}, but should be {(B, *self.obs_shapes['depth'])}" + + if self.large: + depth = rearrange(obs["depth"], "b f v c h w -> (b f v) c h w") + else: + depth = rearrange(obs["depth"], "b f v c h w -> b (f v c) h w") + + if not self.use_depth: + depth = torch.zeros_like(depth) + + feature.append(depth) + + feature = torch.cat(feature, dim=1) + if "state" in obs: + state = obs["state"] + else: + state = None + + return feature, state + + def forward(self, obs): + """Return feature and info.""" + feature, state = self.preprocess(obs) + x = self.stem(feature) + res_input = None + + for i, fconv in enumerate(self.feat_convs): + x = fconv(x) + res_input = x + x = self.resnet1[i](x) + x += res_input + res_input = x + x = self.resnet2[i](x) + x += res_input + + x = F.relu(x) + x = x.reshape(x.shape[0], self.img_feat_size) + x = F.relu(self.fc(x)) + + if self.large: + x = rearrange(x, "(b f v) c -> b (f v c)", f=self.num_frames, v=self.num_views) + + B = x.shape[0] + state = state.view(B, -1) + + out = torch.cat([x, state], dim=1) + return out, {} + + def _update_out_shape(self, out_shape): + if "state" in self.obs_shapes: + state_shape = np.prod(self.obs_shapes["state"]) * self.num_frames + print(f"IMPALA: updated out shape to {self.out_shape + state_shape}") + self.out_shape = out_shape + state_shape + else: + self.out_shape = out_shape + + +@hydra.main(config_path="../../conf/encoder", config_name="impala", version_base="1.1") +def test(cfg): + cfg = DictConfig(cfg) + cfg.in_channels = 32 + cfg.shape = [224, 224] + cfg.obs_shapes = {"rgb": [3, 224, 224], "depth": [1, 224, 224], "state": [26]} + cfg.num_frames = 4 + cfg.out_shape = 384 + print(cfg) + + encoder = IMPALA(**cfg) + print(encoder) + + obs = { + "rgb": to_torch(torch.randn(6, 4, 2, 3, 224, 224)), + "depth": to_torch(torch.randn(6, 4, 2, 1, 224, 224)), + "state": to_torch(torch.randn(6, 4, 26)), + } + + out, _ = encoder(obs) + assert ( + out.shape[1:] == encoder.out_shape + ), f"out shape is {out.shape}, but should be {encoder.out_shape}" + + +if __name__ == "__main__": + test() diff --git a/gello/data_utils/simple_bc/encoder/r3m_encoder.py b/gello/data_utils/simple_bc/encoder/r3m_encoder.py new file mode 100644 index 00000000..0dae313e --- /dev/null +++ b/gello/data_utils/simple_bc/encoder/r3m_encoder.py @@ -0,0 +1,94 @@ +import r3m +import os +from gpu_info import vulkan_cuda_idxes + +from simple_bc._interfaces.encoder import Encoder +from einops import rearrange, repeat +from r3m.models.models_r3m import R3M + +import os +from os.path import expanduser +import omegaconf +import hydra +import gdown +import torch + + +def load_r3m(modelid): + # copied from Suraj Nair's R3M repo, repurposed for different HYDRA launchers + home = os.path.join(expanduser("~"), ".r3m") + if "HYDRA_LAUNCHER" not in os.environ: + r3m.device = 0 + else: + cuda_gpu_idxes, _ = vulkan_cuda_idxes(os.environ["HYDRA_LAUNCHER"], 1) + r3m.device = cuda_gpu_idxes[0] + if modelid == "resnet50": + foldername = "r3m_50" + modelurl = "https://drive.google.com/uc?id=1Xu0ssuG0N1zjZS54wmWzJ7-nb0-7XzbA" + configurl = "https://drive.google.com/uc?id=10jY2VxrrhfOdNPmsFdES568hjjIoBJx8" + elif modelid == "resnet34": + foldername = "r3m_34" + modelurl = "https://drive.google.com/uc?id=15bXD3QRhspIRacOKyWPw5y2HpoWUCEnE" + configurl = "https://drive.google.com/uc?id=1RY0NS-Tl4G7M1Ik_lOym0b5VIBxX9dqW" + elif modelid == "resnet18": + foldername = "r3m_18" + modelurl = "https://drive.google.com/uc?id=1A1ic-p4KtYlKXdXHcV2QV0cUzI4kn0u-" + configurl = "https://drive.google.com/uc?id=1nitbHQ-GRorxc7vMUiEHjHWP5N11Jvc6" + else: + raise NameError("Invalid Model ID") + + if not os.path.exists(os.path.join(home, foldername)): + os.makedirs(os.path.join(home, foldername)) + modelpath = os.path.join(home, foldername, "model.pt") + configpath = os.path.join(home, foldername, "config.yaml") + if not os.path.exists(modelpath): + gdown.download(modelurl, modelpath, quiet=False) + gdown.download(configurl, configpath, quiet=False) + + modelcfg = omegaconf.OmegaConf.load(configpath) + cleancfg = r3m.cleanup_config(modelcfg) + rep = hydra.utils.instantiate(cleancfg) + rep = torch.nn.DataParallel(rep, device_ids=[r3m.device]) + r3m_state_dict = r3m.remove_language_head( + torch.load(modelpath, map_location=torch.device(r3m.device))["r3m"] + ) + rep.load_state_dict(r3m_state_dict) + return rep + + +class R3MEncoder(Encoder): + def __init__(self, model_type="resnet50", freeze_pretrained=True, **kwargs): + super().__init__(**kwargs) + self.model = load_r3m(model_type) + + if freeze_pretrained: + for param in self.model.parameters(): + param.requires_grad = False + + self.out_shape = eval(str(self.out_shape)) + + self.batch_norm = torch.nn.BatchNorm1d( + self.model.module.outdim + ) # as described in R3M paper, this occurs prior to MLP layers + + def forward(self, obs): + # obs is of shape b f v c h w, (state is shape b f d) + img = obs["rgb"] + + B, F, V, _, _, _ = obs["rgb"].shape + img = rearrange(img, "b f v c h w -> (b f v) c h w") + + feats = self.model(img) + feats = self.batch_norm(feats) + feats = rearrange(feats, "(b f v) d -> b (f v d)", b=B, f=F, v=V) + + if "state" in obs: + state = obs["state"] + state = repeat(state, "b f d -> b f v d", v=V) + state = rearrange(state, "b f v d -> b (f v d)") + feats = torch.cat([feats, state], dim=1) + + return feats, {} + + def preprocess(self, obs): # for encoder interface + return obs diff --git a/gello/data_utils/simple_bc/encoder/spawnnet/__init__.py b/gello/data_utils/simple_bc/encoder/spawnnet/__init__.py new file mode 100644 index 00000000..92d6d7d2 --- /dev/null +++ b/gello/data_utils/simple_bc/encoder/spawnnet/__init__.py @@ -0,0 +1 @@ +from .spawnnet import SpawnNet diff --git a/gello/data_utils/simple_bc/encoder/spawnnet/spawnnet.py b/gello/data_utils/simple_bc/encoder/spawnnet/spawnnet.py new file mode 100644 index 00000000..5dca0b4c --- /dev/null +++ b/gello/data_utils/simple_bc/encoder/spawnnet/spawnnet.py @@ -0,0 +1,380 @@ +import math +from typing import List, Tuple +from simple_bc.encoder.r3m_encoder import load_r3m + +import torch +import torch.nn as nn +from omegaconf import DictConfig, OmegaConf +from torchvision import transforms +from einops import rearrange +from simple_bc._interfaces.encoder import Encoder +from simple_bc.utils.torch_utils import get_named_trainable_params, pack_one, unpack_one + +""" +ViT feature extraction largely based on 'Deep ViT Features as Dense Visual Descriptors': +https://github.com/ShirAmir/dino-vit-features/blob/main/extractor.py +""" + + +class SpawnNet(Encoder): + def __init__( + self, + conv_cfg: DictConfig, # 'token_5':1, 'token_8':2, 'token_11':3 + pretrained_feat_info: dict, + model_type="dino_vits8", + stride=8, + freeze_pretrained=False, + freeze_vit_to_random=False, + **kwargs, + ): + """ + :param model_type: A string specifying the type of model to extract from + """ + super(SpawnNet, self).__init__(**kwargs) + self.model_type = model_type + self.conv_cfg = conv_cfg + self.vit = self.build_vit(model_type) + self.conv = self.build_conv(conv_cfg, pretrained_feat_info) + self.out_shape = self.conv.out_shape + self.mean = ( + (0.485, 0.456, 0.406) + if ("dino" in self.model_type or "r3m" in self.model_type) + else (0.5, 0.5, 0.5) + ) + self.std = ( + (0.229, 0.224, 0.225) + if ("dino" in self.model_type or "r3m" in self.model_type) + else (0.5, 0.5, 0.5) + ) + self.normalize = transforms.Normalize(mean=self.mean, std=self.std) + + print("Freezing ViT to random:", freeze_vit_to_random) + + self.freeze_vit_to_random = freeze_vit_to_random + + if freeze_pretrained: + for param in self.vit.parameters(): + param.requires_grad = False + else: + for param in self.vit.parameters(): + param.requires_grad = True + + if "r3m" not in model_type: + self.p = self.vit.patch_embed.patch_size + self.stride = [stride, stride] + if not isinstance(self.p, int): + self.p = self.p[0] + + self.pretrained_feat_info = pretrained_feat_info + self._feats = {} + self.hook_handlers = [] + self.load_size = None + self.num_patches = None + + @staticmethod + def build_vit(model_type: str) -> nn.Module: + assert "dino" in model_type or "r3m" in model_type or "mvp" in model_type + if "dinov2" in model_type: + model = torch.hub.load("facebookresearch/dinov2", model_type) + elif "dino" in model_type: + model = torch.hub.load("facebookresearch/dino:main", model_type) + elif "r3m" in model_type: + model = load_r3m("resnet50") + elif "mvp" in model_type: + model_type = model_type.replace('mvp_', '') + import mvp + model = mvp.load(model_type) + return model + + @staticmethod + def build_conv(model_cfg, pretrained_feat_info) -> nn.Module: + from simple_bc.encoder.spawnnet.vit_conv import ViTConv + + model = ViTConv(**model_cfg, pretrained_feat_info=pretrained_feat_info) + return model + + @staticmethod + def prune_model(model: nn.Module, layer: int = 9) -> nn.Module: + """ + :param model: the model to prune + :param layer: the layer to extract from. 0 is the first layer after the input + :return: the pruned model + """ + try: + model.transformer.resblocks = model.transformer.resblocks[: layer + 1] + except AttributeError: + model.blocks = model.blocks[: layer + 1] + return model + + @staticmethod + def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]): + """ + Creates a method for position encoding interpolation. + :param patch_size: patch size of the model. + :param stride_hw: A tuple containing the new height and width stride respectively. + :return: the interpolation method + """ + + def interpolate_pos_encoding( + self, x: torch.Tensor, w: int, h: int + ) -> torch.Tensor: + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + # compute number of tokens taking stride into account + w0 = 1 + (w - patch_size) // stride_hw[1] + h0 = 1 + (h - patch_size) // stride_hw[0] + assert ( + w0 * h0 == npatch + ), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and + stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}""" + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape( + 1, int(math.sqrt(N)), int(math.sqrt(N)), dim + ).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + align_corners=False, + recompute_scale_factor=False, + ) + assert ( + int(w0) == patch_pos_embed.shape[-2] + and int(h0) == patch_pos_embed.shape[-1] + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return interpolate_pos_encoding + + def _get_hook(self, facet: str, layer: int): + """ + generate a hook method for a specific block and facet. + """ + feature_name = f"{facet}_{layer}" + if facet in ["attn", "token"]: + + def _hook(model, input, output): + self._feats[feature_name] = output + + return _hook + + if facet == "query": + facet_idx = 0 + elif facet == "key": + facet_idx = 1 + elif facet == "value": + facet_idx = 2 + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _inner_hook(module, input, output): + input = input[0] + B, N, C = input.shape + qkv = ( + module.qkv(input) + .reshape(B, N, 3, module.num_heads, C // module.num_heads) + .permute(2, 0, 3, 1, 4) + ) + self._feats[feature_name] = qkv[facet_idx] # Bxhxtxd + + return _inner_hook + + def _register_hooks(self, feature_names) -> None: + """ + register hook to extract features. + :param layers: layers from which to extract features. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + """ + # feature names are in format of 'facet_layer' + facets, layers = zip(*[name.split("_") for name in feature_names]) + layers = [int(layer) for layer in layers] + facets_from_layers = {layer: facet for layer, facet in zip(layers, facets)} + + for block_idx, block in enumerate(self.vit.blocks): + if block_idx in layers: + facet = facets_from_layers[block_idx] + if facet == "token": + self.hook_handlers.append( + block.register_forward_hook(self._get_hook(facet, block_idx)) + ) + elif facet == "attn": + self.hook_handlers.append( + block.attn.attn_drop.register_forward_hook( + self._get_hook(facet, block_idx) + ) + ) + elif facet in ["key", "query", "value"]: + self.hook_handlers.append( + block.attn.register_forward_hook( + self._get_hook(facet, block_idx) + ) + ) + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _unregister_hooks(self) -> None: + """ + unregisters the hooks. should be called after feature extraction. + """ + for handle in self.hook_handlers: + handle.remove() + self.hook_handlers = [] + + def _extract_features(self, batch: torch.Tensor, feature_names: List[str]) -> dict: + """ + extract features from the model + :param batch: batch to extract features for. Has shape BxCxHxW. + :param layers: layer to extract. A number between 0 to 11. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + :param feature_names: A list of names of the features to extract. Should be in the format of "{faucet}_{layer}" + :return : Dictionary of features + if facet is 'key' | 'query' | 'value' has shape Bxhxtxd + if facet is 'attn' has shape Bxhxtxt + if facet is 'token' has shape Bxtxd + """ + B, C, H, W = batch.shape + self._feats = {} + self._register_hooks(feature_names) + + _ = self.vit(batch) + + self._unregister_hooks() + self.load_size = (H, W) + self.num_patches = ( + 1 + (H - self.p) // self.stride[0], + 1 + (W - self.p) // self.stride[1], + ) + return self._feats + + def _extract_clip_feature(self, batch: torch.Tensor): + B, C, H, W = batch.shape + self.num_patches = (H // self.p[0], W // self.p[1]) + return self.vit.forward(batch) + + def _extract_r3m_feature(self, batch: torch.Tensor, feature_names): + # copied from the torchvision resnet package + resnet_enc = self.vit.module.convnet + x = resnet_enc.conv1(batch) + x = resnet_enc.bn1(x) + x = resnet_enc.relu(x) + x = resnet_enc.maxpool(x) + + feat_dict = {} + + for i in range(1, 5): + x = eval(f"resnet_enc.layer{i}")(x) + + if f"layer_{i}" in feature_names: + feat = rearrange(x, "b c h w -> b (h w) c") + 1e-12 # norm errors + feat = torch.cat([feat[:, 0:1], feat], dim=1) # dummy cls token + feat_dict[f"layer_{i}"] = feat + + return feat_dict + + def extract_descriptors(self, batch: torch.Tensor, feature_names) -> torch.Tensor: + """ + extract descriptors from the model + :param batch: batch to extract descriptors for. Has shape BxCxHxW. + :param layers: layer to extract. A number between 0 to 11. + :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors. + """ + if "clip" in self.model_type: + facet = "token" + features = self._extract_clip_feature(batch) + raise NotImplementedError + elif "r3m" in self.model_type: + features = self._extract_r3m_feature(batch, feature_names) + else: + features = self._extract_features(batch, feature_names) + + for name, val in features.items(): + # Token shape: * x (1 + pH x pW) x d + val = val / val.norm(dim=-1, keepdim=True) + features[name] = val + return features + + def preprocess(self, obs): + with torch.no_grad(): + ret = {} + if "tokens" in obs: + ret["tokens"], ret["cls"], ret["proprio"] = ( + obs["tokens"], + obs["cls"], + obs["state"], + ) + raise NotImplementedError + else: + if "rgb" in obs: + # normalize RGB image + rgb, ps = pack_one(obs["rgb"], "* c h w") # [0, 255] + rgb = rgb.float() / 255.0 + rgb = self.normalize(rgb) # use imagenet mean and std + rgb = unpack_one(rgb, ps, "* c h w") + # debug + ret["video"] = rgb + + if "depth" in obs: + depth, ps = pack_one(obs["depth"], "* c h w") + depth = torch.clip(depth, 0.0, 2.0) + depth = unpack_one(depth, ps, "* c h w") + ret["depth"] = depth.float() + + if "state" in obs: + ret["proprio"] = obs["state"] + return ret + + def forward_vit(self, video, depth, pretrained_feats=None, **kwargs): + """Video shape: (B, F, V, C, H, W) + Conv feature shape: (B, F, V, D) + """ + images, packed_shape = pack_one(video, "* c h w") + depth, _ = pack_one(depth, "* c h w") + rgbd = torch.cat([images, depth], dim=1) + + if pretrained_feats is None: + with torch.no_grad(): + pretrained_feats = self.extract_descriptors( + images, self.pretrained_feat_info + ) # * x (1 + pH x pW) x D + if self.freeze_vit_to_random: + for key, val in pretrained_feats.items(): + pretrained_feats[key] = torch.zeros_like(val) + conv_features, attn = self.conv.forward_feature(rgbd, pretrained_feats) + conv_features = unpack_one(conv_features, packed_shape, "* c") + attn = unpack_one(attn, packed_shape, "* g h w") + return conv_features, attn + + def forward(self, obs, pretrained_feats=None, **kwargs): + # V: Number of views + # Feature: B x F x V x C + processed_obs = self.preprocess(obs) + feature, attn = self.forward_vit(processed_obs["video"], processed_obs["depth"], pretrained_feats, **kwargs) + state = processed_obs["proprio"] # B F D + feature = rearrange(feature, "b f v c -> b f (v c)") + feature = torch.cat([feature, state], dim=-1) + feature = rearrange(feature, "b f d -> b (f d)") + + processed_obs["attn"] = attn + + return feature, processed_obs + + def get_pretrained_feats(self, obs): + processed_obs = self.preprocess(obs) + video = processed_obs["video"] + depth = processed_obs["depth"] + + images, packed_shape = pack_one(video, "* c h w") + + with torch.no_grad(): + pretrained_feats = self.extract_descriptors( + images, self.pretrained_feat_info + ) # * x (1 + pH x pW) x D + + return pretrained_feats \ No newline at end of file diff --git a/gello/data_utils/simple_bc/encoder/spawnnet/vit_conv.py b/gello/data_utils/simple_bc/encoder/spawnnet/vit_conv.py new file mode 100644 index 00000000..3355adf3 --- /dev/null +++ b/gello/data_utils/simple_bc/encoder/spawnnet/vit_conv.py @@ -0,0 +1,228 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + + +class ViTConv(torch.nn.Module): + """ + Perform a sequence of convolutional operations and also merge with a pretrained ViT model. + Note: Code operates on assumption of three layers being extracted from ViT. + Different modes exist for fusing all three, the last two, or the last one layer. + """ + + def __init__( + self, + in_channels, + shape, + out_shape, + pretrained_feat_info, + version="default", + conv_size=1, + pretrained_feature_dim=64, + channel_mask="default", + pretrained_input_dims=[384, 384, 384], + use_dense=True, + ): + super().__init__() + self.feat_convs = [] + self.resnet1 = [] + self.resnet2 = [] + self.convs = [] + + self.out_shape = out_shape + fcs = [64, 64, 64] + + self.layerwise_feature = {val: key for key, val in pretrained_feat_info.items()} + + self.adapters = [] + self.linears = [] + self.pretrained_input_dims = pretrained_input_dims + self.version = version + self.pretrained_feature_dims = [pretrained_feature_dim] * len(fcs) + self.conv_size = conv_size + + self.use_dense = use_dense + assert version in [ + "default", + "last_two", + "last_only", + ], f"Version {version} not supported (default, last_two, last_only)" + + self.feature_spatial_sizes = [28, 14, 7] + for i, (feature_spatial_size, num_ch, pretrained_feature_dim) in enumerate( + zip(self.feature_spatial_sizes, fcs, self.pretrained_feature_dims) + ): + if version == "last_two" and i < 1: + continue + elif version == "last_only" and i < 2: + continue + if conv_size == -1: + conv_size = 55 // feature_spatial_size + self.adapters.append( + nn.Sequential( + nn.Conv2d( + self.pretrained_input_dims[i], + pretrained_feature_dim, + kernel_size=conv_size, + stride=conv_size, + padding=0, + ), + nn.ReLU(), + nn.Upsample( + size=(feature_spatial_size, feature_spatial_size), + mode="bilinear", + ), + ) + ) + if i != 2: + self.linears.append(nn.Linear(num_ch + pretrained_feature_dim, num_ch)) + + self.stem = nn.Conv2d(in_channels, fcs[0], kernel_size=4, stride=4) + in_channels = fcs[0] + out_shape = self.out_shape + + for layer, num_ch in enumerate(fcs): + feats_convs = [] + feats_convs.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=num_ch, + kernel_size=3, + stride=1, + padding=1, + ) + ) + feats_convs.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + self.feat_convs.append(nn.Sequential(*feats_convs)) + in_channels = num_ch + + if self.version == "last_only" and layer != len(fcs) - 1: + feat_dim = num_ch + elif self.version == "last_two" and layer < len(fcs) - 2: + feat_dim = num_ch + else: + feat_dim = num_ch + self.pretrained_feature_dims[layer] + + for i in range(2): + resnet_block = [] + resnet_block.append(nn.ReLU()) + resnet_block.append( + nn.Conv2d( + in_channels=feat_dim, + out_channels=feat_dim, + kernel_size=3, + stride=1, + padding=1, + ) + ) + resnet_block.append(nn.ReLU()) + resnet_block.append( + nn.Conv2d( + in_channels=feat_dim, + out_channels=feat_dim, + kernel_size=3, + stride=1, + padding=1, + ) + ) + if i == 0: + self.resnet1.append(nn.Sequential(*resnet_block)) + else: + self.resnet2.append(nn.Sequential(*resnet_block)) + + self.feat_convs = nn.ModuleList(self.feat_convs) + self.resnet1 = nn.ModuleList(self.resnet1) + self.resnet2 = nn.ModuleList(self.resnet2) + self.adapters = nn.ModuleList(self.adapters) + self.linears = nn.ModuleList(self.linears) + self.channel_mask = channel_mask + + def forward_feature(self, feature, pretrained_feats): + """ + Input: Feature: (B x F x V) x 4 x H x W + Return: Feature of size (B x F x V) x D x pH (7) x pW (7) + """ + assert feature.shape[1] == 4 + if self.channel_mask == "default": + pass + elif self.channel_mask == "rgb_only": + feature[:, 3:] = 0 + elif self.channel_mask == "depth_only": + feature[:, :3] = 0 + elif self.channel_mask == 'no_rgbd': + feature[:, :] = 0 + + x = self.stem(feature) + res_input = None + + attn = [] + + for i, fconv in enumerate(self.feat_convs): + x = fconv(x) + # Pretrained feature: (B x F x V) x (1 + pH x pW) x D + if self.use_dense: + pretrained_x = pretrained_feats[self.layerwise_feature[i]][ + :, 1: + ] # Remove cls token + else: # let's use the cls token only. make it match the shape. + _, phpw, _ = pretrained_feats[self.layerwise_feature[i]].shape + phpw -=1 + pretrained_x = pretrained_feats[self.layerwise_feature[i]][:, 0] + pretrained_x = repeat(pretrained_x, 'b d -> b h d', h=phpw) + + def view_as_patches(feat): + h = int(np.sqrt(feat.shape[-2])) + return rearrange(feat, "b (h w) d -> b d h w", h=h, w=h) + + if self.version == "default": + pretrained_x = view_as_patches(pretrained_x) + + for j in range(3): + pretrained_x = self.adapters[i][j](pretrained_x) + if j == 1: + # extract post RELU, pre downsample features + attn.append(pretrained_x.clone().detach()) + x = torch.cat([x, pretrained_x], dim=1) + elif self.version == "last_only": + if i == 2: + pretrained_x = view_as_patches(pretrained_x) + for j in range(3): + pretrained_x = self.adapters[0][j](pretrained_x) + if j == 1: + # extract post RELU, pre downsample features + attn.append(pretrained_x.clone().detach()) + x = torch.cat([x, pretrained_x], dim=1) + + res_input = x + x = self.resnet1[i](x) + x += res_input + res_input = x + x = self.resnet2[i](x) + x += res_input + # Map back the dimension + if i != 2 and self.version == "default": + b, d, h, w = x.shape + x = rearrange(x, "b d h w -> (b h w) d") + x = self.linears[i](x) + x = rearrange(x, "(b h w) d -> b d h w", b=b, h=h, w=w) + elif i == 1 and self.version == "last_two": + b, d, h, w = x.shape + x = rearrange(x, "b d h w -> (b h w) d") + x = self.linears[0](x) + x = rearrange(x, "(b h w) d -> b d h w", b=b, h=h, w=w) + + attn = [F.interpolate(a, size=(28, 28), mode="bilinear") for a in attn] + attn = torch.stack(attn, dim=1) + attn = rearrange(attn, "b g d h w -> b g h w d") + attn = torch.norm(attn, dim=-1) + attn_b, attn_g, attn_h, attn_w = attn.shape + attn = rearrange(attn, "b g h w -> (b g) (h w)") + attn = torch.nn.functional.softmax(attn / 0.1, dim=-1) + attn = rearrange( + attn, "(b g) (h w) -> b g h w", b=attn_b, g=attn_g, h=attn_h, w=attn_w + ) + + x = rearrange(x, "b d h w -> b (h w d)") + return x, attn \ No newline at end of file diff --git a/gello/data_utils/simple_bc/encoder/vit_descriptor.py b/gello/data_utils/simple_bc/encoder/vit_descriptor.py new file mode 100644 index 00000000..81305f1e --- /dev/null +++ b/gello/data_utils/simple_bc/encoder/vit_descriptor.py @@ -0,0 +1,317 @@ +import math +import types +from typing import List, Tuple + +import hydra +import torch +import torch.nn.modules.utils as nn_utils +from omegaconf import DictConfig +from torch import nn +from torchvision import transforms + +from simple_bc._interfaces.encoder import Encoder +from simple_bc.utils.torch_utils import pack_one, unpack_one, to_torch, rearrange + + +class ViTDescriptor(Encoder): + def __init__(self, + model_type: str = 'dino_vits8', + stride: int = 4, + layer: int = 9, + freeze_pretrained: bool = True, + downsample: bool = False, + use_cached_token: bool = False, + patch_size: int = 0, # Not used, but keep it here so patch size can be passed in + use_pretrained=True, + **kwargs + ): + """ + :param model_type: A string specifying the type of model to extract from + :param stride: stride of first convolution layer. small stride -> higher resolution + :param layer: the layer to extract from. 0 is the first layer after the input + """ + super(ViTDescriptor, self).__init__(**kwargs) + self.model_type = model_type + self.use_pretrained=use_pretrained + if not use_pretrained: + assert not freeze_pretrained + self.vit = self.build_vit(model_type, freeze_pretrained, use_pretrained) + + # TODO Note this is not True for MVP + self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5) + self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5) + self.normalize = transforms.Normalize(mean=self.mean, std=self.std) + + if freeze_pretrained: + for param in self.vit.parameters(): + param.requires_grad = False + else: + for param in self.vit.parameters(): + param.requires_grad = True + self.downsample = downsample + if 'clip' not in model_type: + self.vit = ViTDescriptor.patch_vit_resolution(self.vit, stride=stride) + self.p = self.vit.patch_embed.patch_size + self.stride = self.vit.patch_embed.proj.stride + if not isinstance(self.p, int): + self.p = self.p[0] + else: + self.p = self.vit.p + self.stride = self.vit.stride + + self.layer = layer + + self._feats = [] + self.hook_handlers = [] + self.load_size = None + self.num_patches = None + + @staticmethod + def build_vit(model_type: str, freeze_pretrained, use_pretrained) -> nn.Module: + if 'dinov2' in model_type: + model = torch.hub.load('facebookresearch/dinov2', model_type, pretrained=use_pretrained) + elif 'dino' in model_type: + model = torch.hub.load('facebookresearch/dino:main', model_type, pretrained=use_pretrained) + elif 'mvp' in model_type: + assert use_pretrained + model_type = model_type.replace('mvp_', '') + import mvp + model = mvp.load(model_type) + if freeze_pretrained: + model.freeze() + return model + + @staticmethod + def prune_model(model: nn.Module, layer: int = 9) -> nn.Module: + """ + :param model: the model to prune + :param layer: the layer to extract from. 0 is the first layer after the input + :return: the pruned model + """ + try: + model.transformer.resblocks = model.transformer.resblocks[:layer + 1] + except AttributeError: + model.blocks = model.blocks[:layer + 1] + return model + + @staticmethod + def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]): + """ + Creates a method for position encoding interpolation. + :param patch_size: patch size of the model. + :param stride_hw: A tuple containing the new height and width stride respectively. + :return: the interpolation method + """ + + def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + # compute number of tokens taking stride into account + w0 = 1 + (w - patch_size) // stride_hw[1] + h0 = 1 + (h - patch_size) // stride_hw[0] + assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and + stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}""" + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + align_corners=False, recompute_scale_factor=False + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + return interpolate_pos_encoding + + @staticmethod + def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module: + """ + change resolution of model output by changing the stride of the patch extraction. + :param model: the model to change resolution for. + :param stride: the new stride parameter. + :return: the adjusted model + """ + patch_size = model.patch_embed.patch_size + if isinstance(patch_size, tuple): + assert len(patch_size) == 2, f'patch_size should be a tuple of length 2, got {patch_size}' + assert patch_size[0] == patch_size[1] + patch_size = patch_size[0] + + if stride == patch_size: # nothing to do + return model + + stride = nn_utils._pair(stride) + + assert all( + [(patch_size // s_) * s_ == patch_size for s_ in + stride]), f'stride {stride} should divide patch_size {patch_size}' + + # fix the stride + model.patch_embed.proj.stride = stride + # fix the positional encoding code + model.interpolate_pos_encoding = types.MethodType(ViTDescriptor._fix_pos_enc(patch_size, stride), model) + return model + + def _get_hook(self, facet: str): + """ + generate a hook method for a specific block and facet. + """ + if facet in ['attn', 'token']: + def _hook(model, input, output): + self._feats.append(output) + + return _hook + + if facet == 'query': + facet_idx = 0 + elif facet == 'key': + facet_idx = 1 + elif facet == 'value': + facet_idx = 2 + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _inner_hook(module, input, output): + input = input[0] + B, N, C = input.shape + qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) + self._feats.append(qkv[facet_idx]) # Bxhxtxd + + return _inner_hook + + def _register_hooks(self, layers: List[int], facet: str) -> None: + """ + register hook to extract features. + :param layers: layers from which to extract features. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + """ + for block_idx, block in enumerate(self.vit.blocks): + if block_idx in layers: + if facet == 'token': + self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) + elif facet == 'attn': + self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) + elif facet in ['key', 'query', 'value']: + self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) + else: + raise TypeError(f"{facet} is not a supported facet.") + + def _unregister_hooks(self) -> None: + """ + unregisters the hooks. should be called after feature extraction. + """ + for handle in self.hook_handlers: + handle.remove() + self.hook_handlers = [] + + def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]: + """ + extract features from the model + :param batch: batch to extract features for. Has shape BxCxHxW. + :param layers: layer to extract. A number between 0 to 11. + :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] + :return : tensor of features. + if facet is 'key' | 'query' | 'value' has shape Bxhxtxd + if facet is 'attn' has shape Bxhxtxt + if facet is 'token' has shape Bxtxd + """ + B, C, H, W = batch.shape + self._feats = [] + self._register_hooks(layers, facet) + _ = self.vit(batch) + self._unregister_hooks() + self.load_size = (H, W) + self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1]) + return self._feats + + def _extract_clip_feature(self, batch: torch.Tensor): + B, C, H, W = batch.shape + self.num_patches = (H // self.p[0], W // self.p[1]) + return self.vit.forward(batch) + + def extract_descriptors(self, batch: torch.Tensor, layer=None) -> torch.Tensor: + """ + extract descriptors from the model + :param batch: batch to extract descriptors for. Has shape BxCxHxW. + :param layers: layer to extract. A number between 0 to 11. + :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors. + """ + if 'clip' in self.model_type: + facet = 'token' + x = self._extract_clip_feature(batch) + else: + facet = 'key' if ('mvp' not in self.model_type) else 'token' + if layer is None: + self._extract_features(batch, [self.layer], facet) + else: + self._extract_features(batch, [layer], facet) + x = self._feats[0] + + if facet == 'token': + x = x.unsqueeze(dim=1) # Bx1xtxd + + desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) + desc = desc / desc.norm(dim=-1, keepdim=True) + + cls, patch_embed = desc[:, :, 0], desc[:, :, 1:] + return cls, patch_embed + + def preprocess(self, obs): + with torch.no_grad(): + ret = {} + if 'tokens' in obs: + ret['tokens'], ret['cls'], ret['proprio'] = obs['tokens'], obs['cls'], obs['state'] + else: + if "rgb" in obs and 'rgb' in self.obs_shapes: + # normalize RGB image + rgb, ps = pack_one(obs["rgb"], '* c h w') # [0, 255] + rgb = rgb.float() / 255.0 + rgb = self.normalize(rgb) # use imagenet mean and std + rgb = unpack_one(rgb, ps, '* c h w') + # debug + ret['video'] = rgb + + if "depth" in obs and 'depth' in self.obs_shapes: + depth, ps = pack_one(obs["depth"], '* c h w') + # depth = torch.clip(depth, 0.0, 2.0); should already be normalized + depth = unpack_one(depth, ps, '* c h w') + ret['depth'] = depth.float() + + if 'state' in obs and 'state' in self.obs_shapes: + ret['proprio'] = obs['state'] + return ret + + def forward_vit(self, video, **kwargs): + """ Video shape: (B, F, C, H, W) + Token shape: (B, F, D, pH, pW) + """ + images, packed_shape = pack_one(video, '* c h w') + cls, patch_embed = self.extract_descriptors(images) # (*) x 1 x (pH x pW) x D + cls = cls.squeeze(dim=1) + patch_embed = patch_embed.squeeze(dim=1) + P = int(math.sqrt(patch_embed.shape[-2])) + tokens = rearrange(patch_embed, 'b (pH pW) d -> b d pH pW', pH=P, pW=P) + if self.downsample: + tokens = nn.AvgPool2d(5, 5)(tokens) + tokens = unpack_one(tokens, packed_shape, '* c h w') + cls = unpack_one(cls, packed_shape, '* c') + return tokens, cls + + def forward(self, obs): + # V: Number of views + # tokens: B x F x V x D x H x W + processed_obs = self.preprocess(obs) + if 'tokens' in obs: + tokens, cls = processed_obs['tokens'], processed_obs['cls'] + else: + tokens, cls = self.forward_vit(processed_obs['video']) + processed_obs['cls'] = cls + return tokens, processed_obs \ No newline at end of file diff --git a/gello/data_utils/simple_bc/eval.py b/gello/data_utils/simple_bc/eval.py new file mode 100644 index 00000000..99c770aa --- /dev/null +++ b/gello/data_utils/simple_bc/eval.py @@ -0,0 +1,235 @@ +import os + +import click +import matplotlib.pyplot as plt +import numpy as np +import torch +from einops import repeat +from omegaconf import OmegaConf + +import gdict +from simple_bc._interfaces.encoder import Encoder +from simple_bc._interfaces.policy import Policy +from simple_bc.dataset.replay_dataset import ReplayDataset +from simple_bc.utils import data_utils +from simple_bc.utils.torch_utils import to_numpy, to_torch +from simple_bc.utils.visualization_utils import make_grid_video_from_numpy + +from natsort import natsorted +import glob +from tqdm import tqdm +from einops import rearrange + + +@torch.no_grad() +def evaluate(encoder, policy, dataloader, suffix="", rgb=False): + """ + Evaluate the policy. Return a dictionary of the loss and any other metrics. + """ + encoder.eval() + policy.eval() + running_mse, running_abs, tot_items = 0, 0, 0 + + all_pred_actions, all_actions, all_attn, all_rgb, all_dones = [], [], [], [], [] + for batch in dataloader: + obses, actions, dones = batch["obs"], batch["actions"], batch["dones"] + obses = gdict.GDict(obses).cuda(device="cuda") + actions = actions.to(device="cuda") + pred_actions, info = policy(encoder(obses)) + + if "attn" in info: + all_attn.append(to_numpy(info["attn"])) + if rgb: + rgb_frame = obses["rgb"] # (B, F, V, C, H, W); float32 in [0, 255] + rgb_frame = rearrange(rgb_frame, "b f v c h w -> b f v h w c")[:, -1] + all_rgb.append(to_numpy(rgb_frame)) + + all_pred_actions.append(to_numpy(pred_actions)) + all_actions.append(to_numpy(actions)) + all_dones.append(to_numpy(dones)) + first_pred_actions = pred_actions[:, 0] + first_gt_actions = actions[:, 0] + running_mse += ( + ((first_pred_actions - first_gt_actions) ** 2).sum(0).mean() + ).item() + running_abs += ( + (torch.abs(first_pred_actions - first_gt_actions)).sum(0).mean() + ).item() + B = actions.shape[0] + tot_items += B + + metrics = { + f"val/mse{suffix}": running_mse / tot_items, + f"val/abs{suffix}": running_abs / tot_items, + } + + batch_list = {"gt_actions": all_actions, "pred_actions": all_pred_actions} + if len(all_attn) > 0: + batch_list["attn"] = all_attn + if len(all_rgb) > 0: + batch_list["rgb"] = all_rgb + + episodic_list = data_utils.list_batch_to_episodic(all_dones, batch_list) + return metrics, episodic_list + + +def make_rgb_attn_video(rgb, attn, save_name): + """ + rgb: (T, V, H, W, C) + attn: (T, F, V, G, pH, pW) + """ + + T, F, V, G = attn.shape[:4] + H, W = rgb.shape[2:4] + # rgb pad the first frame by F-1 + rgb = np.concatenate([np.expand_dims(rgb[0], 0)] * (F - 1) + [rgb], axis=0) + # Normalize the attention map per image + attn_min = np.min(attn, axis=(-1, -2), keepdims=True) + attn_max = np.max(attn, axis=(-1, -2), keepdims=True) + attn = (attn - attn_min) / (attn_max - attn_min + 1e-8) + + all_videos = [] + for v in range(V): + for f in range(F): + if f != F - 1: + continue # Only keep the last frame for visualization + rgb_t = rgb[f : f + T, v] # T H W C + attn_t = attn[:, f, v] # Take the last frame, (T G pH pW) + attn_t_torch = to_torch(attn_t) + attn_t = torch.nn.functional.interpolate( + attn_t_torch, size=(224, 224), mode="nearest" + ) + rgb_t = to_torch(rearrange(rgb_t, "t h w c -> t c h w")) / 255.0 + rgb_t = torch.nn.functional.interpolate( + rgb_t, size=(224, 224), mode="nearest" + ) + + rgb_t = to_numpy(rgb_t) # T, C, H, W + rgb_t = rearrange(rgb_t, "t c h w -> t h w c") * 255.0 + + attn_t = to_numpy(attn_t) # T, G, H, W + attn_t = repeat(attn_t, "t g h w -> g t h w c", c=3) + + attn_t *= 255.0 # T, G, H, W, C + + import cv2 + + attn_t = attn_t.astype(np.uint8) + attn_t = [ + [ + cv2.applyColorMap(attn_t[g_index][t_index], cv2.COLORMAP_VIRIDIS) + for t_index in range(T) + ] + for g_index in range(G) + ] + + attn_t = np.array(attn_t)[..., ::-1].astype(np.float32) + + attn_t = attn_t * 0.5 + repeat(rgb_t, "t h w c -> g t h w c", g=G) * 0.5 + + all_videos.append(rgb_t) + all_videos.extend(attn_t) + + combined = make_grid_video_from_numpy(all_videos, ncol=G + 1, output_name=save_name) + return combined + + +def make_action_comparison(pred_actions, gt_actions, save_name): + """ + pred_actions: (T, H, D) + gt_actions: (T, H, D) + """ + T, H, D = pred_actions.shape + + row_size = (D // 4) + int(D % 4 == 0) + + plt.switch_backend("agg") + fig, ax = plt.subplots(row_size, 4, figsize=(16, 8)) + plt.tight_layout() + + for d in range(D): + d1, d2 = d // 4, d % 4 + if H > 1: + for t in range(0, T, 10): + if t == 0: + ax[d1, d2].plot( + range(t, t + H), pred_actions[t, :, d], label=f"pred", color="b" + ) + else: + ax[d1, d2].plot(range(t, t + H), pred_actions[t, :, d], color="b") + else: + ax[d1, d2].plot(pred_actions[:, 0, d], label="pred", color="b") + + ax[d1, d2].plot(gt_actions[:, 0, d], label="gt", color="orange") + ax[d1, d2].set_title(f"Action {d}") + ax[d1, d2].legend() + plt.savefig(save_name) + plt.close() + + +def make_visualization(val_replay, info, save_dir, prefix=""): + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + attn_list = info["attn"] if "attn" in info else None + rgb_frames = info["rgb"] if "rgb" in info else None + gt_actions_list, pred_actions_list = info["gt_actions"], info["pred_actions"] + for idx, (gt_actions, pred_actions) in enumerate( + zip(gt_actions_list, pred_actions_list) + ): + if attn_list is not None: + rgb = rgb_frames[idx] + make_rgb_attn_video( + rgb, + attn_list[idx], + save_name=os.path.join(save_dir, f"{prefix}traj_{idx}.mp4"), + ) + make_action_comparison( + pred_actions, + gt_actions, + save_name=os.path.join(save_dir, f"{prefix}traj_{idx}.png"), + ) + + +@click.command() +@click.option( + "--dataset", + type=str, + default=None, + help="Dataset for evaluation. If in simulation, must contain .json files organized by instance!", +) +@click.option("--policy_path", "-a", type=str, default=None) +def main(dataset, policy_path): + data_dir = os.environ["DATASET_DIR"] + eval_dir = os.path.join(data_dir, "eval") + if not os.path.exists(eval_dir): + os.makedirs(eval_dir) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + exp_dir = os.path.dirname(policy_path) + encoder_path = policy_path.replace("policy", "encoder") + cfg_path = os.path.join(exp_dir, "config.yaml") + train_cfg = OmegaConf.load(cfg_path) + + print(f"Loading encoder from {encoder_path}") + encoder = Encoder.build_encoder(train_cfg.encoder).to(device) + policy = Policy.build_policy( + encoder.out_shape, train_cfg.policy, train_cfg.encoder + ).to( + device + ) # Updated shape + encoder.load(encoder_path) + policy.load(policy_path) + + eval_replay = ReplayDataset(dataset_dir=dataset, **train_cfg.dataset, shuffle=False) + # Have to use worker = 1 to reserve the order + eval_dataloader = data_utils.get_dataloader( + eval_replay, "val", 1, train_cfg.batch_size + ) + + _, info = evaluate(encoder, policy, eval_dataloader) + make_visualization(eval_replay, info, eval_dir, prefix="eval_") + + +if __name__ == "__main__": + main() diff --git a/gello/data_utils/simple_bc/policy/__init__.py b/gello/data_utils/simple_bc/policy/__init__.py new file mode 100644 index 00000000..0f2e1bf6 --- /dev/null +++ b/gello/data_utils/simple_bc/policy/__init__.py @@ -0,0 +1 @@ +from .mlp import MLP diff --git a/gello/data_utils/simple_bc/policy/mlp.py b/gello/data_utils/simple_bc/policy/mlp.py new file mode 100644 index 00000000..4bcc175c --- /dev/null +++ b/gello/data_utils/simple_bc/policy/mlp.py @@ -0,0 +1,101 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.parameter as P +from einops import rearrange, repeat + +import simple_bc.utils.torch_utils as utils +from simple_bc._interfaces.policy import Policy + + +class MLP(Policy): + """ + Simple MLP. + """ + + def __init__( + self, + obs_shape, + act_shape, + act_bounds, + layers, + activation, + normalization, + zero_init_output, + final_act, + num_frames, + num_views=2, + action_horizon=1, + encoder_cfg=None, + ): + super().__init__() + self.obs_shape, self.act_shape = eval(str(obs_shape)), act_shape + self.act_bounds = act_bounds + self.num_frames = num_frames + self.num_views = num_views + if isinstance(self.obs_shape, tuple) or isinstance( + self.obs_shape, list + ): # For ViT + in_dim = int(self.obs_shape[-1]) * num_frames * num_views + else: + in_dim = int(self.obs_shape) + + out_dim = int(np.prod(self.act_shape)) * action_horizon + self.action_horizon = action_horizon + + ### build trunk + act = eval(activation) + norm = eval(normalization) + layers = [in_dim] + layers + [out_dim] + trunk_layers = [] + for i in range(len(layers) - 2): + trunk_layers += [ + nn.Linear(layers[i], layers[i + 1]), + ] + if act is not None: + trunk_layers += [act()] + if norm is not None: + trunk_layers += [norm(layers[i + 1])] + + last_layer = nn.Linear(layers[-2], layers[-1]) + if zero_init_output: + self._zero_init_output(last_layer) + trunk_layers += [last_layer] + if final_act: + trunk_layers += [act()] + + self.trunk = nn.Sequential(*trunk_layers) + + ### build head + if self.act_bounds is None: + self.scale, self.bias = 1.0, 0.0 + else: + lb = utils.to_torch(self.act_bounds[0]) + ub = utils.to_torch(self.act_bounds[1]) + self.scale = P.Parameter((ub - lb) / 2.0, requires_grad=False) + self.bias = P.Parameter((ub + lb) / 2.0, requires_grad=False) + + def forward(self, obs): + feat, enc_info = obs + if "cls" in enc_info: + cls, proprio = enc_info["cls"], enc_info["proprio"] + B, F, V, C = enc_info["cls"].shape + proprio = repeat(proprio, "b f c -> b f v c", v=V) + feat = torch.cat([cls, proprio], dim=-1) + feat = rearrange(feat, "b f v c -> b (f v c)") + + mean = self.trunk(feat) # This includes MLP + info = {} + + action = rearrange(mean, "b (t a) -> b t a", t=self.action_horizon) + action = action * self.scale + self.bias + + if "attn" in enc_info: + info["attn"] = enc_info["attn"] + + return action, info + + def _zero_init_output(self, m): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + m.weight.data.copy_(0.01 * m.weight.data) diff --git a/gello/data_utils/simple_bc/profile.py b/gello/data_utils/simple_bc/profile.py new file mode 100644 index 00000000..63669c60 --- /dev/null +++ b/gello/data_utils/simple_bc/profile.py @@ -0,0 +1,139 @@ +import os + +import hydra +import torch +import torch.nn as nn +import wandb +from einops import rearrange +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf + +import gdict +import simple_bc +from simple_bc._interfaces.encoder import Encoder +from simple_bc._interfaces.policy import Policy +from simple_bc.dataset.replay_dataset import ReplayDataset +from simple_bc.utils import log_utils, data_utils, torch_utils +import time +import tqdm + + +def update_dataset(cfg): + from simple_bc.constants import BC_DATASET + from omegaconf import open_dict # Allows to modify cfg in place + + with open_dict(cfg): + cfg.train_dataset = BC_DATASET[cfg.task]["train_dataset"] + cfg.val_dataset_l1 = BC_DATASET[cfg.task]["val_dataset_l1"] + cfg.val_dataset_l2 = BC_DATASET[cfg.task]["val_dataset_l2"] + + +@hydra.main(config_path="../conf", version_base="1.3") +def main(cfg: DictConfig): + work_dir = HydraConfig.get().runtime.output_dir + setup(cfg) + encoder = Encoder.build_encoder(cfg.encoder).to(cfg.device) + policy = Policy.build_policy(encoder.out_shape, cfg.policy, cfg.encoder).to( + cfg.device + ) # Updated shape + + encoder_trainable_params = torch_utils.get_named_trainable_params(encoder) + print( + "Encoder trainable parameters:", + sum(p.numel() for (name, p) in encoder_trainable_params) / 1e6, + "M", + ) + print( + "Policy trainable parameters:", + sum(p.numel() for p in policy.parameters()) / 1e6, + "M", + ) + token_name = data_utils.get_token_name(cfg.encoder) + cfg.dataset.token_name = token_name + update_dataset(cfg) + OmegaConf.save(config=cfg, f=os.path.join(work_dir, "config.yaml")) + + val_replay_l2 = ReplayDataset(dataset_dir=cfg.val_dataset_l2, **cfg.dataset) + val_dataloader_l2 = data_utils.get_dataloader( + val_replay_l2, "val", cfg.num_workers, 1 + ) + + print(f"profiling run time for x1000: no cache") + all_times = [] + enc_times = [] + pol_times = [] + batch = next(iter(val_dataloader_l2)) + obs, act = batch["obs"], batch["actions"] + obs = gdict.GDict(obs).cuda(device="cuda") + act = act.to(device="cuda") + + with torch.no_grad(): + for _ in tqdm.tqdm(range(1000)): + start = time.time() + enc_start = time.time() + feats = encoder(obs) + enc_end = time.time() + pol_start = time.time() + _ = policy(feats) + pol_end = time.time() + end = time.time() + + all_times += [end - start] + enc_times += [enc_end - enc_start] + pol_times += [pol_end - pol_start] + + # report the mean and std of the times in ms + print(f"all: {torch.tensor(all_times).mean() * 1000} +- {torch.tensor(all_times).std() * 1000}ms") + print(f"enc: {torch.tensor(enc_times).mean() * 1000} +- {torch.tensor(enc_times).std() * 1000}ms") + print(f"pol: {torch.tensor(pol_times).mean() * 1000} +- {torch.tensor(pol_times).std() * 1000}ms") + + if isinstance(encoder, simple_bc.encoder.SpawnNet): + print(f"profiling run time for x1000: with cache") + + all_times = [] + enc_times = [] + pol_times = [] + + batch = next(iter(val_dataloader_l2)) + obs, act = batch["obs"], batch["actions"] + obs = gdict.GDict(obs).cuda(device="cuda") + act = act.to(device="cuda") + + pretrained_feats = encoder.get_pretrained_feats(obs) + with torch.no_grad(): + for _ in tqdm.tqdm(range(1000)): + start = time.time() + enc_start = time.time() + feats = encoder(obs, pretrained_feats) + enc_end = time.time() + pol_start = time.time() + _ = policy(feats) + pol_end = time.time() + end = time.time() + all_times += [end - start] + enc_times += [enc_end - enc_start] + pol_times += [pol_end - pol_start] + + # report in ms + print(f"all: {torch.tensor(all_times).mean() * 1000} +- {torch.tensor(all_times).std() * 1000}ms") + print(f"enc: {torch.tensor(enc_times).mean() * 1000} +- {torch.tensor(enc_times).std() * 1000}ms") + print(f"pol: {torch.tensor(pol_times).mean() * 1000} +- {torch.tensor(pol_times).std() * 1000}ms") + +def setup(cfg): + if cfg.gpu is not None: + print(f"Using GPU {cfg.gpu}") + torch.cuda.set_device(cfg.gpu) + + import warnings + + warnings.simplefilter("ignore") + + from simple_bc.utils.log_utils import set_random_seed + + set_random_seed(cfg.seed) + torch.manual_seed(cfg.seed) + torch.cuda.manual_seed(cfg.seed) + + +if __name__ == "__main__": + main() diff --git a/gello/data_utils/simple_bc/train.py b/gello/data_utils/simple_bc/train.py new file mode 100644 index 00000000..94d90b4c --- /dev/null +++ b/gello/data_utils/simple_bc/train.py @@ -0,0 +1,289 @@ +import os + +import hydra +import torch +import torch.nn as nn +import wandb +from einops import rearrange +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, OmegaConf + +import gdict +from simple_bc.eval import make_visualization, evaluate +from simple_bc._interfaces.encoder import Encoder +from simple_bc._interfaces.policy import Policy +from simple_bc.dataset.replay_dataset import ReplayDataset +from simple_bc.utils import log_utils, data_utils, torch_utils +import time + + +def update_dataset(cfg): + from simple_bc.constants import BC_DATASET + from omegaconf import open_dict # Allows to modify cfg in place + + with open_dict(cfg): + cfg.train_dataset = BC_DATASET[cfg.task]["train_dataset"] + cfg.val_dataset_l1 = BC_DATASET[cfg.task]["val_dataset_l1"] + cfg.val_dataset_l2 = BC_DATASET[cfg.task]["val_dataset_l2"] + + +@hydra.main(config_path="../conf", version_base="1.3") +def main(cfg: DictConfig): + work_dir = HydraConfig.get().runtime.output_dir + setup(cfg) + encoder = Encoder.build_encoder(cfg.encoder).to(cfg.device) + policy = Policy.build_policy(encoder.out_shape, cfg.policy, cfg.encoder).to( + cfg.device + ) # Updated shape + + val_vis_stride = cfg.get("val_vis_stride", 1) + + token_name = data_utils.get_token_name(cfg.encoder) + cfg.dataset.token_name = token_name + update_dataset(cfg) + OmegaConf.save(config=cfg, f=os.path.join(work_dir, "config.yaml")) + + train_replay = ReplayDataset(dataset_dir=cfg.train_dataset, **cfg.dataset) + train_dataloader = data_utils.get_dataloader( + train_replay, "train", cfg.num_workers, cfg.batch_size + ) + + val_replay_l1 = ReplayDataset(dataset_dir=cfg.val_dataset_l1, **cfg.dataset) + val_dataloader_l1 = data_utils.get_dataloader( + val_replay_l1, "val", cfg.num_workers, cfg.batch_size * 2 + ) + + val_replay_l2 = ReplayDataset(dataset_dir=cfg.val_dataset_l2, **cfg.dataset) + val_dataloader_l2 = data_utils.get_dataloader( + val_replay_l2, "val", cfg.num_workers, cfg.batch_size * 2 + ) + + log_utils.init_wandb(cfg) + + optimizer = setup_optimizer(cfg.optimizer_cfg, encoder, policy) + scheduler = setup_lr_scheduler(optimizer, cfg.scheduler_cfg) + + # Pick ckpt based on the average of the last 5 epochs + metric_logger = log_utils.MetricLogger(delimiter=" ") + best_loss_logger = log_utils.BestAvgLoss(window_size=5) + + for epoch in metric_logger.log_every(range(cfg.epochs), 1, ""): + train_metrics = run_one_epoch( + encoder, + policy, + train_dataloader, + optimizer, + scheduler, + clip_grad=cfg.clip_grad, + ) + + train_metrics["train/lr"] = optimizer.param_groups[0]["lr"] + metric_logger.update(**train_metrics) + wandb.log(train_metrics, step=epoch) + + if epoch % cfg.val_freq == 0: + val_metrics_l1, _ = evaluate(encoder, policy, val_dataloader_l1, "_L1") + val_metrics_l2, _ = evaluate(encoder, policy, val_dataloader_l2, "_L2") + combined_metrics = dict(**val_metrics_l1, **val_metrics_l2) + + # Save best checkpoint + metric_logger.update(**combined_metrics) + + loss_metric = combined_metrics["val/mse_L1"] + is_best = best_loss_logger.update_best(loss_metric, epoch) + + if is_best: + encoder.save(f"{work_dir}/encoder_best.ckpt") + policy.save(f"{work_dir}/policy_best.ckpt") + with open(f"{work_dir}/best_epoch.txt", "w") as f: + f.write( + "Best epoch: %d, Best %s: %.4f" + % (epoch, "loss", best_loss_logger.best_loss) + ) + wandb.log(combined_metrics, step=epoch) + + if epoch % cfg.save_freq == 0: + encoder.save(f"{work_dir}/encoder_{epoch}.ckpt") + policy.save(f"{work_dir}/policy_{epoch}.ckpt") + + def visualize(): + val_replay_l1 = ReplayDataset( + dataset_dir=cfg.val_dataset_l1, + **cfg.dataset, + shuffle=False, + stride=val_vis_stride, + ) + val_replay_l1.aug_transform = None + val_dataloader_l1 = data_utils.get_dataloader( + val_replay_l1, "val", 1, cfg.batch_size + ) + + val_replay_l2 = ReplayDataset( + dataset_dir=cfg.val_dataset_l2, + **cfg.dataset, + shuffle=False, + stride=val_vis_stride, + ) + val_replay_l2.aug_transform = None + val_dataloader_l2 = data_utils.get_dataloader( + val_replay_l2, "val", 1, cfg.batch_size + ) + + _, info_l1 = evaluate( + encoder, policy, val_dataloader_l1, "_L1", rgb=True + ) + _, info_l2 = evaluate( + encoder, policy, val_dataloader_l2, "_L2", rgb=True + ) + save_dir = os.path.join(work_dir, f"visualization_{epoch}") + make_visualization( + val_replay_l1, info_l1, save_dir=save_dir, prefix="l1_" + ) + make_visualization( + val_replay_l2, info_l2, save_dir=save_dir, prefix="l2_" + ) + + print("Visualizing current epoch...") + + visualize() + + encoder.save(f"{work_dir}/encoder_final.ckpt") + policy.save(f"{work_dir}/policy_final.ckpt") + print(f"finished training in {wandb.run.dir}") + + def visualize_best(): + encoder.load(f"{work_dir}/encoder_best.ckpt") + policy.load(f"{work_dir}/policy_best.ckpt") + + val_replay_l1 = ReplayDataset( + dataset_dir=cfg.val_dataset_l1, + **cfg.dataset, + shuffle=False, + stride=val_vis_stride, + ) + val_replay_l1.aug_transform = None + val_dataloader_l1 = data_utils.get_dataloader( + val_replay_l1, "val", 1, cfg.batch_size + ) + + val_replay_l2 = ReplayDataset( + dataset_dir=cfg.val_dataset_l2, + **cfg.dataset, + shuffle=False, + stride=val_vis_stride, + ) + val_replay_l2.aug_transform = None + val_dataloader_l2 = data_utils.get_dataloader( + val_replay_l2, "val", 1, cfg.batch_size + ) + _, info_l1 = evaluate(encoder, policy, val_dataloader_l1, "_L1", rgb=True) + _, info_l2 = evaluate(encoder, policy, val_dataloader_l2, "_L2", rgb=True) + save_dir = os.path.join(work_dir, "visualization_best_epoch") + make_visualization(val_replay_l1, info_l1, save_dir=save_dir, prefix="l1_") + make_visualization(val_replay_l2, info_l2, save_dir=save_dir, prefix="l2_") + + print("Visualizing the best epoch...") + visualize_best() + + wandb.finish() + + +def setup_optimizer(optim_cfg, encoder, policy): + """ + Setup the optimizer. Return the optimizer. + """ + from torch import optim + + optimizer = eval(optim_cfg.type) + encoder_trainable_params = torch_utils.get_named_trainable_params(encoder) + # Print size of trainable parameters + print( + "Encoder trainable parameters:", + sum(p.numel() for (name, p) in encoder_trainable_params) / 1e6, + "M", + ) + print( + "Policy trainable parameters:", + sum(p.numel() for p in policy.parameters()) / 1e6, + "M", + ) + if len(encoder_trainable_params) > 0: + return optimizer( + list(encoder.parameters()) + list(policy.parameters()), **optim_cfg.params + ) + else: + return optimizer(list(policy.parameters()), **optim_cfg.params) + + +def setup_lr_scheduler(optimizer, scheduler_cfg): + import torch.optim as optim + import torch.optim.lr_scheduler as lr_scheduler + from simple_bc.utils.lr_scheduler import CosineAnnealingLRWithWarmup + + sched = eval(scheduler_cfg.type) + if sched is None: + return None + return sched(optimizer, **scheduler_cfg.params) + + +def run_one_epoch( + encoder, policy, dataloader, optimizer, scheduler=None, clip_grad=None +): + """ + Optimize the policy. Return a dictionary of the loss and any other metrics. + """ + running_loss, running_mse, running_abs, tot_items = 0, 0, 0, 0 + + encoder.train() + policy.train() + loss_fn = nn.MSELoss(reduction="none") + + for batch in dataloader: + obs, act = batch["obs"], batch["actions"] + obs = gdict.GDict(obs).cuda(device="cuda") + act = act.to(device="cuda") + + optimizer.zero_grad() + pred, _ = policy(encoder(obs)) # pred: (B, H, A) + loss = loss_fn(pred, act).mean() + + loss.backward() + + if clip_grad is not None: + torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5) + + optimizer.step() + running_mse += ((pred - act) ** 2).sum(0).mean().item() + running_abs += (torch.abs(pred - act)).sum(0).mean().item() + running_loss += loss.item() * act.shape[0] + tot_items += act.shape[0] + + out_dict = { + "train/mse": running_mse / tot_items, + "train/abs": running_abs / tot_items, + "train/loss": running_loss / tot_items, + } + + if scheduler is not None: + scheduler.step() + + return out_dict + +def setup(cfg): + if cfg.gpu is not None: + print(f"Using GPU {cfg.gpu}") + torch.cuda.set_device(cfg.gpu) + + import warnings + + warnings.simplefilter("ignore") + + from simple_bc.utils.log_utils import set_random_seed + + set_random_seed(cfg.seed) + torch.manual_seed(cfg.seed) + torch.cuda.manual_seed(cfg.seed) + + +if __name__ == "__main__": + main() diff --git a/gello/data_utils/simple_bc/utils/data_utils.py b/gello/data_utils/simple_bc/utils/data_utils.py new file mode 100644 index 00000000..6bbd6f54 --- /dev/null +++ b/gello/data_utils/simple_bc/utils/data_utils.py @@ -0,0 +1,123 @@ +import os +import torch +import cv2 as cv +import numpy as np +import gdict as gd +import glob +from natsort import natsorted +from functools import partial + + +# worker functions +def _worker_init_fn(worker_id, rank=0, world_size=1): + """ + For each process, each worker will get a slice of the buffer files (worker_start, worker_end) + For each epoch, the buffer_files will be permuted (with the same random seed across workers) + When there are multiple processes, each global worker will get a slice of the permuted buffer files + """ + worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers + num_global_workers = num_workers * world_size + global_worker_id = rank * num_workers + worker_id + dataset = worker_info.dataset + + N = len(dataset.buffer_filenames) + per_worker = N // num_global_workers + + dataset.worker_start = global_worker_id * per_worker + dataset.worker_end = min((global_worker_id + 1) * per_worker, N) + dataset.world_rng = np.random.RandomState(41248) + + +# collate functions +def collate_fn(batch): + return gd.GDict.stack(batch, axis=0).to_torch( + non_blocking=True, dtype="float32", use_copy=True + ) + + +def get_token_name(encoder_cfg): + return "none" + + +def normalize_quat(quats, dim): + assert quats.shape[1] == 4 + quats = quats.copy() + signs = np.sign(quats[:, dim]) + return quats * signs[:, None] + + +def load_traj_from_memory(filename): + trajs = gd.GDict.from_hdf5(filename) + key = natsorted(trajs.keys())[ + 0 + ] # Take the first key, assuming there is only one traj per file + traj = gd.DictArray(trajs[key]) + traj = traj.select_by_keys(["obs", "actions"]) + traj = traj.to_two_dims() + + return traj + + +def load_traj_files(folder, token_name="none", stride=1): + """ + Load the trajectory filenames from the folders. + """ + # Recursive + buffer_filenames = glob.glob( + os.path.join(folder, token_name, "**/*.h5"), recursive=True + ) + buffer_filenames = natsorted(buffer_filenames) + + buffer_filenames = [ + buffer_filenames[i] for i in range(len(buffer_filenames)) if i % stride == 0 + ] + print(f"ReplayDataset: loading {len(buffer_filenames)} trajectories from {folder}") + + return buffer_filenames + + +def replace_abs_dataset_path(dataset_path): + data_dir = os.environ.get("DATASET_DIR", None) + if data_dir is None: + return dataset_path + else: + dataset_name = os.path.basename(dataset_path) + return os.path.join(data_dir, dataset_name) + + +def get_dataloader(replay, mode, num_workers, batch_size): + if mode == "val": + num_workers = min(num_workers, 3) + context = ( + None if num_workers == 1 else torch.multiprocessing.get_context("forkserver") + ) + world_rank, world_size = 0, 1 + loader = torch.utils.data.DataLoader( + replay, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn, + worker_init_fn=partial(_worker_init_fn, rank=world_rank, world_size=world_size), + persistent_workers=True if num_workers > 1 else False, + multiprocessing_context=context if num_workers > 1 else None, + drop_last=mode == "train", + prefetch_factor=5, + ) + return loader + + +def list_batch_to_episodic(all_dones, all_list): + """ + Convert a list of batched data to a list of episodes based on the dones. + all_dones: list of dones, each dones is a 1D array of shape (episode_length,) + all_list: A dictionary, where each value is a list of batched data + """ + all_dones = np.concatenate(all_dones, axis=0) + for key in all_list: + all_list[key] = np.concatenate(all_list[key], axis=0) + split_idx = np.where(all_dones == 1)[0] + 1 + episodic_list = {} + for key in all_list: + episodic_list[key] = np.split(all_list[key], split_idx, axis=0)[:-1] + return episodic_list diff --git a/gello/data_utils/simple_bc/utils/log_utils.py b/gello/data_utils/simple_bc/utils/log_utils.py new file mode 100644 index 00000000..7fc9911d --- /dev/null +++ b/gello/data_utils/simple_bc/utils/log_utils.py @@ -0,0 +1,282 @@ +import wandb +import os +import torch +import torch.distributed as dist +import numpy as np +import random +import builtins +import datetime +import time +from collections import defaultdict, deque +from omegaconf import DictConfig, OmegaConf +import json + + +def init_wandb(cfg): + cfg = OmegaConf.to_container(cfg, resolve=True) + cfg = OmegaConf.create(cfg) + pretty_print_cfg(cfg) + wandb_cfg = prepare_wandb_cfg(cfg) + + wandb.init( + config=wandb_cfg, + project=cfg.wandb.project, + name=cfg.wandb.name, + group=cfg.wandb.group, + ) + OmegaConf.save(cfg, f"{wandb.run.dir}/config.yaml") + + +def log_wandb(metrics, step): + """ + Log the metrics to wandb. The metrics should be hierarchical, with the top level keys being "train" and "val". + """ + for key, value in metrics.items(): + wandb.log(value, step=step, commit=False) + + +def pretty_print_cfg(cfg): + """ + Pretty print the config as cascading bullet points. + """ + print("Config:") + for key, value in cfg.items(): + print(f"- {key}:") + if isinstance(value, DictConfig) or isinstance(value, dict): + for k, v in value.items(): + print(f" - {k}: {v}") + else: + print(f" - {value}") + + +def prepare_wandb_cfg(cfg): + wandb_cfg = {} + for key, value in cfg.items(): + if isinstance(value, DictConfig): + wandb_cfg[key] = prepare_wandb_cfg(value) + else: + wandb_cfg[key] = value + + return wandb_cfg + + +def set_random_seed(seed): + if seed is not None: + random.seed(seed) + np.random.seed(seed) + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print("[{}] ".format(now), end="") # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{value:.4f} ({avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class BestAvgLoss(object): + # Best loss is + def __init__(self, window_size=20): + self.deque = deque(maxlen=window_size) + self.best_loss, self.best_epoch = None, None + + def update_best(self, loss, epoch): + self.deque.append(loss) + if self.best_loss is None or loss < self.best_loss: + self.best_loss = loss + self.best_epoch = epoch + is_best = True + else: + is_best = False + return is_best + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + """Window size is for the best key""" + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f}s/it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +def save_cfg(cfg, save_name): + # cfg is a dictionary + with open(save_name, "w") as f: + json.dump(cfg, f, indent=2) diff --git a/gello/data_utils/simple_bc/utils/lr_scheduler.py b/gello/data_utils/simple_bc/utils/lr_scheduler.py new file mode 100644 index 00000000..457eba9a --- /dev/null +++ b/gello/data_utils/simple_bc/utils/lr_scheduler.py @@ -0,0 +1,63 @@ +import math +import torch +from torch import nn + + +class CosineAnnealingLRWithWarmup(torch.optim.lr_scheduler._LRScheduler): + def __init__(self, optimizer, warmup_lr, warmup_epoch, T_max, last_epoch=-1): + self.warmup_lr = warmup_lr + self.warmup_epoch = warmup_epoch + self.T_max = T_max + super().__init__(optimizer, last_epoch) + + def get_lr(self): + lrs = [] + for i in range(len(self.base_lrs)): + if self.last_epoch < self.warmup_epoch: + lr = ( + self.warmup_lr + + (self.base_lrs[i] - self.warmup_lr) + * self.last_epoch + / self.warmup_epoch + ) + else: + lr = ( + 0.5 + * self.base_lrs[i] + * ( + 1 + + math.cos( + math.pi + * (self.last_epoch - self.warmup_epoch) + / (self.T_max - self.warmup_epoch) + ) + ) + ) + lrs.append(lr) + return lrs + + +if __name__ == "__main__": + # Test code + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import torch + import torch.nn as nn + + simple_nn = nn.Sequential(nn.Linear(1, 1), nn.ReLU(), nn.Linear(1, 1)) + + optimizer = torch.optim.SGD(simple_nn.parameters(), lr=0.1) + scheduler = CosineAnnealingLRWithWarmup( + optimizer, warmup_lr=0.01, warmup_epoch=10, T_max=100 + ) + lrs = [] + plt.figure(figsize=(10, 5)) + for i in range(100): + optimizer.step() + scheduler.step() + lrs.append(scheduler.get_lr()[0]) + plt.plot(lrs) + plt.show() + plt.savefig("lr.png") diff --git a/gello/data_utils/simple_bc/utils/torch_utils.py b/gello/data_utils/simple_bc/utils/torch_utils.py new file mode 100644 index 00000000..677173c6 --- /dev/null +++ b/gello/data_utils/simple_bc/utils/torch_utils.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +import numpy as np +from einops import pack, unpack, repeat, reduce, rearrange + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, length=1): + return val if isinstance(val, tuple) else ((val,) * length) + + +def pack_one(x, pattern): + return pack([x], pattern) + + +def unpack_one(x, ps, pattern): + return unpack(x, ps, pattern)[0] + + +def to_torch(array, device="cpu"): + if isinstance(array, torch.Tensor): + return array.to(device) + if isinstance(array, np.ndarray): + return torch.from_numpy(array).to(device) + else: + return torch.tensor(array).to(device) + + +def to_numpy(array): + if isinstance(array, torch.Tensor): + return array.cpu().numpy() + return array + + +def to_cpu(array): + if isinstance(array, torch.Tensor): + return array.detach().cpu() + elif isinstance(array, tuple): + return tuple(to_cpu(a) for a in array) + + +@torch.no_grad() +def batch_pred(func, kwargs, batch_size=2048, collate_fn=None, get_cpu=False): + rand_key = list(kwargs.keys())[0] + N = len(kwargs[rand_key]) + if N <= batch_size: + return func(**kwargs) + else: + all_pred = [] + for i in range(0, N, batch_size): + new_kwargs = {} + for key, val in kwargs.items(): + if isinstance(val, torch.Tensor) or isinstance(val, np.ndarray): + new_kwargs[key] = val[i : min(i + batch_size, N)] + pred = func(**new_kwargs) + if get_cpu: + pred = to_cpu(pred) + else: + new_kwargs[key] = val + pred = func(**new_kwargs) + all_pred.append(pred) + if collate_fn is None: + return torch.cat(all_pred, dim=0) + else: + return collate_fn(all_pred) + + +def get_named_trainable_params(model): + return [ + (name, param) for name, param in model.named_parameters() if param.requires_grad + ] diff --git a/gello/data_utils/simple_bc/utils/visualization_utils.py b/gello/data_utils/simple_bc/utils/visualization_utils.py new file mode 100644 index 00000000..a14a222f --- /dev/null +++ b/gello/data_utils/simple_bc/utils/visualization_utils.py @@ -0,0 +1,272 @@ +import os + +import cv2 +import numpy as np +from moviepy.editor import ImageSequenceClip, VideoFileClip, concatenate_videoclips +from moviepy.video.fx.speedx import speedx + + +def merge_videos(video_files, output_name, speed=1): + videos = [] + for video_file in video_files: + videos.append(VideoFileClip(video_file)) + final_clip = concatenate_videoclips(videos) + # speed up + final_clip = speedx(final_clip, speed) + final_clip.write_videofile(output_name, fps=30) + + +def cv_render(img, name="GoalEnvExt", scale=5): + """Take an image in ndarray format and show it with opencv.""" + if len(img.shape) == 2: + img = img[:, :, None] + if img.shape[2] == 1: # Depth. Normalize. + img = np.tile(img, [1, 1, 3]) + img = (img - np.min(img)) / (np.max(img) - np.min(img)) + elif img.shape[2] > 3: + img = img[:, :, :3] + new_img = img[:, :, (2, 1, 0)] + h, w = new_img.shape[:2] + new_img = cv2.resize(new_img, (w * scale, h * scale)) + cv2.imshow(name, new_img) + cv2.waitKey(20) + + +def save_rgb(path, img): + if np.max(img) <= 1.0: + img = img * 255.0 + img = img.astype(np.float32) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + cv2.imwrite(path, img) + + +def make_grid(array, ncol=5, padding=0, pad_value=120): + """numpy version of the make_grid function in torch. Dimension of array: NHWC""" + if np.max(array) < 2.0: + array = array * 255.0 + if len(array.shape) == 3: # In case there is only one channel + array = np.expand_dims(array, 3) + N, H, W, C = array.shape + if N % ncol > 0: + res = ncol - N % ncol + array = np.concatenate([array, np.ones([res, H, W, C])]) + N = array.shape[0] + nrow = N // ncol + idx = 0 + grid_img = None + for i in range(nrow): + row = np.pad( + array[idx], + [[padding if i == 0 else 0, padding], [padding, padding], [0, 0]], + constant_values=pad_value, + mode="constant", + ) + for j in range(1, ncol): + idx += 1 + cur_img = np.pad( + array[idx], + [[padding if i == 0 else 0, padding], [0, padding], [0, 0]], + constant_values=pad_value, + mode="constant", + ) + row = np.hstack([row, cur_img]) + idx += 1 + if i == 0: + grid_img = row + else: + grid_img = np.vstack([grid_img, row]) + return grid_img.astype(np.float32) + + +def save_numpy_as_gif(array, filename, fps=20, scale=1.0): + """Creates a gif given a stack of images using moviepy + Notes + ----- + works with current Github version of moviepy (not the pip version) + https://github.com/Zulko/moviepy/commit/d4c9c37bc88261d8ed8b5d9b7c317d13b2cdf62e + Usage + ----- + >>> X = randn(100, 64, 64) + >>> gif('test.gif', X) + Parameters + ---------- + filename : string + The filename of the gif to write to + array : array_like + A numpy array that contains a sequence of images + fps : int + frames per second (default: 10) + scale : float + how much to rescale each image by (default: 1.0) + """ + + if np.max(array) <= 2.0: + array *= 255.0 + # ensure that the file has the .gif extension + fname, _ = os.path.splitext(filename) + filename = fname + ".gif" + + # copy into the color dimension if the images are black and white + if array.ndim == 3: + array = array[..., np.newaxis] * np.ones(3) + + # make the moviepy clip + clip = ImageSequenceClip(list(array), fps=fps).resize(scale) + clip.write_gif(filename, fps=fps) + return clip + + +def save_numpy_as_video(array, filename, fps=20): + """Creates a gif given a stack of images using moviepy + Notes + ----- + works with current Github version of moviepy (not the pip version) + https://github.com/Zulko/moviepy/commit/d4c9c37bc88261d8ed8b5d9b7c317d13b2cdf62e + Usage + """ + folder = os.path.dirname(filename) + if not os.path.exists(folder): + os.makedirs(folder) + + if np.max(array) <= 2.0: + array *= 255.0 + array = array.astype(np.uint8) + # ensure that the file has the .mp4 extension + fname, _ = os.path.splitext(filename) + filename = fname + ".mp4" + + # copy into the color dimension if the images are black and white + if array.ndim == 3: + array = array[..., np.newaxis] * np.ones(3) + + # import uuid + # temp_filename = f'/tmp/{str(uuid.uuid4())}.mp4' + # CV_VIDEO_CODES = {"mp4": cv2.VideoWriter_fourcc(*"mp4v"), } + # img = array[0] + # video_writer = cv2.VideoWriter(temp_filename, CV_VIDEO_CODES['mp4'], fps, (img.shape[1], img.shape[0])) + # + # # Save + # for frame in list(array): + # frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + # video_writer.write(frame) + # video_writer.release() + # os.system(f"ffmpeg -i {temp_filename} -vcodec libx264 {filename} -y -hide_banner -loglevel error") + # os.system(f"rm -rf {temp_filename}")\ + + # copy into the color dimension if the images are black and white + if array.ndim == 3: + array = array[..., np.newaxis] * np.ones(3) + + # make the moviepy clip + clip = ImageSequenceClip(list(array), fps=fps) + clip.write_videofile(filename, fps=fps, logger=None) + return clip + + +def save_numpy_as_img(img, filename): + img = img * 255.0 + img = img.astype(np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + cv2.imwrite(filename, img) + + +def save_numpy_to_gif_matplotlib(array, filename, interval=50): + from matplotlib import animation + from matplotlib import pyplot as plt + + fig = plt.figure(figsize=(10, 10)) + ax = fig.add_subplot(1, 1, 1) + + def img_show(i): + plt.imshow(array[i]) + print("showing image {}".format(i)) + return + + ani = animation.FuncAnimation(fig, img_show, len(array), interval=interval) + + ani.save("{}.mp4".format(filename)) + + import ffmpy + + ff = ffmpy.FFmpeg( + inputs={"{}.mp4".format(filename): None}, + outputs={"{}.gif".format(filename): None}, + ) + + ff.run() + # plt.show() + + +def video_pad_time(videos): + nframe = np.max([video.shape[0] for video in videos]) + padded = [] + for video in videos: + npad = nframe - len(video) + padded_frame = video[[-1], :, :, :].copy() + video = np.vstack([video, np.tile(padded_frame, [npad, 1, 1, 1])]) + padded.append(video) + return np.array(padded) + + +def make_grid_video_from_numpy( + video_array, ncol, output_name="./output.mp4", speedup=1, fps=24 +): + videos = [] + for video in video_array: + if speedup != 1: + video = video[::speedup] + videos.append(video) + videos = video_pad_time(videos) # N x T x H x W x 3 + grid_frames = [] + for t in range(videos.shape[1]): + grid_frame = make_grid(videos[:, t], ncol=ncol, padding=5) + + # save_numpy_as_img(grid_frame / 255.0, output_name.replace('.mp4', f'_{t}.jpg')) + + grid_frames.append(grid_frame) + + save_numpy_as_video(np.array(grid_frames), output_name, fps=fps) + + +def make_grid_gif_from_numpy( + video_array, ncol, output_name="./output.gif", speedup=1, fps=10 +): + videos = [] + for video in video_array: + if speedup != 1: + video = video[::speedup] + videos.append(video) + videos = video_pad_time(videos) # N x T x H x W x 3 + grid_frames = [] + for t in range(videos.shape[1]): + grid_frame = make_grid(videos[:, t], ncol=ncol, padding=5) + grid_frames.append(grid_frame) + save_numpy_as_gif(np.array(grid_frames), output_name, fps=fps) + + +def make_grid_video(video_list, ncol, output_name="./output.mp4", speedup=1): + videos = [] + for video_path in video_list: + myclip = VideoFileClip(video_path) + if myclip.size[0] > 256: + myclip = myclip.resize(height=256) + if speedup != 1: + myclip = myclip.speedx(speedup) + frames = [] + for frame in myclip.iter_frames(): + frames.append(frame) + videos.append(np.array(frames)) + videos = video_pad_time(videos) # N x T x H x W x 3 + grid_frames = [] + for t in range(videos.shape[1]): + grid_frame = make_grid(videos[:, t], ncol=ncol, padding=5) + grid_frames.append(grid_frame) + save_numpy_as_video(np.array(grid_frames), output_name, fps=24) + + +def visualize_traj_opencv(imgs): + import cv2 as cv + + for i in range(len(imgs)): + cv.imshow("x", imgs[i]) + cv.waitKey(20) diff --git a/requirements_data_process.txt b/requirements_data_process.txt new file mode 100644 index 00000000..99e20ad8 --- /dev/null +++ b/requirements_data_process.txt @@ -0,0 +1,9 @@ +natsort +matplotlib +mediapy +moviepy==1.0.3 +opencv-python +torch +transforms3d +PyYAML +h5py \ No newline at end of file