Skip to content

Commit

Permalink
Merge pull request #10 from salute-developers/feat/artsokol/demo
Browse files Browse the repository at this point in the history
Add a realtime demo
  • Loading branch information
artsokol authored Aug 24, 2023
2 parents 901ff2d + 40dabac commit 5e2108a
Show file tree
Hide file tree
Showing 4 changed files with 80,313 additions and 0 deletions.
Binary file not shown.
85 changes: 85 additions & 0 deletions dusha/demo/model/train.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from pathlib import Path
from core.dataset import MelEmotionsDataset, get_augm_func, adaptive_padding_collate_fn, LengthWeightedSampler
from core.model import ConvSelfAttentionMobileNet
from core.utils import load_jsonl_as_df
from torch.utils.data import DataLoader


base_path = Path('/raid/okutuzov/dusha_data_new_2/processed_dataset_0.9')
train_manifest_path = base_path / 'train' / 'podcast_train.jsonl'
val_manifest_path = base_path / 'test' / 'podcast_test.jsonl'

pt_model_path = Path('/raid/kondrat/dusha_experiments_try2/agg_0.9/crowd_lr_1e-3_try1/crowd_lr_1e-3_try1')

batch_size = 64
epoch_count = 100
learning_rate = 1e-3
optimizer_step = 5
optimizer_gamma = 1
weight_decay = 1e-6
clip_grad = False

collate_fn = adaptive_padding_collate_fn
augm_func = get_augm_func(time_mask_param=40, freq_mask_param=16, crop_augm_max_cut_size=40)

MAX_LENGTH = 16

def get_train_weights(_df):
train_weights = 1 + 9 * (_df.label.values == 0) + 19 * (_df.label.values == 1) + 4 * (_df.label.values == 3)
# train_weights = 1 + 29 * (_df.label.values == 0) + 49 * (_df.label.values == 1) + 9 * (_df.label.values == 3)
return train_weights


model_setting = [
# t, c, n, s
[1, 16, 1, 1],
[2, 32, 2, 2],
[2, 64, 6, 2],
[2, 128, 6, 2],
]

model = ConvSelfAttentionMobileNet(model_setting,
n_classes=4,
last_channel=128)


def get_train_dataset(_df, ds_base_path):
return MelEmotionsDataset(_df,
get_weights_func=get_train_weights,
augm_transform=augm_func,
base_path=ds_base_path)



def get_val_dataset(_df, ds_base_path):
return MelEmotionsDataset(_df, base_path=ds_base_path)



def get_train_dataloader(train_ds):
return DataLoader(train_ds, batch_size=batch_size, num_workers=1,
collate_fn=collate_fn,
sampler=LengthWeightedSampler(df=train_ds.df,
batch_size=batch_size,
min_length=0.3,
max_length=MAX_LENGTH,
length_delta=0.3,
decimals=1))



def get_val_dataloader(val_ds):
return DataLoader(val_ds, batch_size=1, num_workers=4, shuffle=False)


train_dataset = get_train_dataset(load_jsonl_as_df(train_manifest_path),
ds_base_path=train_manifest_path.parent)
val_dataset = get_val_dataset(load_jsonl_as_df(val_manifest_path),
ds_base_path=val_manifest_path.parent)

dataloaders = {'train': get_train_dataloader(train_ds=train_dataset),
'validate': get_val_dataloader(val_ds=val_dataset)}

DUMP_BEST_CHECKPOINTS = True
DUMP_LAST_CHECKPOINTS = True
BEST_CHECKPOINTS_WARMUP = 5
Loading

0 comments on commit 5e2108a

Please sign in to comment.