-
Notifications
You must be signed in to change notification settings - Fork 2
/
experiments_run.py
1544 lines (1415 loc) · 87 KB
/
experiments_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import logging
import os
import pandas as pd
import copy
import pickle
import numpy as np
import random
import json
from model import mafft_integration, model_tools
from model.helpers import misc, import_data, stats
from model.orientation import orientation_tools
from model.bdm_likelihood_computation import symbolic_lh_computation
from model.experiments_tools import expand_sim_parameters, multiple_selection_fcts, construct_tree, filter_fct, \
load_align_pickled_data, load_arrays_from_pkl, align_crispr_groups_for_sim_as_rec, pooling_for_parameter_estimation, \
load_align_single_fasta, expand_sim_parameters_based_on_import, tree_handling, name_inner_nodes_if_unnamed
from additional_data.additional_scripts.model.simulation_tree import SimulationTree
from model.reconstruction_tree import ReconstructionTree
from model.data_classes.advanced_tree import AdvancedTree
WORK_PATH = os.path.join('data', 'simulation_alignment')
TREE_GENERATION_PARAMETERS = {
'simple_insertion_rate': 6522.48,
'simple_deletion_rate': 137.20,
'simple_alpha': 2.7278,
'precise_insertion_rate': 0.6108,
'precise_deletion_rate': 0.1830,
'precise_alpha': 3.3377,
}
LH_FCT_PATH = os.path.join('additional_data', '0_lh',
'230329_death_lh_up_to_68_lambdifyed.pickle'
)
def run_multiple_groups(ls_data_path, save_path, rec_parameter_dict, lh_fct=None, logger=None, plot_tree=True,
do_show=False, combine_non_unique_arrays=False, determine_orientation=True,
orientation_decision_boundary=10,
tree_path=None, plot_order=True, significance_level=0.05, extend_branches=False,
tree_distance_fct='likelihood', tree_construction_method='upgma',
tree_lh_fct=None, tree_insertion_rate=None, tree_deletion_rate=None, tree_alpha=None,
alpha_bias_correction=True, rho_bias_correction=True,
seed=None,
save_reconstructed_events=False,
dpi=90,
figsize_rec=(None, None, 'px'),
):
if seed is not None:
np.random.seed(seed)
dict_crispr_groups = {}
run_timer = misc.RunTimer()
save_data_folder = os.path.join(save_path, 'additional_data')
logger.info(f'Loading and aligning the data...')
for i, data_path in enumerate(ls_data_path):
group_name = os.path.split(data_path)[-1].split('.')[0]
ind_save_path = os.path.join(save_data_folder, 'work_folder', group_name)
if not os.path.exists(ind_save_path):
os.makedirs(ind_save_path)
logger.info(f'{i + 1} / {len(ls_data_path)} aligning group {group_name}... \n')
crispr_group = load_align_single_fasta(data_path, ind_save_path,
group_name=group_name,
logger=logger,
mafft_options=None,
save_path=None,
seed=seed)
dict_crispr_groups[group_name] = crispr_group
# logger.debug(f'Loading and aligning time: {run_timer.time_from_last_checkpoint()}')
print_save_path = os.path.join(save_data_folder, 'work_folder')
logger.debug(f'Saving CRISPR groups to: {print_save_path}')
if not os.path.exists(save_data_folder):
os.makedirs(save_data_folder)
pickle.dump(dict_crispr_groups, open(os.path.join(save_data_folder, 'dict_crispr_groups.pkl'), 'wb'))
if lh_fct.lower() in ['ode_based', 'precise']:
logger.info(f'Loading lh function from {LH_FCT_PATH}')
# Implement loading a shortened likelihood function depending on something.
lh_fct = symbolic_lh_computation.load_pickled_lambdified_lh_fct(LH_FCT_PATH)
logger.info(f'Loading time: {run_timer.time_from_last_checkpoint()}')
if tree_lh_fct in ['ode_based', 'precise']:
tree_lh_fct = symbolic_lh_computation.load_pickled_lambdified_lh_fct(LH_FCT_PATH) if lh_fct not in ['ode_based',
'precise'] \
else lh_fct
dict_crispr_groups_for_reverse = copy.deepcopy(dict_crispr_groups) if determine_orientation else {}
logger.info(f'Beginning reconstruction of (forward oriented) group(s)...\n')
df_rec_protocol, \
df_rec_protocol_trivial, _, \
dict_provided_numbering, dict_trees_forward, \
dict_provided_aligned_arrays, \
dict_provided_duplicated_spacers = run_reconstruction(rec_parameter_dict,
dict_crispr_groups,
save_path=os.path.join(
save_path,
'0_forward'),
plot_tree=plot_tree,
lh_fct=lh_fct,
logfile_path=None,
logger=logger, do_show=do_show,
combine_non_unique_arrays=combine_non_unique_arrays,
tree_path=tree_path,
use_provided_trees=True if tree_path is not None
else False,
tree_save_path=None,
hide_unobserved_spacers=False,
plot_order=plot_order,
finite_bl_at_root=True,
group_by=None,
significance_level=significance_level,
extend_branches=extend_branches,
tree_lh_fct=tree_lh_fct,
tree_construction_method=tree_construction_method,
tree_distance_function=tree_distance_fct,
tree_gain_rate=tree_insertion_rate,
tree_loss_rate=tree_deletion_rate,
tree_alpha=tree_alpha,
alpha_bias_correction=alpha_bias_correction,
rho_bias_correction=rho_bias_correction,
core_genome_trees=False,
metadata=False,
save_reconstructed_events=save_reconstructed_events,
dpi=dpi,
figsize_rec=figsize_rec,
)
dict_trees = dict_trees_forward
df_rec_protocol = df_rec_protocol.set_index('name')
df_results_wo_details = df_rec_protocol
if not df_rec_protocol_trivial.empty:
df_rec_protocol_trivial = df_rec_protocol_trivial.set_index('name')
for gn, provided_numbering in dict_provided_numbering.items():
dict_spacer_names_to_numbers = provided_numbering[0]
p = os.path.join(save_data_folder, 'spacer_fasta', gn + '_spacer_name_to_sp_number.json')
if not os.path.exists(os.path.dirname(p)):
os.makedirs(os.path.dirname(p))
json.dump(dict_spacer_names_to_numbers, open(p,
'w'))
if determine_orientation:
logger.info(f'Beginning reconstruction of reversed group(s)...\n')
if not os.path.exists(os.path.join(save_path, '0_reversed')):
os.makedirs(os.path.join(save_path, '0_reversed'))
for gn, crispr_group in dict_crispr_groups_for_reverse.items():
for n, ca in crispr_group.crispr_dict.items():
ca.spacer_array = ca.spacer_array[::-1]
dict_provided_aligned_arrays = {k: v[v.columns[::-1]] for k, v in dict_provided_aligned_arrays.items()}
df_rec_protocol_reversed, \
df_rec_protocol_reversed_trivial, \
_, _, dict_trees_reversed, \
_, _ = run_reconstruction(rec_parameter_dict,
dict_crispr_groups_for_reverse,
save_path=os.path.join(
save_path,
'0_reversed'),
plot_tree=plot_tree,
lh_fct=lh_fct,
logfile_path=None,
logger=logger, do_show=do_show,
combine_non_unique_arrays=combine_non_unique_arrays,
tree_path=tree_path,
use_provided_trees=True if tree_path is not None
else False,
tree_save_path=None,
hide_unobserved_spacers=False,
plot_order=plot_order,
finite_bl_at_root=True,
group_by=None,
significance_level=significance_level,
extend_branches=extend_branches,
tree_lh_fct=tree_lh_fct,
tree_construction_method=tree_construction_method,
tree_distance_function=tree_distance_fct,
tree_gain_rate=tree_insertion_rate,
tree_loss_rate=tree_deletion_rate,
tree_alpha=tree_alpha,
alpha_bias_correction=alpha_bias_correction,
rho_bias_correction=rho_bias_correction,
core_genome_trees=False,
provided_numbering=dict_provided_numbering,
provided_aligned_arrays=dict_provided_aligned_arrays,
provided_dict_duplicated_spacers=dict_provided_duplicated_spacers,
metadata=False,
save_reconstructed_events=save_reconstructed_events,
dpi=dpi,
figsize_rec=figsize_rec,)
dict_trees = {'forward': dict_trees_forward,
'reversed': dict_trees_reversed}
df_rec_protocol_reversed = df_rec_protocol_reversed.set_index('name')
if not df_rec_protocol_reversed_trivial.empty:
df_rec_protocol_reversed_trivial = df_rec_protocol_reversed_trivial.set_index('name')
df_rec_protocol, df_oriented_rec_protocol, dict_trees = orientation_tools.compare_likelihoods_for_orientation(
df_rec_protocol,
df_rec_protocol_reversed,
df_rec_protocol_trivial,
df_rec_protocol_reversed_trivial,
decision_boundary=orientation_decision_boundary, dict_trees=dict_trees)
detailed_protocols_save_path = os.path.join(save_path, 'detailed_results')
if not os.path.exists(detailed_protocols_save_path):
os.makedirs(detailed_protocols_save_path)
df_rec_protocol.to_csv(os.path.join(detailed_protocols_save_path,
'_'.join(['0_detailed_results_w_orientation']) + '.csv'))
df_rec_protocol.to_pickle(os.path.join(detailed_protocols_save_path,
'_'.join(['0_detailed_results_w_orientation']) + '.pkl'))
df_oriented_rec_protocol.to_csv(os.path.join(detailed_protocols_save_path,
'0_detailed_results_wo_trivial_groups.csv'))
df_oriented_rec_protocol.to_pickle(os.path.join(detailed_protocols_save_path,
'0_detailed_results_wo_trivial_groups.pkl'))
df_results_wo_details = df_oriented_rec_protocol
df_results_wo_details = df_results_wo_details.drop(
columns=['relative deletion positions', 'nb of existent spacers',
'all max block deletion lengths',
'all max block deletion lengths (normalized)',
'reversed_lh_idm', 'reversed_lh_bdm',
'lh_idm / reversed_lh_idm',
'lh_bdm / reversed_lh_bdm',
'lh_idm', 'lh_bdm',
'excluding young spacers deletion_rate_idm',
'excluding young spacers deletion_rate_bdm',
'excluding young spacers alpha_bdm',
'containing old spacers deletion_rate_idm',
'containing old spacers deletion_rate_bdm',
'containing old spacers alpha_bdm', 'fes +m presence',
'les -m presence',
'ls -m presence', 'fs +m presence', 'global fes +m presence',
'global les -m presence', 'global ls -m presence',
'combined array names',
],
errors='ignore')
df_results_wo_details.to_csv(os.path.join(save_path, '_'.join(['0_results']) + '.csv'))
# write_summary()
# summary = {**summary_forward, **summary_reversed}
tree_save_path = os.path.join(save_path, 'additional_data')
if tree_save_path is not None:
with open(os.path.join(tree_save_path, 'dict_nwk_trees.json'), 'w') as f:
json.dump(dict_trees, f)
return df_results_wo_details
def run_pickled_data(rec_parameter_dict, data_path, save_path=None, plot_tree=True,
logfile_path=None,
do_show=False,
combine_non_unique_arrays=True,
tree_path=None,
tree_save_path=None,
hide_unobserved_spacers=False,
selection_fct=None,
plot_order=True,
lh_fct=None,
significance_level=0.05,
finite_bl_at_root=True,
group_by=None,
alignment_options_dict=None,
extend_branches=False,
tree_distance_function='breakpoint',
tree_gain_rate=None,
tree_loss_rate=None,
tree_alpha=None,
determine_orientation=False,
orient_boundary=0,
alpha_bias_correction=False,
rho_bias_correction=False,
core_genome_trees=False,
save_reconstructed_events=False,
dpi=90,
figsize_rec=(None, None, 'px'),
):
if not os.path.exists(os.path.split(logfile_path)[0]):
os.makedirs(os.path.split(logfile_path)[0])
if isinstance(logfile_path, str):
logger = misc.create_logger('pickled data based reconstruction', logging.INFO, outfile=logfile_path)
else:
logger = logfile_path
if data_path.split('.')[-1] in ['pkl', 'pickle']:
logger.info(f'Loading data from pickle {data_path}')
dict_crispr_groups = load_arrays_from_pkl(data_path)
logger.info(f'Number of groups before reconstruction: {len(dict_crispr_groups)}')
nb_arrays = sum([len(g.crispr_dict) for g in dict_crispr_groups.values()])
logger.info(f'Number of arrays before reconstruction: {nb_arrays}')
else:
logger.info(f'Loading and aligning from files {data_path}')
dict_crispr_groups = load_align_pickled_data(data_path, mafft_options=alignment_options_dict['mafft_options'],
exclude_files=alignment_options_dict['exclude_files'],
save_path=alignment_options_dict['save_path'])
dict_crispr_groups_for_reverse = copy.deepcopy(dict_crispr_groups) if determine_orientation else {}
(df_rec_protocol, df_rec_protocol_boring, (dict_data_for_group_by, lh_fct), dict_provided_numbering,
dict_trees_forward,
dict_provided_aligned_arrays,
dict_provided_duplicated_spacers) = run_reconstruction(rec_parameter_dict,
dict_crispr_groups,
save_path=os.path.join(
save_path,
'0_forward'),
plot_tree=plot_tree,
lh_fct=lh_fct,
logfile_path=logfile_path,
logger=logger, do_show=do_show,
combine_non_unique_arrays=combine_non_unique_arrays,
tree_path=tree_path,
use_provided_trees=True if tree_path is not None else False,
tree_save_path=os.path.join(
tree_save_path,
'0_forward'),
hide_unobserved_spacers=hide_unobserved_spacers,
selection_fct=selection_fct,
plot_order=plot_order,
finite_bl_at_root=finite_bl_at_root,
group_by=group_by,
significance_level=significance_level,
extend_branches=extend_branches,
tree_distance_function=tree_distance_function,
tree_gain_rate=tree_gain_rate,
tree_loss_rate=tree_loss_rate,
tree_alpha=tree_alpha,
alpha_bias_correction=alpha_bias_correction,
rho_bias_correction=rho_bias_correction,
core_genome_trees=core_genome_trees,
save_reconstructed_events=save_reconstructed_events,
dpi=dpi,
figsize_rec=figsize_rec,)
dict_trees = dict_trees_forward
df_rec_protocol = df_rec_protocol.set_index('name')
if not df_rec_protocol_boring.empty:
df_rec_protocol_boring = df_rec_protocol_boring.set_index('name')
if determine_orientation:
if not os.path.exists(os.path.join(save_path, '0_reversed')):
os.makedirs(os.path.join(save_path, '0_reversed'))
for gn, crispr_group in dict_crispr_groups_for_reverse.items():
for n, ca in crispr_group.crispr_dict.items():
ca.spacer_array = ca.spacer_array[::-1]
dict_provided_aligned_arrays = {k: v[v.columns[::-1]] for k, v in dict_provided_aligned_arrays.items()}
(df_rec_protocol_reversed,
df_rec_protocol_reversed_boring,
(dict_data_for_group_by_reversed, lh_fct), _, dict_trees_reversed,
_, _) = run_reconstruction(rec_parameter_dict,
dict_crispr_groups_for_reverse,
save_path=os.path.join(
save_path,
'0_reversed'),
plot_tree=plot_tree,
lh_fct=lh_fct,
logfile_path=logfile_path,
logger=logger,
do_show=do_show,
combine_non_unique_arrays=combine_non_unique_arrays,
tree_path=tree_path,
use_provided_trees=True
if tree_path is not None else False,
tree_save_path=os.path.join(
tree_save_path,
'0_reversed'),
hide_unobserved_spacers=hide_unobserved_spacers,
selection_fct=selection_fct,
plot_order=plot_order,
finite_bl_at_root=finite_bl_at_root,
group_by=group_by,
significance_level=significance_level,
extend_branches=extend_branches,
tree_distance_function=tree_distance_function,
tree_gain_rate=tree_gain_rate,
tree_loss_rate=tree_loss_rate,
tree_alpha=tree_alpha,
provided_numbering=dict_provided_numbering,
alpha_bias_correction=alpha_bias_correction,
rho_bias_correction=rho_bias_correction,
core_genome_trees=core_genome_trees,
provided_aligned_arrays=dict_provided_aligned_arrays,
provided_dict_duplicated_spacers=dict_provided_duplicated_spacers,
save_reconstructed_events=save_reconstructed_events,
dpi=dpi,
figsize_rec=figsize_rec,)
dict_trees = {'forward': dict_trees_forward,
'reversed': dict_trees_reversed}
df_rec_protocol_reversed = df_rec_protocol_reversed.set_index('name')
if not df_rec_protocol_reversed_boring.empty:
df_rec_protocol_reversed_boring = df_rec_protocol_reversed_boring.set_index('name')
df_rec_protocol, df_oriented_rec_protocol, dict_trees = orientation_tools.compare_likelihoods_for_orientation(
df_rec_protocol,
df_rec_protocol_reversed,
df_rec_protocol_boring,
df_rec_protocol_reversed_boring,
decision_boundary=orient_boundary, dict_trees=dict_trees)
if group_by is not None:
logger.info(f'Grouping trees for parameter estimation by: {group_by}')
if isinstance(group_by, int):
val_group_by = []
for i in range(df_oriented_rec_protocol.shape[0] // group_by):
val_group_by += [i] * group_by
val_group_by += [max(val_group_by)] * (df_oriented_rec_protocol.shape[0] % group_by)
elif isinstance(group_by, pd.Series):
# solve this issue here pooling and group_by do not agree in dimensions!!!!
val_group_by = group_by
else:
val_group_by = df_oriented_rec_protocol[group_by]
ls_data_for_lh_ratio = []
for idx in df_oriented_rec_protocol.index:
ls_data_for_lh_ratio.append(dict_data_for_group_by_reversed[idx] \
if df_oriented_rec_protocol.loc[
idx, 'predicted orientation'] == 'Reverse' \
else dict_data_for_group_by[idx])
group_names, groups_lh_0, groups_lh_1, \
groups_result_0, groups_result_1 = pooling_for_parameter_estimation(ls_data_for_lh_ratio,
group_by=val_group_by,
give_lh_fct=lh_fct)
ls_loss_rates_0 = [g.x for g in groups_result_0]
ls_loss_rates_1 = [g.x[0] for g in groups_result_1]
ls_alphas_1 = [g.x[1] for g in groups_result_1]
ls_ln_lh_ratio = [2 * (g_0.fun - g_1.fun) for g_0, g_1 in zip(groups_result_0, groups_result_1)]
ls_test_results = []
ls_confidences = []
ls_sig_values = []
for llh in ls_ln_lh_ratio:
test_result, confidence = stats.test_significance_ratio_chi2(llh, significance_level)
ls_test_results.append(test_result)
ls_confidences.append(confidence)
ls_sig_values.append(significance_level)
df_group_protocol = pd.DataFrame(list(zip(group_names, groups_lh_0, groups_lh_1,
ls_loss_rates_0, ls_loss_rates_1,
ls_alphas_1, ls_ln_lh_ratio, ls_test_results,
ls_confidences, ls_sig_values,
)),
columns=['group_by', 'lh_idm', 'lh_bdm', 'deletion_rate_idm',
'deletion_rate_bdm',
'alpha_bdm', 'test statistic (-2*ln_lh_ratio)',
'test result', 'chi2_quantile',
'significance level'])
df_group_protocol.to_csv(os.path.join(save_path, '0_final_group_by_protocol.csv'))
df_rec_protocol.to_csv(os.path.join(save_path, '_'.join(['0_protocol_w_orientation']) + '.csv'))
df_rec_protocol.to_pickle(os.path.join(save_path, '_'.join(['0_protocol_w_orientation']) + '.pkl'))
df_oriented_rec_protocol.to_csv(os.path.join(save_path, '0_final_oriented_protocol.csv'))
df_oriented_rec_protocol.to_pickle(os.path.join(save_path, '0_final_oriented_protocol.pkl'))
if tree_save_path is not None:
with open(os.path.join(tree_save_path, 'final_dict_nwk_trees.json'), 'w') as f:
json.dump(dict_trees, f)
return
def run_simulation_and_reconstruction(sim_parameter_dict, rec_parameter_dict, save_path=None,
rec_save_path=None, plot_tree=True,
sim_save_path=None, sim_plot_tree=True,
sim_logfile_path=None,
logfile_path=None,
do_show=False,
sim_as_rec=False,
hide_unobserved_spacers=False,
selection_fct=None,
plot_order=True,
significance_level=0.05,
alignment_options=None,
load_sim_from_pkl=False,
finite_bl_at_root=True,
sim_simple_alignment=False,
group_by=None,
extend_branches=False,
use_sim_tree=True,
tree_distance_function='breakpoint',
tree_gain_rate=None,
tree_loss_rate=None,
tree_alpha=None,
tree_save_path=None,
determine_orientation=False,
randomize_orientation=False,
orient_decision_boundary=0,
load_sim_param_from_pkl=None,
alpha_bias_correction=False,
rho_bias_correction=False,
dpi=90,
figsize_rec=(None, None, 'px'),
):
"""
This function runs a simulation and a reconstruction and compares the results.
:param figsize_rec:
:param rho_bias_correction:
:param alpha_bias_correction:
:param sim_parameter_dict:
:param rec_parameter_dict:
:param save_path:
:param rec_save_path:
:param plot_tree:
:param sim_save_path:
:param sim_plot_tree:
:param sim_logfile_path:
:param logfile_path:
:param do_show:
:param sim_as_rec:
:param hide_unobserved_spacers:
:param selection_fct:
:param plot_order:
:param significance_level:
:param alignment_options:
:param load_sim_from_pkl:
:param finite_bl_at_root:
:param sim_simple_alignment:
:param group_by:
:param extend_branches:
:param use_sim_tree:
:param tree_distance_function:
:param tree_gain_rate:
:param tree_loss_rate:
:param tree_alpha:
:param tree_save_path:
:param determine_orientation:
:param randomize_orientation:
:param orient_decision_boundary:
:param load_sim_param_from_pkl:
:return:
"""
if not os.path.exists(os.path.split(logfile_path)[0]):
os.makedirs(os.path.split(logfile_path)[0])
logger = misc.create_logger('Simulation -> Reconstruction', logging.INFO, outfile=logfile_path)
if load_sim_from_pkl:
logger.info(f'Loading simulations from {sim_save_path}...')
with open(os.path.join(sim_save_path, 'dict_crispr_groups.pkl'), 'rb') as f:
dict_crispr_groups = pickle.load(f)
df_sim_protocol = pd.read_pickle(os.path.join(sim_save_path, '_'.join(['0_sim_protocol']) + '.pkl'))
else:
logger.info('Starting simulations...')
df_sim_protocol, dict_crispr_groups, ls_sim_m = run_simulation(sim_parameter_dict, save_path=sim_save_path,
plot_tree=sim_plot_tree,
logfile_path=sim_logfile_path,
sim_as_rec=sim_as_rec,
randomize_orientation=randomize_orientation,
load_sim_param_from_pkl=load_sim_param_from_pkl,
)
# Alignment
if sim_as_rec or sim_simple_alignment:
logger.info('Simulation is used as Reconstruction or sim_simple_alignment was chosen -> '
'alignment is done based on simulated top_order.')
dict_crispr_groups = align_crispr_groups_for_sim_as_rec(dict_crispr_groups)
else:
logger.info('Starting alignment...')
dict_crispr_groups = mafft_integration.align_crispr_groups(WORK_PATH, dict_crispr_groups,
mafft_options=alignment_options,
logger=logger, )
if sim_save_path:
with open(os.path.join(sim_save_path, 'dict_crispr_groups.pkl'), 'wb') as f:
pickle.dump(dict_crispr_groups, f)
# If we want to use simulated parameters for grouping.
if group_by is not None:
if isinstance(group_by, str):
if group_by.split('|')[0] == 'sim':
group_by = df_sim_protocol.set_index('name')[''.join(group_by.split('|')[1:])]
logger.info('Starting reconstruction...')
if not os.path.exists(os.path.join(rec_save_path, '0_forward')):
os.makedirs(os.path.join(rec_save_path, '0_forward'))
dict_crispr_groups_for_reverse = copy.deepcopy(dict_crispr_groups) if determine_orientation else {}
(df_rec_protocol, df_rec_protocol_boring, (
dict_data_for_group_by, lh_fct),
dict_provided_numbering, dict_trees_forward,
dict_provided_aligned_arrays,
dict_provided_duplicated_spacers) = run_reconstruction(
rec_parameter_dict,
dict_crispr_groups,
save_path=os.path.join(
rec_save_path,
'0_forward'),
plot_tree=plot_tree,
logfile_path=logfile_path,
do_show=do_show,
combine_non_unique_arrays=False,
tree_path=os.path.join(
sim_save_path,
'dict_nwk_trees.pkl'),
use_provided_trees=use_sim_tree,
tree_save_path=os.path.join(
tree_save_path,
'0_forward') if tree_save_path is not None else None,
sim_as_rec=sim_as_rec,
hide_unobserved_spacers=hide_unobserved_spacers,
selection_fct=selection_fct,
plot_order=plot_order,
significance_level=significance_level,
finite_bl_at_root=finite_bl_at_root,
logger=logger,
group_by=group_by,
extend_branches=extend_branches,
tree_distance_function=tree_distance_function,
tree_gain_rate=tree_gain_rate,
tree_loss_rate=tree_loss_rate,
tree_alpha=tree_alpha,
alpha_bias_correction=alpha_bias_correction,
rho_bias_correction=rho_bias_correction,
dpi=dpi,
figsize_rec=figsize_rec,
)
dict_trees = dict_trees_forward
df_rec_protocol = df_rec_protocol.set_index('name')
if not df_rec_protocol_boring.empty:
df_rec_protocol_boring = df_rec_protocol_boring.set_index('name')
if determine_orientation:
if not os.path.exists(os.path.join(rec_save_path, '0_reversed')):
os.makedirs(os.path.join(rec_save_path, '0_reversed'))
for gn, crispr_group in dict_crispr_groups_for_reverse.items():
for n, ca in crispr_group.crispr_dict.items():
ca.spacer_array = ca.spacer_array[::-1]
dict_provided_aligned_arrays = {k: v[v.columns[::-1]] for k, v in dict_provided_aligned_arrays.items()}
(df_rec_protocol_reversed,
df_rec_protocol_reversed_boring,
(dict_data_for_group_by_reversed, lh_fct), _, dict_tree_reversed,
_, _) = run_reconstruction(rec_parameter_dict,
dict_crispr_groups_for_reverse,
save_path=os.path.join(
rec_save_path,
'0_reversed'),
plot_tree=plot_tree,
logfile_path=logfile_path,
do_show=do_show,
combine_non_unique_arrays=False,
tree_path=os.path.join(
sim_save_path,
'dict_nwk_trees.pkl'),
use_provided_trees=use_sim_tree,
tree_save_path=os.path.join(
tree_save_path,
'0_reversed') if tree_save_path is not None else None,
sim_as_rec=sim_as_rec,
hide_unobserved_spacers=hide_unobserved_spacers,
selection_fct=selection_fct,
plot_order=plot_order,
significance_level=significance_level,
finite_bl_at_root=finite_bl_at_root,
logger=logger,
group_by=group_by,
extend_branches=extend_branches,
tree_distance_function=tree_distance_function,
tree_gain_rate=tree_gain_rate,
tree_loss_rate=tree_loss_rate,
tree_alpha=tree_alpha,
alpha_bias_correction=alpha_bias_correction,
rho_bias_correction=rho_bias_correction,
provided_aligned_arrays=dict_provided_aligned_arrays,
provided_dict_duplicated_spacers=dict_provided_duplicated_spacers,
dpi=dpi,
figsize_rec=figsize_rec,
)
dict_trees = {'forward': dict_trees_forward, 'reversed': dict_tree_reversed}
df_rec_protocol_reversed = df_rec_protocol_reversed.set_index('name')
if not df_rec_protocol_reversed_boring.empty:
df_rec_protocol_reversed_boring = df_rec_protocol_reversed_boring.set_index('name')
df_rec_protocol, df_oriented_rec_protocol, dict_trees = orientation_tools.compare_likelihoods_for_orientation(
df_rec_protocol,
df_rec_protocol_reversed,
df_rec_protocol_boring,
df_rec_protocol_reversed_boring,
decision_boundary=orient_decision_boundary, dict_trees=dict_trees)
if group_by is not None:
logger.info(f'Grouping trees for parameter estimation by: {group_by}')
if isinstance(group_by, int):
val_group_by = []
for i in range(df_oriented_rec_protocol.shape[0] // group_by):
val_group_by += [i] * group_by
val_group_by += [max(val_group_by)] * (df_oriented_rec_protocol.shape[0] % group_by)
elif isinstance(group_by, pd.Series):
val_group_by = [group_by[idx] for idx in df_oriented_rec_protocol.index]
else:
val_group_by = df_oriented_rec_protocol[group_by]
ls_data_for_lh_ratio = []
filtered_val_group_by = []
for i, idx in enumerate(df_oriented_rec_protocol.index):
if df_oriented_rec_protocol.loc[idx, 'our predicted orientation'] == 'Reverse':
if idx in dict_data_for_group_by_reversed:
ls_data_for_lh_ratio.append(dict_data_for_group_by_reversed[idx])
filtered_val_group_by.append(val_group_by[i])
else:
if idx in dict_data_for_group_by:
ls_data_for_lh_ratio.append(dict_data_for_group_by[idx])
filtered_val_group_by.append(val_group_by[i])
group_names, groups_lh_0, groups_lh_1, \
groups_result_0, groups_result_1 = pooling_for_parameter_estimation(ls_data_for_lh_ratio,
group_by=filtered_val_group_by,
give_lh_fct=lh_fct)
ls_loss_rates_0 = [g.x for g in groups_result_0]
ls_loss_rates_1 = [g.x[0] for g in groups_result_1]
ls_alphas_1 = [g.x[1] for g in groups_result_1]
ls_ln_lh_ratio = [2 * (g_0.fun - g_1.fun) for g_0, g_1 in zip(groups_result_0, groups_result_1)]
ls_test_results = []
ls_confidences = []
ls_sig_values = []
for llh in ls_ln_lh_ratio:
test_result, confidence = stats.test_significance_ratio_chi2(llh, significance_level)
ls_test_results.append(test_result)
ls_confidences.append(confidence)
ls_sig_values.append(significance_level)
df_group_protocol = pd.DataFrame(list(zip(group_names, groups_lh_0, groups_lh_1,
ls_loss_rates_0, ls_loss_rates_1,
ls_alphas_1, ls_ln_lh_ratio, ls_test_results,
ls_confidences, ls_sig_values,
)),
columns=['group_by', 'lh_0', 'lh_1', 'loss_rate_0', 'loss_rate_1',
'alpha_1', '-2*ln_lh_ratio', 'test result', 'test_confidence',
'significance value'])
df_group_protocol.to_csv(os.path.join(save_path, '0_final_group_by_protocol.csv'))
df_rec_protocol.to_csv(os.path.join(rec_save_path, '_'.join(['0_protocol_orientation']) + '.csv'))
df_rec_protocol.to_pickle(os.path.join(rec_save_path, '_'.join(['0_protocol_orientation']) + '.pkl'))
df_oriented_rec_protocol.to_csv(os.path.join(rec_save_path, '0_final_oriented_protocol.csv'))
df_oriented_rec_protocol.to_pickle(os.path.join(rec_save_path, '0_final_oriented_protocol.pkl'))
# trying concat for now, might want to merge/join, especially, if keys overlap
df_sim_protocol = df_sim_protocol.set_index('name')
if not determine_orientation:
if not df_rec_protocol_boring.empty:
df_rec_protocol_boring = df_rec_protocol_boring.set_index('name')
df_rec_protocol = pd.concat([df_rec_protocol, df_rec_protocol_boring], axis=0)
df_rec_protocol = df_oriented_rec_protocol if determine_orientation else df_rec_protocol
df_sim_rec_protocol = pd.concat([df_sim_protocol, df_rec_protocol], axis=1, join='inner')
protocol_path = os.path.join(save_path, '_'.join(['0_sim_rec_protocol']))
logger.info(f'Finished simulation and reconstruction, saving to {protocol_path}')
df_sim_rec_protocol.to_csv(os.path.join(save_path, '_'.join(['0_sim_rec_protocol']) + '.csv'))
df_sim_rec_protocol.to_pickle(os.path.join(save_path, '_'.join(['0_sim_rec_protocol']) + '.pkl'))
if tree_save_path is not None:
with open(os.path.join(tree_save_path, 'final_dict_nwk_trees.pkl'), 'wb') as f:
pickle.dump(dict_trees, f)
def run_reconstruction(rec_parameter_dict, dict_crispr_groups, save_path=None, plot_tree=True,
lh_fct=None,
logfile_path=None,
do_show=False,
sim_as_rec=False,
combine_non_unique_arrays=False,
tree_path=None,
tree_save_path=None,
use_provided_trees=True,
hide_unobserved_spacers=False,
selection_fct=None,
plot_order=True,
significance_level=0.05,
group_by=None,
finite_bl_at_root=True,
logger=None,
extend_branches=False,
tree_lh_fct=None,
tree_distance_function='likelihood',
tree_construction_method='upgma',
tree_gain_rate=None,
tree_loss_rate=None,
tree_alpha=None,
visualization_order='nb_leafs',
provided_aligned_arrays=None,
provided_dict_duplicated_spacers=None,
provided_numbering=None,
spacer_labels_num=True,
alpha_bias_correction=False,
rho_bias_correction=False,
core_genome_trees=False,
metadata=True,
alternative_parameter_estimation=False,
save_reconstructed_events=False,
dpi=90,
figsize_rec=(None, None, 'px'),
):
"""
:param figsize_rec:
:param dpi:
:param tree_construction_method:
:param save_reconstructed_events:
:param tree_lh_fct:
:param seed:
:param metadata:
:param lh_fct:
:param provided_aligned_arrays:
:param provided_dict_duplicated_spacers:
:param core_genome_trees:
:param core_gene_trees:
:param rho_bias_correction:
:param alpha_bias_correction:
:param use_provided_trees:
:param tree_fmt:
:param spacer_labels_num:
:param provided_numbering: dictionary with spacer numbers and respective colors in tuples for each crispr_group_name
fixes numbering and colors for reverse reconstructions (or in general)
:param visualization_order: Just if custom order is supposed to be given for toy examples. Although I can implement
other stuff in the future.
:param tree_alpha:
:param tree_loss_rate:
:param tree_gain_rate:
:param tree_distance_function:
:param extend_branches:
:param group_by:
:param finite_bl_at_root:
:param sim_as_rec:
:param selection_fct:
:param logger:
:param significance_level:
:param plot_order:
:param hide_unobserved_spacers:
:param tree_save_path:
:param tree_path:
:param combine_non_unique_arrays:
:param rec_parameter_dict:
:param dict_crispr_groups:
:param save_path:
:param plot_tree:
:param logfile_path:
:param do_show:
:param alternative_parameter_estimation: based on reduced number of events (only considers spacers existing at root
or only events that are in blocks with spacers existing at root).
:return:
"""
minimum_nb_of_arrays = 2
if not os.path.exists(save_path):
os.makedirs(save_path)
if logger is None:
if logfile_path is None:
logfile_path = os.path.join(save_path, '0_logger.log')
logger = misc.create_logger('reconstruction', logging.INFO, outfile=logfile_path)
logger.info('Running Reconstruction with parameter dictionary: %s', rec_parameter_dict)
run_timer = misc.RunTimer()
if lh_fct is not None:
give_lh_fct = None if lh_fct == 'simple' else lh_fct
else:
give_lh_fct = rec_parameter_dict.get('give_lh_fct', None)
if isinstance(give_lh_fct, str):
logger.info(f'Loading lh function from {give_lh_fct}')
# logger.info(f'Lambdifying lh fct...')
# give_lh_fct = symbolic_lh_computation.load_lambdify_lh_fct(give_lh_fct, save_lambdified_lh_fct=True)
# for later when I saved the lambdifyed fct one time
give_lh_fct = symbolic_lh_computation.load_pickled_lambdified_lh_fct(give_lh_fct)
logger.info(f'Loading time: {run_timer.time_from_last_checkpoint()}')
ls_dict_protocol = []
ls_boring_protocol = []
ls_skipped_protocol = []
if tree_path and use_provided_trees:
logger.info(f'Loading trees from {tree_path} ...')
file_ext = os.path.splitext(tree_path)[1]
if file_ext in ['.pkl', '.pickle']:
with open(tree_path, 'rb') as f:
dict_trees = pickle.load(f)
elif file_ext in ['.json', '.txt']:
with open(tree_path, 'r') as f:
dict_trees = json.load(f)
else:
raise logger.error(f'Unknown file extension of tree file: {file_ext} !')
else:
dict_trees = {}
new_dict_trees = {}
dict_data_for_lh_ratio = dict()
new_dict_provided_numberings = dict()
new_dict_provided_aligned_arrays = dict()
new_dict_provided_duplicated_spacers = dict()
for i, crispr_group in enumerate(dict_crispr_groups.values()):
dict_bg_colors = dict()
in_between_runs_time = run_timer.time_from_last_checkpoint()
logger.info(f'Working on: {crispr_group.name}')
if crispr_group.repeat:
logger.info(f'Repeat: {crispr_group.repeat}')
logger.info(f'Progress (group/total): {i + 1} / {len(dict_crispr_groups)}')
skip, reason = multiple_selection_fcts(crispr_group,
selection_criteria=selection_fct[0] if selection_fct else None,
selection_parameters=selection_fct[1] if selection_fct else None)
if skip:
logger.warning('Skipped\n for reason: %s', reason)
ls_skipped_protocol.append({'name': crispr_group.name, 'repeat': crispr_group.repeat,
'reason': reason})
continue
ls_array_names = list(crispr_group.crispr_dict.keys())
ls_arrays = [crispr_array.spacer_array for crispr_array in crispr_group.crispr_dict.values()]
prev_len = len(ls_arrays)
dict_combined_arrays = dict()
if combine_non_unique_arrays:
ls_arrays, ls_array_names, dict_combined_arrays = misc.remove_completely_same_arrays(ls_arrays,
ls_array_names)
nb_unique_spacer_arrays = len(ls_arrays)
# logger.info(f'Arrays that were combined to one array {dict_combined_arrays}')
logger.info(f'Note: For completely same arrays only one representative is used for reconstruction and '
f'visualization, due to provided option: --combine_non_unique_arrays. '
f'The classes of combined arrays are found in detailed results csv under '
f'"combined array names".')
logger.info(f'Number of arrays before combination: {prev_len} / after: {len(ls_arrays)}')
# Be aware: Tree generation or alignment produces an error for single arrays (which are useless anyway)
if len(ls_arrays) < minimum_nb_of_arrays:
logger.warning('Skipped because arrays were combined. Too few arrays after combining non-uniques.\n')
ls_skipped_protocol.append({'name': crispr_group.name, 'repeat': crispr_group.repeat,
'reason': 'too few after combining uniques'})
continue
else:
ls_unique_arrays, _, _ = misc.remove_completely_same_arrays(ls_arrays, ls_array_names)
nb_unique_spacer_arrays = len(ls_unique_arrays)
if dict_trees:
logger.info(f'Loading tree from imported tree dictionary...')
if crispr_group.name not in dict_trees:
skip, reason = True, 'No tree provided for this group'
logger.warning('Skipped\n for reason: %s', reason)
ls_skipped_protocol.append({'name': crispr_group.name, 'repeat': crispr_group.repeat,
'reason': reason})
continue
tree = import_data.load_single_tree_from_string(dict_trees[crispr_group.name] + '\n')
else:
if tree_lh_fct is None or tree_lh_fct == 'simple':
modifier = 'simple_'
else:
modifier = 'precise_'
if tree_gain_rate is None:
tree_gain_rate = TREE_GENERATION_PARAMETERS[modifier + 'insertion_rate']
if tree_loss_rate is None:
tree_loss_rate = TREE_GENERATION_PARAMETERS[modifier + 'deletion_rate']
if tree_alpha is None:
tree_alpha = TREE_GENERATION_PARAMETERS[modifier + 'alpha']
logger.info(f'Constructing tree with distance function: {tree_distance_function}, '
f'gain rate: {tree_gain_rate}, '
f'loss rate: {tree_loss_rate}, '
f'alpha: {tree_alpha}, '
f'provided_lh_fct: {tree_lh_fct} ...')
tree = construct_tree(ls_array_names, ls_arrays, crispr_group.name,
logger=logger, distance_fct=tree_distance_function,
gain_rate=tree_gain_rate, loss_rate=tree_loss_rate, alpha=tree_alpha,
provided_lh_fct=give_lh_fct,
tree_save_path=None,
tree_construction_method=tree_construction_method)
new_tree = AdvancedTree(tree, True, model_name=crispr_group.name)
new_dict_trees[crispr_group.name] = tree.format('newick')
# print(ls_array_names)
if core_genome_trees:
# Bio.Phylo.draw_ascii(tree)
tree, crispr_group = tree_handling(tree, crispr_group, name_inner_nodes=True, bl_eps=0)
ls_arrays = [crispr_array.spacer_array for crispr_array in crispr_group.crispr_dict.values()]
ls_array_names = list(crispr_group.crispr_dict.keys())
if len(ls_arrays) < minimum_nb_of_arrays:
logger.warning('Skipped because too few arrays after pruning.\n')
ls_skipped_protocol.append({'name': crispr_group.name, 'repeat': crispr_group.repeat,
'reason': 'Skipped because too few arrays after pruning.'})
continue
else:
tree = name_inner_nodes_if_unnamed(tree)
if len(ls_arrays) < minimum_nb_of_arrays:
logger.warning(f'{crispr_group.name} was skipped because there is only one array.')
ls_skipped_protocol.append({'name': crispr_group.name, 'repeat': crispr_group.repeat,
'reason': 'Skipped because there is only one array.'})
continue
crispr_group.set_tree(tree)
# To prevent branches from having 0 branch length
if extend_branches:
for node in crispr_group.tree.find_clades():
if node.up is None:
node.branch_length = 0
continue
if node.branch_length is None:
node.branch_length = 0
node.branch_length += extend_branches
aligned = model_tools.create_df_alignment(ls_arrays, ls_array_names)
if core_genome_trees:
ls_drop_cols = []
for j, col in enumerate(aligned.columns):
if all([c == '-' for c in aligned[col]]):
ls_drop_cols.append(col)
aligned = aligned.drop(columns=ls_drop_cols)