diff --git a/evaluation/eval_percep.py b/evaluation/eval_percep.py index 53377a7..f49bb88 100644 --- a/evaluation/eval_percep.py +++ b/evaluation/eval_percep.py @@ -52,7 +52,7 @@ def load_dreamsim_model(args, device="cuda"): with open(os.path.join(args.eval_checkpoint_cfg), "r") as f: cfg = yaml.load(f, Loader=yaml.Loader) - model_cfg = vars(cfg) + model_cfg = cfg model_cfg['load_dir'] = args.load_dir model = LightningPerceptualModel(**model_cfg) model.load_lora_weights(args.eval_checkpoint) @@ -141,4 +141,4 @@ def run(args, device): args = parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" run(args, device) - \ No newline at end of file + diff --git a/training/train.py b/training/train.py index f8105ef..e25aabb 100644 --- a/training/train.py +++ b/training/train.py @@ -184,7 +184,13 @@ def load_lora_weights(self, checkpoint_root, epoch_load=None): if self.save_mode in {'adapter_only', 'all'}: if epoch_load is not None: checkpoint_root = os.path.join(checkpoint_root, f'epoch_{epoch_load}') - + + with open(os.path.join(checkpoint_root, 'adapter_config.json'), 'r') as f: + adapter_config = json.load(f) + lora_keys = ['r', 'lora_alpha', 'lora_dropout', 'bias', 'target_modules'] + lora_config = LoraConfig(**{k: adapter_config[k] for k in lora_keys}) + self.perceptual_model = get_peft_model(self.perceptual_model, lora_config) + logging.info(f'Loading adapter weights from {checkpoint_root}') self.perceptual_model = PeftModel.from_pretrained(self.perceptual_model.base_model.model, checkpoint_root).to(self.device) else: