Skip to content

Commit

Permalink
add gene_mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jun 21, 2024
1 parent f8587ca commit 9b2bcd3
Showing 1 changed file with 11 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
special_token: bool = False,
token_dictionary_file: str = "",
gene_median_file: str = "",
gene_mapping_file: str = "",
**kwargs: Any,
) -> None:
"""- `experiment`: Census Experiment to query
Expand All @@ -74,14 +75,15 @@ def __init__(
- `token_dictionary_file`, `gene_median_file`: pickle files supplying the mapping of
Ensembl human gene IDs onto Geneformer token numbers and median expression values.
By default, these will be loaded from the Geneformer package.
- `gene_mapping_file`: optional pickle file with mapping for Census gene IDs to model's
"""
if obs_attributes: # old name of obs_column_names
obs_column_names = obs_attributes

self.max_input_tokens = max_input_tokens
self.special_token = special_token
self.obs_column_names = set(obs_column_names) if obs_column_names else set()
self._load_geneformer_data(experiment, token_dictionary_file, gene_median_file)
self._load_geneformer_data(experiment, token_dictionary_file, gene_median_file, gene_mapping_file)
super().__init__(
experiment,
measurement_name="RNA",
Expand All @@ -94,6 +96,7 @@ def _load_geneformer_data(
experiment: tiledbsoma.Experiment,
token_dictionary_file: str,
gene_median_file: str,
gene_mapping_file: str,
) -> None:
"""Load (1) the experiment's genes dataframe and (2) Geneformer's static data
files for gene tokens and median expression; then, intersect them to compute
Expand Down Expand Up @@ -121,13 +124,20 @@ def _load_geneformer_data(
with open(gene_median_file, "rb") as f:
gene_median_dict = pickle.load(f)

gene_mapping = None
if gene_mapping_file:
with open(gene_mapping_file, "rb") as f:
gene_mapping = pickle.load(f)

# compute model_gene_{ids,tokens,medians} by joining genes_df with Geneformer's
# dicts
model_gene_ids = []
model_gene_tokens = []
model_gene_medians = []
for gene_id, row in genes_df.iterrows():
ensg = row["feature_id"] # ENSG... gene id, which keys Geneformer's dicts
if gene_mapping is not None:
ensg = gene_mapping.get(ensg, ensg)
if ensg in gene_token_dict:
model_gene_ids.append(gene_id)
model_gene_tokens.append(gene_token_dict[ensg])
Expand Down

0 comments on commit 9b2bcd3

Please sign in to comment.