Skip to content

Commit

Permalink
sum mapped columns
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jun 24, 2024
1 parent 9b2bcd3 commit 4b996f8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, Optional

import numpy as np
import numpy.typing as npt
import scipy.sparse
from datasets import Dataset
from tiledbsoma import Experiment, ExperimentAxisQuery
Expand Down Expand Up @@ -65,7 +67,9 @@ def gen() -> Generator[Dict[str, Any], None, None]:
for Xblock, (block_cell_joinids, _) in (
self.X(self.layer_name).blockwise(axis=0, reindex_disable_on_axis=[1], size=self.block_size).scipy()
):
Xblock = self._map_block(block_cell_joinids, Xblock)
assert isinstance(Xblock, scipy.sparse.csr_matrix)
assert Xblock.shape[0] == len(block_cell_joinids)
for i, cell_joinid in enumerate(block_cell_joinids):
yield self.cell_item(cell_joinid, Xblock.getrow(i))

Expand All @@ -81,6 +85,17 @@ def cell_item(self, cell_joinid: int, Xrow: scipy.sparse.csr_matrix) -> Dict[str
"""
...

def _map_block(
self, block_cell_joinids: npt.NDArray[np.int64], Xblock: scipy.sparse.csr_matrix
) -> scipy.sparse.csr_matrix:
"""Apply any necessary transformations to the block of X data before further processing.
This base implementation simply returns the block unchanged. Applying transformations to
the whole block at once can be more efficient than applying them to each row individually
in `cell_item()`.
"""
return Xblock


class _DatasetGeneratorPickleHack:
"""SEE: https://github.com/huggingface/datasets/issues/6194."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pickle
from typing import Any, Dict, Optional, Sequence, Set
from typing import Any, Dict, List, Optional, Sequence, Set

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -46,9 +46,11 @@ class GeneformerTokenizer(CellDatasetBuilder):
max_input_tokens: int
special_token: bool

# set of gene soma_joinids corresponding to genes modeled by Geneformer:
model_gene_ids: npt.NDArray[np.int64]
model_gene_tokens: npt.NDArray[np.int64] # token for each model_gene_id
# for each Geneformer-modeled gene that matches to at least one Census gene, the list of
# matching Census gene/var soma_joinids. (one Geneformer-modeled gene may match several
# Census genes based on gene_mapping_file.)
model_gene_ids: List[List[int]]
model_gene_tokens: npt.NDArray[np.int64] # Geneformer token number for each model_gene_id
model_gene_medians: npt.NDArray[np.float64] # float for each model_gene_id
model_cls_token: Optional[np.int64] = None
model_sep_token: Optional[np.int64] = None
Expand Down Expand Up @@ -104,7 +106,13 @@ def _load_geneformer_data(
"""
# TODO: this work could be reused for all queries on this experiment

genes_df = experiment.ms["RNA"].var.read(column_names=["soma_joinid", "feature_id"]).concat().to_pandas()
genes_df = (
experiment.ms["RNA"]
.var.read(column_names=["soma_joinid", "feature_id"])
.concat()
.to_pandas()
.set_index("soma_joinid")
)

if not (token_dictionary_file and gene_median_file):
try:
Expand All @@ -131,27 +139,35 @@ def _load_geneformer_data(

# compute model_gene_{ids,tokens,medians} by joining genes_df with Geneformer's
# dicts
model_gene_ids = []
model_gene_tokens = []
model_gene_medians = []
model_gene_id_by_ensg: Dict[str, int] = {}
model_gene_ids: List[List[int]] = []
model_gene_tokens: List[np.int64] = []
model_gene_medians: List[np.float64] = []
for gene_id, row in genes_df.iterrows():
ensg = row["feature_id"] # ENSG... gene id, which keys Geneformer's dicts
if gene_mapping is not None:
ensg = gene_mapping.get(ensg, ensg)
if ensg in gene_token_dict:
model_gene_ids.append(gene_id)
model_gene_tokens.append(gene_token_dict[ensg])
model_gene_medians.append(gene_median_dict[ensg])
self.model_gene_ids = np.array(model_gene_ids, dtype=np.int64)
if ensg not in model_gene_id_by_ensg:
model_gene_id_by_ensg[ensg] = len(model_gene_ids)
model_gene_ids.append([])
model_gene_tokens.append(gene_token_dict[ensg])
model_gene_medians.append(gene_median_dict[ensg])
model_gene_ids[model_gene_id_by_ensg[ensg]].append(gene_id)

self.model_gene_ids = model_gene_ids
self.model_gene_tokens = np.array(model_gene_tokens, dtype=np.int64)
self.model_gene_medians = np.array(model_gene_medians, dtype=np.float64)

assert len(np.unique(self.model_gene_ids)) == len(self.model_gene_ids)
assert len(self.model_gene_ids) == len(self.model_gene_tokens)
assert len(self.model_gene_ids) == len(self.model_gene_medians)
assert len(np.unique(self.model_gene_tokens)) == len(self.model_gene_tokens)
assert np.all(self.model_gene_medians > 0)
# Geneformer models protein-coding and miRNA genes, so the intersection should
# be somewhere a little north of 20K.
assert len(self.model_gene_ids) > 20_000
assert (
len(self.model_gene_ids) > 20_000
), f"Mismatch between Census gene IDs and Geneformer token mappings (only {len(self.model_gene_ids)} common genes)"

# Precompute a vector by which we'll multiply each cell's expression vector.
# The denominator normalizes by Geneformer's median expression values.
Expand All @@ -172,19 +188,26 @@ def __enter__(self) -> "GeneformerTokenizer":
self.obs_df = self.obs(column_names=obs_column_names).concat().to_pandas().set_index("soma_joinid")
return self

def _map_block(
self, block_cell_joinids: npt.NDArray[np.int64], Xblock: scipy.sparse.csr_matrix
) -> scipy.sparse.csr_matrix:
model_gene_counts = [
scipy.sparse.csc_matrix(Xblock[:, cols].sum(axis=1)) if len(cols) > 1 else Xblock[:, cols[0]].tocsc()
for cols in self.model_gene_ids
]
return scipy.sparse.hstack(model_gene_counts, format="csr")

def cell_item(self, cell_joinid: int, cell_Xrow: scipy.sparse.csr_matrix) -> Dict[str, Any]:
"""Given the expression vector for one cell, compute the Dataset item providing
the Geneformer inputs (token sequence and metadata).
"""
# project cell_Xrow onto model_gene_ids and normalize by row sum.
# project cell_Xrow onto model_gene_ids and normalize with row sum & gene medians
# notice we divide by the total count of the complete row (not only of the projected
# values); this follows Geneformer's internal tokenizer.
model_counts = cell_Xrow[:, self.model_gene_ids].multiply(1.0 / cell_Xrow.sum())
assert isinstance(model_counts, scipy.sparse.csr_matrix), type(model_counts)
# assert len(model_counts.data) == np.count_nonzero(model_counts.data)
model_expr = model_counts.multiply(self.model_gene_medians_factor)
assert isinstance(model_expr, scipy.sparse.coo_matrix), type(model_expr)
# assert len(model_expr.data) == np.count_nonzero(model_expr.data)
assert cell_Xrow.shape == (1, len(self.model_gene_ids))
model_expr = cell_Xrow.multiply(self.model_gene_medians_factor / cell_Xrow.sum())
assert isinstance(model_expr, scipy.sparse.coo_matrix)
assert model_expr.shape == (1, len(self.model_gene_ids))

# figure the resulting tokens in descending order of model_expr
# (use sparse model_expr.{col,data} to naturally exclude undetected genes)
Expand Down

0 comments on commit 4b996f8

Please sign in to comment.