Skip to content

Commit

Permalink
sort bins by occupation
Browse files Browse the repository at this point in the history
  • Loading branch information
florian-huber committed Oct 15, 2024
1 parent 345cf0d commit eb79dc6
Showing 1 changed file with 35 additions and 14 deletions.
49 changes: 35 additions & 14 deletions ms2deepscore/train_new_model/inchikey_pair_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,24 @@ def select_compound_pairs_wrapper(
settings.same_prob_bins,
settings.include_diagonal)

aimed_nr_of_pairs_per_bin = determine_aimed_nr_of_pairs_per_bin(available_pairs_per_bin_matrix,
settings,
nr_of_inchikeys=len(inchikeys14_unique))

pair_frequency_matrixes = balanced_selection_of_pairs_per_bin(available_pairs_per_bin_matrix,
settings.max_pair_resampling,
aimed_nr_of_pairs_per_bin)

selected_pairs_per_bin = convert_to_selected_pairs_list(pair_frequency_matrixes, available_pairs_per_bin_matrix,
available_scores_per_bin_matrix, inchikeys14_unique)
aimed_nr_of_pairs_per_bin, bin_priorities = determine_aimed_nr_of_pairs_per_bin(
available_pairs_per_bin_matrix,
settings,
nr_of_inchikeys=len(inchikeys14_unique)
)

pair_frequency_matrixes = balanced_selection_of_pairs_per_bin(
available_pairs_per_bin_matrix,
settings.max_pair_resampling,
aimed_nr_of_pairs_per_bin
)

selected_pairs_per_bin = convert_to_selected_pairs_list(
pair_frequency_matrixes,
available_pairs_per_bin_matrix,
available_scores_per_bin_matrix,
inchikeys14_unique
)
return [pair for pairs in selected_pairs_per_bin for pair in pairs]


Expand Down Expand Up @@ -143,12 +151,19 @@ def compute_jaccard_similarity_per_bin(

def determine_aimed_nr_of_pairs_per_bin(available_pairs_per_bin_matrix, settings, nr_of_inchikeys):
"""Determines the aimed_nr_of_pairs_per_bin.
If the settings given are higher than the highest possible number of pairs it is lowered to that"""
If the settings given are higher than the highest possible number of pairs it is lowered to that.
"""

# Select the nr_of_pairs_per_bin to use
nr_of_available_pairs_per_bin = get_nr_of_available_pairs_in_bin(available_pairs_per_bin_matrix)
lowest_max_number_of_pairs = min(nr_of_available_pairs_per_bin) * settings.max_pair_resampling
print(f"The available nr of pairs per bin are: {nr_of_available_pairs_per_bin}")

# Set bin priority from lowest to highest no. of available pairs
bin_priority = np.argsort(nr_of_available_pairs_per_bin)
print(f"Bin priorities will be orderd accordingly: {[settings.same_prob_bins[i] for i in bin_priority]}")

aimed_nr_of_pairs_per_bin = settings.average_pairs_per_bin * nr_of_inchikeys
if lowest_max_number_of_pairs < aimed_nr_of_pairs_per_bin:
print(f"Warning: The average_pairs_per_bin: {settings.average_pairs_per_bin} cannot be reached, "
Expand All @@ -158,13 +173,14 @@ def determine_aimed_nr_of_pairs_per_bin(available_pairs_per_bin_matrix, settings
f"Instead the lowest number of available pairs in a bin times the resampling is used, "
f"which is: {lowest_max_number_of_pairs}")
aimed_nr_of_pairs_per_bin = lowest_max_number_of_pairs
return aimed_nr_of_pairs_per_bin
return aimed_nr_of_pairs_per_bin, bin_priority


def balanced_selection_of_pairs_per_bin(
available_pairs_per_bin_matrix: np.ndarray,
max_pair_resampling: int,
nr_of_pairs_per_bin: int
nr_of_pairs_per_bin: int,
bin_priority: np.ndarray = None,
) -> np.ndarray:
"""From the available_pairs_per_bin_matrix a balanced selection is made to have a balanced distribution.
Expand All @@ -190,11 +206,16 @@ def balanced_selection_of_pairs_per_bin(
Resampling means that the exact same inchikey pair is added multiple times to the list of pairs.
nr_of_pairs_per_bin:
The number of pairs that should be sampled for each tanimoto bin.
bin_priority:
Bins will be processed in the order given in bin_priority. Default is set to None in which case no change
to the order will be done.
"""
if bin_priority is None:
bin_priority = np.arange(0, available_pairs_per_bin_matrix.shape[0])

inchikey_count = np.zeros(available_pairs_per_bin_matrix.shape[1])
pair_frequency_matrixes = []
for pairs_in_bin in available_pairs_per_bin_matrix:
for pairs_in_bin in available_pairs_per_bin_matrix[bin_priority]:
pair_frequencies, inchikey_count = select_balanced_pairs(pairs_in_bin,
inchikey_count,
nr_of_pairs_per_bin,
Expand Down

0 comments on commit eb79dc6

Please sign in to comment.