From ea7e0e8775970ab383b9e614bf27016d00d7976a Mon Sep 17 00:00:00 2001 From: Benjamin-Walker Date: Thu, 23 Oct 2025 14:39:44 +0100 Subject: [PATCH 1/2] Length gen experiments --- length_gen.py | 1 + train.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/length_gen.py b/length_gen.py index f31ee0f..b61d784 100644 --- a/length_gen.py +++ b/length_gen.py @@ -94,6 +94,7 @@ def build_model(model_name, config, data_dim, label_dim, device): dropout_rate=config.get("dropout_rate", 0.01), use_glu=config.get("use_glu", False), diagonal=config.get("diagonal", False), + diagonal_dense=config.get("diagonal_dense", False), fwht=config.get("fwht", False), second_embedding=config.get("second_embedding", False), rank=config.get("rank", 0), diff --git a/train.py b/train.py index 58ec0c5..88d7670 100644 --- a/train.py +++ b/train.py @@ -118,6 +118,7 @@ def train_model( criterion = nn.CrossEntropyLoss() step = 0 total_loss = 0 + best_val_acc = 0 steps = [] val_accs = [] start_time = time.time() @@ -232,8 +233,10 @@ def train_model( checkpoint_filename = f"checkpoint_{task}_{model_name}_{time_str}.pt" checkpoint_path = os.path.join("checkpoints", checkpoint_filename) - torch.save(checkpoint, checkpoint_path) - print(f"Saved model checkpoint to: {checkpoint_path}") + if accuracy > best_val_acc: + best_val_acc = accuracy + torch.save(checkpoint, checkpoint_path) + print(f"Saved model checkpoint to: {checkpoint_path}") early_stop = accuracy > early_stop_threshold @@ -336,9 +339,15 @@ def train_dataloader_multilength(): dataloader = {"train": train_dataloader_multilength(), "val": val_dataloader} elif task == "A5_generalise": - train_padding_length = 128 - if model_name == "lcde": - train_padding_length = 20 + if model_name[:8] == "deltanet" or model_name == "deltaproduct": + train_padding_length = 65 + elif model_name == "xlstm": + if slstm_at == []: + train_padding_length = 128 + else: + train_padding_length = 40 + else: + train_padding_length = 40 val_padding_length = 128 train_dataloader, _, data_dim, label_dim = create_group_dataloaders( group="A5", @@ -360,6 +369,7 @@ def train_dataloader_multilength(): train_split=1.0, seed=2 * seed, ) + dataloader = {"train": train_dataloader, "val": val_dataloader} else: From 295b09a19fc4ecd9be34ca65ad60515aefee6927 Mon Sep 17 00:00:00 2001 From: Benjamin-Walker Date: Fri, 24 Oct 2025 11:22:20 +0100 Subject: [PATCH 2/2] Removed padding length for mLSTM on A5 length gen --- train.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/train.py b/train.py index 88d7670..c292a1a 100644 --- a/train.py +++ b/train.py @@ -341,11 +341,6 @@ def train_dataloader_multilength(): elif task == "A5_generalise": if model_name[:8] == "deltanet" or model_name == "deltaproduct": train_padding_length = 65 - elif model_name == "xlstm": - if slstm_at == []: - train_padding_length = 128 - else: - train_padding_length = 40 else: train_padding_length = 40 val_padding_length = 128