Skip to content

Commit 73dce8c

Browse files
authored
Add files via upload
1 parent eacabd7 commit 73dce8c

File tree

3 files changed

+836
-0
lines changed

3 files changed

+836
-0
lines changed

regclip_ssr.py

+292
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import os.path as osp
2+
3+
import torch
4+
import torch.nn as nn
5+
from torch.nn import init
6+
import torch.nn.functional as F
7+
import torchvision.models as models
8+
from clip import clip
9+
10+
from ordinalclip.utils import get_logger
11+
12+
from . import image_encoders
13+
from .builder import MODELS
14+
from .prompt_leaners import PROMPT_LEARNERS
15+
from .prompt_leaners.plain_prompt_learner import PlainPromptLearner
16+
17+
import sys
18+
19+
logger = get_logger(__name__)
20+
21+
22+
# for age estimation
23+
bin_list_a = [0, 13, 19, 35, 65]
24+
bin_list_b = [0, 13, 19, 35, 65]
25+
26+
bin_width_a = [13,6,16,30,36]
27+
bin_width_b = [13,6,16,30,36]
28+
29+
30+
# for image aesthetics
31+
# bin_list_a = [0, 1, 2, 3, 4]
32+
# bin_list_b = [0, 1, 2, 3, 4]
33+
34+
# bin_width_a = [1, 1, 1, 1, 1]
35+
# bin_width_b = [1, 1, 1, 1, 1]
36+
37+
38+
# for historical image dating
39+
# bin_list_a = [0, 1, 2, 3, 4]
40+
# bin_list_b = [0, 1, 2, 3, 4]
41+
42+
# bin_width_a = [1, 1, 1, 1, 1]
43+
# bin_width_b = [1, 1, 1, 1, 1]
44+
45+
@MODELS.register_module()
46+
class RegCLIPSSR(nn.Module):
47+
def __init__(
48+
self,
49+
text_encoder_name,
50+
image_encoder_name,
51+
prompt_learner_cfg,
52+
d = 512,
53+
**kwargs,
54+
) -> None:
55+
super().__init__()
56+
57+
if kwargs:
58+
logger.info(f"irrelevant kwargs: {kwargs}")
59+
60+
clip_model = load_clip_to_cpu(
61+
text_encoder_name,
62+
image_encoder_name,
63+
root=osp.join(osp.dirname(osp.realpath(__file__)), "..", "..", ".cache", "clip"),
64+
)
65+
clip_model.float()
66+
logger.info("convert `clip_model` to float32. if need fp16 model, call `clip.model.convert_weights`")
67+
68+
self.image_encoder = clip_model.visual
69+
self.text_encoder = TextEncoder(clip_model)
70+
prompt_learner_cfg.update(dict(clip_model=clip_model))
71+
self.prompt_learner: PlainPromptLearner = PROMPT_LEARNERS.build(prompt_learner_cfg)
72+
self.psudo_sentence_tokens = self.prompt_learner.psudo_sentence_tokens
73+
self.logit_scale = clip_model.logit_scale
74+
75+
self.embed_dims = clip_model.text_projection.shape[1]
76+
self.num_ranks = self.prompt_learner.num_ranks
77+
self.d = d
78+
79+
# we first adopt CLIP-adapter based adaptation method. After experiment, we found fully finetune the image encoder could get the better performance.
80+
self.image_adapter = Adapter(self.d, 4)
81+
82+
self.regressor = SSRModule()
83+
84+
def forward(self, images):
85+
sentence_embeds = self.prompt_learner()
86+
psudo_sentence_tokens = self.psudo_sentence_tokens
87+
text_features = self.text_encoder(sentence_embeds, psudo_sentence_tokens)
88+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
89+
90+
image_features = self.image_encoder(images)
91+
y = self.image_adapter(image_features)
92+
y_ratio = 0.8
93+
image_features = y_ratio * y + (1 - y_ratio) * image_features
94+
95+
96+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
97+
logit_scale = self.logit_scale.exp()
98+
logits = logit_scale * image_features @ text_features.t()
99+
100+
101+
regress_age = self.regressor(logits)
102+
103+
return logits, regress_age, image_features, text_features
104+
105+
def forward_text_only(self):
106+
sentence_embeds = self.prompt_learner()
107+
psudo_sentence_tokens = self.psudo_sentence_tokens
108+
text_features = self.text_encoder(sentence_embeds, psudo_sentence_tokens)
109+
110+
return text_features
111+
112+
def encode_image(self, x):
113+
return self.image_encoder(x)
114+
115+
116+
class TextEncoder(nn.Module):
117+
def __init__(self, clip_model):
118+
super().__init__()
119+
self.transformer = clip_model.transformer
120+
self.positional_embedding = clip_model.positional_embedding
121+
self.ln_final = clip_model.ln_final
122+
self.text_projection = clip_model.text_projection
123+
124+
def forward(self, prompts, tokenized_prompts):
125+
x = prompts.type(self.dtype) + self.positional_embedding.type(self.dtype)
126+
x = x.permute(1, 0, 2) # NLD -> LND
127+
x = self.transformer(x)
128+
x = x.permute(1, 0, 2) # LND -> NLD
129+
x = self.ln_final(x).type(self.dtype)
130+
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
131+
132+
return x
133+
134+
@property
135+
def dtype(self):
136+
return self.transformer.resblocks[0].mlp.c_fc.weight.dtype
137+
138+
139+
class Adapter(nn.Module):
140+
def __init__(self, c_in, reduction=4):
141+
super(Adapter, self).__init__()
142+
self.fc = nn.Sequential(
143+
nn.Linear(c_in, c_in // reduction, bias=False),
144+
nn.ReLU(inplace=True),
145+
nn.Linear(c_in // reduction, c_in, bias=False),
146+
nn.ReLU(inplace=True)
147+
)
148+
149+
def forward(self, x):
150+
x = self.fc(x)
151+
return x
152+
153+
154+
155+
class SSRModule(nn.Module):
156+
def __init__(self, stage_num=[5, 3], d=512,
157+
class_range=101, lambda_index=1., lambda_delta=1.):
158+
super(SSRModule, self).__init__()
159+
160+
self.stage_num = stage_num
161+
self.lambda_index = lambda_index
162+
self.lambda_delta = lambda_delta
163+
self.class_range = class_range
164+
self.d = d
165+
166+
self.stream1_stage2 = Adapter(self.d, 4)
167+
self.funsion_block_stream1_stage_2_prediction_block = nn.Linear(d, self.stage_num[1])
168+
self.funsion_block_stream1_stage_1_prediction_block = nn.Linear(d, self.stage_num[0])
169+
170+
self.stream2_stage2 = Adapter(self.d, 4)
171+
self.funsion_block_stream2_stage_2_prediction_block = nn.Linear(d, self.stage_num[1])
172+
self.funsion_block_stream2_stage_1_prediction_block = nn.Linear(d, self.stage_num[0])
173+
174+
self.stage2_FC_after_PB = nn.Sequential(
175+
nn.Linear(self.stage_num[1], 2 * self.stage_num[1]),
176+
nn.ReLU()
177+
)
178+
self.stage2_prob = nn.Sequential(
179+
nn.Linear(2 * self.stage_num[1], self.stage_num[1]),
180+
nn.ReLU()
181+
)
182+
self.stage2_index_offsets = nn.Sequential(
183+
nn.Linear(2 * self.stage_num[1], self.stage_num[1]),
184+
nn.Tanh()
185+
)
186+
self.stage2_delta_k = nn.Sequential(
187+
nn.Linear(2 * self.stage_num[1], 1),
188+
nn.Tanh()
189+
)
190+
self.stage1_FC_after_PB = nn.Sequential(
191+
nn.Linear(self.stage_num[0], 2 * self.stage_num[0]),
192+
nn.ReLU()
193+
)
194+
self.stage1_prob = nn.Sequential(
195+
nn.Linear(2 * self.stage_num[0], self.stage_num[0]),
196+
nn.ReLU()
197+
)
198+
self.stage1_index_offsets = nn.Sequential(
199+
nn.Linear(2 * self.stage_num[0], self.stage_num[0]),
200+
nn.Tanh()
201+
)
202+
self.stage1_delta_k = nn.Sequential(
203+
nn.Linear(2 * self.stage_num[0], self.stage_num[0]),
204+
nn.Tanh()
205+
)
206+
self.init_params()
207+
208+
def init_params(self):
209+
for m in self.modules():
210+
if isinstance(m, nn.Conv2d):
211+
init.kaiming_normal_(m.weight, mode='fan_out')
212+
if m.bias is not None:
213+
init.constant_(m.bias, 0)
214+
elif isinstance(m, nn.BatchNorm2d):
215+
init.constant_(m.weight, 1)
216+
init.constant_(m.bias, 0)
217+
elif isinstance(m, nn.Linear):
218+
init.normal_(m.weight, std=0.001)
219+
if m.bias is not None:
220+
init.constant_(m.bias, 0.0)
221+
222+
def forward(self, logits):
223+
224+
prob_stage_1 = F.softmax(logits, dim=1)
225+
embedding_stage1_after_PB = self.stage1_FC_after_PB(logits)
226+
stage1_delta_k = self.stage1_delta_k(embedding_stage1_after_PB)
227+
228+
stage1_regress_a = prob_stage_1[:, 0] * 0
229+
230+
for index in range(self.stage_num[0]):
231+
width = (bin_list_a[index] / (1 + self.lambda_delta * stage1_delta_k[:, index]))
232+
stage1_regress_a = stage1_regress_a + prob_stage_1[:, index] * width
233+
stage1_regress_a = torch.unsqueeze(stage1_regress_a, 1)
234+
235+
236+
regress_age_a = stage1_regress_a
237+
regress_age_a = regress_age_a.squeeze(1)
238+
239+
regress_age = regress_age_a
240+
241+
return regress_age
242+
243+
244+
def load_clip_to_cpu(
245+
text_encoder_name,
246+
image_encoder_name,
247+
root=osp.join(osp.expanduser("~/.cache/clip")),
248+
):
249+
# text backbone
250+
if logger is not None:
251+
print_func = logger.info
252+
else:
253+
print_func = print
254+
255+
print_func("Building CLIP model...")
256+
text_backbone_name = text_encoder_name
257+
print_func(f"Text backbone : {text_backbone_name}'s counterpart.")
258+
url = clip._MODELS[text_backbone_name]
259+
model_path = clip._download(url, root=root)
260+
261+
try:
262+
# loading JIT archive
263+
model = torch.jit.load(model_path, map_location="cpu").eval()
264+
state_dict = None
265+
266+
except RuntimeError:
267+
state_dict = torch.load(model_path, map_location="cpu")
268+
269+
model = clip.build_model(state_dict or model.state_dict())
270+
271+
# image backbone
272+
embed_dim = model.text_projection.shape[1]
273+
input_resolution = model.visual.input_resolution
274+
image_backbone_name = image_encoder_name
275+
print_func(f"Image backbone: {image_backbone_name}")
276+
277+
if image_backbone_name != text_backbone_name:
278+
# remove the stochastic back-prop in vgg and alexnet
279+
MODEL = getattr(image_encoders, image_backbone_name, None)
280+
if MODEL is None:
281+
MODEL = getattr(models, image_backbone_name, None)
282+
logger.warning(f"Try PyTorch Official image model: {image_backbone_name}")
283+
else:
284+
logger.info(f"Try Custom image model: {image_backbone_name}")
285+
if MODEL is None:
286+
raise ValueError(f"Invalid torchvison model name: {image_backbone_name}")
287+
model.visual = MODEL(num_classes=embed_dim)
288+
model.visual.input_resolution = input_resolution
289+
else:
290+
print_func(f"CLIP Image encoder: {image_backbone_name}!")
291+
292+
return model

run_regclipssr.sh

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# ----------------------------------clip regression ssrnet evaluation for morph dataset (few shot training)-------------------------------
2+
3+
4+
# CUDA_VISIBLE_DEVICES='7' python scripts/run.py \
5+
# --config configs/default_test.yaml \
6+
# --config configs/base_cfgs/data_cfg/datasets/morph/morph.yaml \
7+
# --config configs/base_cfgs/runner_cfg/model/image_encoder/clip-vitb16.yaml \
8+
# --config configs/base_cfgs/runner_cfg/model/text_encoder/clip-vitb16-cntprt.yaml \
9+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/init_context/init-context-morph_class.yaml \
10+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/num_ranks_5.yaml \
11+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/plain-prompt-learner.yaml \
12+
# --config configs/base_cfgs/runner_cfg/optim_sched/prompt_learner/tune-rank.yaml \
13+
# --config configs/base_cfgs/runner_cfg/optim_sched/image_encoder/tune-image.yaml \
14+
# --config configs/base_cfgs/runner_cfg/model/regclipssr.yaml \
15+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/init_rank/init-rank-morph_class.yaml
16+
# --config configs/base_cfgs/runner_cfg/init_weight.yaml
17+
# --config configs/base_cfgs/data_cfg/few_shots/num-shots-1.yaml
18+
# --config configs/base_cfgs/data_cfg/label_distribution_shift/num_topk_scaled_class/num-topk-scaled-class-40.yaml \
19+
# --config configs/base_cfgs/data_cfg/label_distribution_shift/scale_factor/scale-factor-01.yaml
20+
21+
22+
23+
24+
25+
26+
27+
# ----------------------------------clip regression ssrnet evaluation for aesthetics dataset(few shot training)-------------------------------
28+
# CUDA_VISIBLE_DEVICES='7' python scripts/run.py \
29+
# --config configs/default_aesthetics.yaml \
30+
# --config configs/base_cfgs/data_cfg/datasets/aesthetics/aesthetics.yaml \
31+
# --config configs/base_cfgs/runner_cfg/model/image_encoder/clip-vitb16.yaml \
32+
# --config configs/base_cfgs/runner_cfg/model/text_encoder/clip-vitb16-cntprt.yaml \
33+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/num_ranks_5.yaml \
34+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/plain-prompt-learner.yaml \
35+
# --config configs/base_cfgs/runner_cfg/optim_sched/prompt_learner/tune-ctx-rank.yaml \
36+
# --config configs/base_cfgs/runner_cfg/optim_sched/image_encoder/tune-image.yaml \
37+
# --config configs/base_cfgs/runner_cfg/model/regclipssr.yaml \
38+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/init_context/init-context-aesthetics_class.yaml \
39+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/init_rank/init-rank-aesthetics_class.yaml
40+
41+
# --config configs/base_cfgs/runner_cfg/init_weight.yaml
42+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/init_rank/init-rank-aesthetics.yaml
43+
# --config configs/base_cfgs/runner_cfg/init_weight.yaml
44+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/init_rank/init-rank-aesthetics_class.yaml
45+
# --config configs/base_cfgs/data_cfg/few_shots/num-shots-1.yaml
46+
47+
48+
49+
50+
51+
52+
# ----------------------------------clip regression ssrnet evaluation for historical dataset(few shot training)-------------------------------
53+
# CUDA_VISIBLE_DEVICES='3' python scripts/run.py \
54+
# --config configs/default_historical.yaml \
55+
# --config configs/base_cfgs/data_cfg/datasets/historical/historical.yaml \
56+
# --config configs/base_cfgs/runner_cfg/model/image_encoder/clip-vitb16.yaml \
57+
# --config configs/base_cfgs/runner_cfg/model/text_encoder/clip-vitb16-cntprt.yaml \
58+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/init_context/init-context-historical_class.yaml \
59+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/num_ranks_5.yaml \
60+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/plain-prompt-learner.yaml \
61+
# --config configs/base_cfgs/runner_cfg/optim_sched/prompt_learner/tune-rank.yaml \
62+
# --config configs/base_cfgs/runner_cfg/optim_sched/image_encoder/tune-image.yaml \
63+
# --config configs/base_cfgs/runner_cfg/model/regclipssr.yaml \
64+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/init_rank/init-rank-historical_class.yaml
65+
# --config configs/base_cfgs/runner_cfg/model/prompt_learner/init_rank/init-rank-historical.yaml
66+
# --config configs/base_cfgs/runner_cfg/init_weight.yaml
67+
# --config configs/base_cfgs/data_cfg/few_shots/num-shots-1.yaml

0 commit comments

Comments
 (0)