Skip to content

Commit eb79dc6

Browse files
committed
sort bins by occupation
1 parent 345cf0d commit eb79dc6

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

Diff for: ms2deepscore/train_new_model/inchikey_pair_selection.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,24 @@ def select_compound_pairs_wrapper(
3939
settings.same_prob_bins,
4040
settings.include_diagonal)
4141

42-
aimed_nr_of_pairs_per_bin = determine_aimed_nr_of_pairs_per_bin(available_pairs_per_bin_matrix,
43-
settings,
44-
nr_of_inchikeys=len(inchikeys14_unique))
45-
46-
pair_frequency_matrixes = balanced_selection_of_pairs_per_bin(available_pairs_per_bin_matrix,
47-
settings.max_pair_resampling,
48-
aimed_nr_of_pairs_per_bin)
49-
50-
selected_pairs_per_bin = convert_to_selected_pairs_list(pair_frequency_matrixes, available_pairs_per_bin_matrix,
51-
available_scores_per_bin_matrix, inchikeys14_unique)
42+
aimed_nr_of_pairs_per_bin, bin_priorities = determine_aimed_nr_of_pairs_per_bin(
43+
available_pairs_per_bin_matrix,
44+
settings,
45+
nr_of_inchikeys=len(inchikeys14_unique)
46+
)
47+
48+
pair_frequency_matrixes = balanced_selection_of_pairs_per_bin(
49+
available_pairs_per_bin_matrix,
50+
settings.max_pair_resampling,
51+
aimed_nr_of_pairs_per_bin
52+
)
53+
54+
selected_pairs_per_bin = convert_to_selected_pairs_list(
55+
pair_frequency_matrixes,
56+
available_pairs_per_bin_matrix,
57+
available_scores_per_bin_matrix,
58+
inchikeys14_unique
59+
)
5260
return [pair for pairs in selected_pairs_per_bin for pair in pairs]
5361

5462

@@ -143,12 +151,19 @@ def compute_jaccard_similarity_per_bin(
143151

144152
def determine_aimed_nr_of_pairs_per_bin(available_pairs_per_bin_matrix, settings, nr_of_inchikeys):
145153
"""Determines the aimed_nr_of_pairs_per_bin.
146-
If the settings given are higher than the highest possible number of pairs it is lowered to that"""
154+
155+
If the settings given are higher than the highest possible number of pairs it is lowered to that.
156+
"""
147157

148158
# Select the nr_of_pairs_per_bin to use
149159
nr_of_available_pairs_per_bin = get_nr_of_available_pairs_in_bin(available_pairs_per_bin_matrix)
150160
lowest_max_number_of_pairs = min(nr_of_available_pairs_per_bin) * settings.max_pair_resampling
151161
print(f"The available nr of pairs per bin are: {nr_of_available_pairs_per_bin}")
162+
163+
# Set bin priority from lowest to highest no. of available pairs
164+
bin_priority = np.argsort(nr_of_available_pairs_per_bin)
165+
print(f"Bin priorities will be orderd accordingly: {[settings.same_prob_bins[i] for i in bin_priority]}")
166+
152167
aimed_nr_of_pairs_per_bin = settings.average_pairs_per_bin * nr_of_inchikeys
153168
if lowest_max_number_of_pairs < aimed_nr_of_pairs_per_bin:
154169
print(f"Warning: The average_pairs_per_bin: {settings.average_pairs_per_bin} cannot be reached, "
@@ -158,13 +173,14 @@ def determine_aimed_nr_of_pairs_per_bin(available_pairs_per_bin_matrix, settings
158173
f"Instead the lowest number of available pairs in a bin times the resampling is used, "
159174
f"which is: {lowest_max_number_of_pairs}")
160175
aimed_nr_of_pairs_per_bin = lowest_max_number_of_pairs
161-
return aimed_nr_of_pairs_per_bin
176+
return aimed_nr_of_pairs_per_bin, bin_priority
162177

163178

164179
def balanced_selection_of_pairs_per_bin(
165180
available_pairs_per_bin_matrix: np.ndarray,
166181
max_pair_resampling: int,
167-
nr_of_pairs_per_bin: int
182+
nr_of_pairs_per_bin: int,
183+
bin_priority: np.ndarray = None,
168184
) -> np.ndarray:
169185
"""From the available_pairs_per_bin_matrix a balanced selection is made to have a balanced distribution.
170186
@@ -190,11 +206,16 @@ def balanced_selection_of_pairs_per_bin(
190206
Resampling means that the exact same inchikey pair is added multiple times to the list of pairs.
191207
nr_of_pairs_per_bin:
192208
The number of pairs that should be sampled for each tanimoto bin.
209+
bin_priority:
210+
Bins will be processed in the order given in bin_priority. Default is set to None in which case no change
211+
to the order will be done.
193212
"""
213+
if bin_priority is None:
214+
bin_priority = np.arange(0, available_pairs_per_bin_matrix.shape[0])
194215

195216
inchikey_count = np.zeros(available_pairs_per_bin_matrix.shape[1])
196217
pair_frequency_matrixes = []
197-
for pairs_in_bin in available_pairs_per_bin_matrix:
218+
for pairs_in_bin in available_pairs_per_bin_matrix[bin_priority]:
198219
pair_frequencies, inchikey_count = select_balanced_pairs(pairs_in_bin,
199220
inchikey_count,
200221
nr_of_pairs_per_bin,

0 commit comments

Comments
 (0)