Skip to content

Commit f39eb87

Browse files
Solved minor bugs in inverse canonicalization for discrete groups (#18)
* fixed inverse canonicalization key issue group -> group_element * fix scalar group element key: 0 -> rotation * test for discrete invert canonicalization, fixing minor bugs --------- Co-authored-by: olayasturias <[email protected]>
1 parent 1e66680 commit f39eb87

File tree

4 files changed

+78
-8
lines changed

4 files changed

+78
-8
lines changed

equiadapt/images/canonicalization/discrete_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
len(in_shape) == 3
5656
), "Input shape should be in the format (channels, height, width)"
5757

58-
# DEfine all the image transformations here which are used during canonicalization
58+
# Define all the image transformations here which are used during canonicalization
5959
# pad and crop the input image if it is not rotated MNIST
6060
is_grayscale = in_shape[0] == 1
6161

equiadapt/images/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,19 @@ def get_action_on_image_features(
5353
batch_size, C, H, W = feature_map.shape
5454
if induced_rep_type == "regular":
5555
assert feature_map.shape[1] % num_group == 0
56-
angles = group_element_dict["group"]["rotation"]
56+
angles = group_element_dict["rotation"]
5757
x_out = K.geometry.rotate(feature_map, angles)
5858

59-
if "reflection" in group_element_dict["group"]:
60-
reflect_indicator = group_element_dict["group"]["reflection"]
59+
if "reflection" in group_element_dict:
60+
reflect_indicator = group_element_dict["reflection"]
6161
x_out_reflected = K.geometry.hflip(x_out)
6262
x_out = x_out * reflect_indicator[:, None, None, None] + x_out_reflected * (
6363
1 - reflect_indicator[:, None, None, None]
6464
)
6565

6666
x_out = x_out.reshape(batch_size, C // num_group, num_group, H, W)
6767
shift = angles / 360.0 * num_rotations
68-
if "reflection" in group_element_dict["group"]:
68+
if "reflection" in group_element_dict:
6969
x_out = torch.cat(
7070
[
7171
roll_by_gather(x_out[:, :, :num_rotations], shift),
@@ -78,10 +78,10 @@ def get_action_on_image_features(
7878
x_out = x_out.reshape(batch_size, -1, H, W)
7979
return x_out
8080
elif induced_rep_type == "scalar":
81-
angles = group_element_dict["group"][0]
81+
angles = group_element_dict["rotation"]
8282
x_out = K.geometry.rotate(feature_map, angles)
83-
if "reflection" in group_element_dict["group"]:
84-
reflect_indicator = group_element_dict["group"]["reflection"]
83+
if "reflection" in group_element_dict:
84+
reflect_indicator = group_element_dict["reflection"]
8585
x_out_reflected = K.geometry.hflip(x_out)
8686
x_out = x_out * reflect_indicator[:, None, None, None] + x_out_reflected * (
8787
1 - reflect_indicator[:, None, None, None]

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ norecursedirs =
8383
build
8484
.tox
8585
testpaths = tests
86+
filterwarnings = ignore::Warning
8687
# Use pytest markers to select/deselect specific tests
8788
# markers =
8889
# slow: mark tests as slow (deselect with '-m "not slow"')
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pytest
2+
import torch
3+
from omegaconf import DictConfig
4+
5+
from equiadapt.images.canonicalization.discrete_group import (
6+
GroupEquivariantImageCanonicalization,
7+
)
8+
from equiadapt.images.canonicalization_networks.escnn_networks import (
9+
ESCNNEquivariantNetwork,
10+
)
11+
12+
13+
@pytest.fixture
14+
def init_args() -> dict:
15+
"""
16+
Initialize the arguments for the canonicalization function.
17+
18+
Returns:
19+
dict: A dictionary containing the initialization arguments.
20+
"""
21+
# Mock initialization arguments
22+
canonicalization_hyperparams = DictConfig(
23+
{
24+
"input_crop_ratio": 0.9,
25+
"resize_shape": (32, 32),
26+
"beta": 0.1,
27+
}
28+
)
29+
return {
30+
"canonicalization_network": ESCNNEquivariantNetwork(
31+
in_shape=(3, 64, 64),
32+
out_channels=32,
33+
kernel_size=3,
34+
group_type="rotation",
35+
num_rotations=4,
36+
num_layers=2,
37+
),
38+
"canonicalization_hyperparams": canonicalization_hyperparams,
39+
"in_shape": (3, 64, 64),
40+
}
41+
42+
43+
# try both types of induced representations (regular and scalar)
44+
@pytest.mark.parametrize("induced_rep, num_channels", [("regular", 12), ("scalar", 3)])
45+
def test_invert_canonicalization_induced_rep(
46+
induced_rep: str, num_channels: int, init_args: dict
47+
) -> None:
48+
"""
49+
Test the inversion of the canonicalization-induced representation.
50+
51+
Args:
52+
induced_rep (str): The type of induced representation.
53+
num_channels (int): The number of channels in the sample image.
54+
"""
55+
56+
# Initialize the canonicalization function
57+
dgic = GroupEquivariantImageCanonicalization(**init_args)
58+
59+
# Apply the canonicalization function
60+
image = torch.randn((1, 3, 64, 64))
61+
62+
_ = dgic(image) # to populate the canonicalization_info_dict
63+
64+
canonicalized_image = torch.randn((1, num_channels, 64, 64))
65+
66+
# Invert the canonicalization-induced representation
67+
inverted_image = dgic.invert_canonicalization(
68+
canonicalized_image, **{"induced_rep_type": induced_rep}
69+
)

0 commit comments

Comments
 (0)