diff --git a/scripts/convert_lit_checkpoint.py b/scripts/convert_lit_checkpoint.py index 8a3b101a7d..1f04a9cdf5 100644 --- a/scripts/convert_lit_checkpoint.py +++ b/scripts/convert_lit_checkpoint.py @@ -247,6 +247,8 @@ def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: def convert_lit_checkpoint(checkpoint_path: Path, output_path: Path, config_path: Path) -> None: config = Config.from_json(config_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + if "falcon" in config.name: copy_fn = partial(copy_weights_falcon, config.name) elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):