diff --git a/mohou/encoder.py b/mohou/encoder.py index bc54b60..3589369 100644 --- a/mohou/encoder.py +++ b/mohou/encoder.py @@ -13,7 +13,15 @@ from mohou.model.autoencoder import AutoEncoderBase from mohou.model.common import ModelBase from mohou.trainer import TrainCache -from mohou.types import ElementT, EpisodeBundle, ImageT, VectorT, get_element_type +from mohou.types import ( + ElementT, + EpisodeBundle, + ImageT, + PrimitiveImageBase, + PrimitiveImageT, + VectorT, + get_element_type, +) from mohou.utils import assert_equal_with_message, assert_isinstance_with_message logger = logging.getLogger(__name__) @@ -96,6 +104,54 @@ def set_device(self, device: torch.device) -> None: self._get_model().put_on_device(device) +@dataclass +class PCAImageEncoder(EncoderBase[PrimitiveImageT]): + pca: PCA + + @classmethod + def from_bundle( + cls, bundle: EpisodeBundle, image_type: Type[PrimitiveImageT], n_out: int + ) -> "PCAImageEncoder": + assert issubclass(image_type, PrimitiveImageBase), "currently only support PrimitiveImage" + vector_list = [] + image_shape = None + for episode in bundle.get_touch_bundle(): + image_seq = episode.get_sequence_by_type(image_type) + vecs = [image.numpy().flatten() for image in image_seq.elem_list] + vector_list.extend(vecs) + if image_shape is None: + image_shape = image_seq.elem_list[0].shape + else: + assert image_shape == image_seq.elem_list[0].shape + assert image_shape is not None + mat = np.array(vector_list) + pca = PCA(n_components=n_out) + pca.fit(mat) + return cls(image_type, image_shape, n_out, pca) + + def _forward_impl(self, inp: ImageT) -> np.ndarray: + assert isinstance(inp, PrimitiveImageBase) + inp_as_2d = inp.numpy().flatten().reshape(1, -1) + out = self.pca.transform(inp_as_2d) + return out.flatten() + + def _backward_impl(self, inp: np.ndarray) -> PrimitiveImageT: + out = self.pca.inverse_transform(inp.reshape(1, -1)) + out_reshaped = out.reshape(self.input_shape) + out_uint8 = out_reshaped.astype(np.uint8) + return self.elem_type(out_uint8) + + def save(self, project_path: Path) -> None: + with (project_path / "image_pca.pkl").open("wb") as f: + pickle.dump(self, f) + + @classmethod + def load(cls, project_path: Path) -> "PCAImageEncoder": + with (project_path / "image_pca.pkl").open("rb") as f: + pca = pickle.load(f) + return pca + + @dataclass(eq=False) class ImageEncoder(EncoderBase[ImageT], HasAModel): input_shape: Tuple[int, int, int]