Skip to content

Commit

Permalink
feat(encoder): add image pca encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
HiroIshida committed Feb 2, 2024
1 parent f6c6ea1 commit 1929a0a
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion mohou/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
from mohou.model.autoencoder import AutoEncoderBase
from mohou.model.common import ModelBase
from mohou.trainer import TrainCache
from mohou.types import ElementT, EpisodeBundle, ImageT, VectorT, get_element_type
from mohou.types import (
ElementT,
EpisodeBundle,
ImageT,
PrimitiveImageBase,
PrimitiveImageT,
VectorT,
get_element_type,
)
from mohou.utils import assert_equal_with_message, assert_isinstance_with_message

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,6 +104,54 @@ def set_device(self, device: torch.device) -> None:
self._get_model().put_on_device(device)


@dataclass
class PCAImageEncoder(EncoderBase[PrimitiveImageT]):
pca: PCA

@classmethod
def from_bundle(
cls, bundle: EpisodeBundle, image_type: Type[PrimitiveImageT], n_out: int
) -> "PCAImageEncoder":
assert issubclass(image_type, PrimitiveImageBase), "currently only support PrimitiveImage"
vector_list = []
image_shape = None
for episode in bundle.get_touch_bundle():
image_seq = episode.get_sequence_by_type(image_type)
vecs = [image.numpy().flatten() for image in image_seq.elem_list]
vector_list.extend(vecs)
if image_shape is None:
image_shape = image_seq.elem_list[0].shape
else:
assert image_shape == image_seq.elem_list[0].shape
assert image_shape is not None
mat = np.array(vector_list)
pca = PCA(n_components=n_out)
pca.fit(mat)
return cls(image_type, image_shape, n_out, pca)

def _forward_impl(self, inp: ImageT) -> np.ndarray:
assert isinstance(inp, PrimitiveImageBase)
inp_as_2d = inp.numpy().flatten().reshape(1, -1)
out = self.pca.transform(inp_as_2d)
return out.flatten()

def _backward_impl(self, inp: np.ndarray) -> PrimitiveImageT:
out = self.pca.inverse_transform(inp.reshape(1, -1))
out_reshaped = out.reshape(self.input_shape)
out_uint8 = out_reshaped.astype(np.uint8)
return self.elem_type(out_uint8)

def save(self, project_path: Path) -> None:
with (project_path / "image_pca.pkl").open("wb") as f:
pickle.dump(self, f)

@classmethod
def load(cls, project_path: Path) -> "PCAImageEncoder":
with (project_path / "image_pca.pkl").open("rb") as f:
pca = pickle.load(f)
return pca


@dataclass(eq=False)
class ImageEncoder(EncoderBase[ImageT], HasAModel):
input_shape: Tuple[int, int, int]
Expand Down

0 comments on commit 1929a0a

Please sign in to comment.