diff --git a/containers/azimuth/context/main.py b/containers/azimuth/context/main.py index 2fc5a65..d805099 100644 --- a/containers/azimuth/context/main.py +++ b/containers/azimuth/context/main.py @@ -17,6 +17,7 @@ class AzimuthOrganMetadata(t.TypedDict): class AzimuthOptions(t.TypedDict): reference_data_dir: Path + query_layers_key: t.Optional[str] class AzimuthAlgorithm(Algorithm[AzimuthOrganMetadata, AzimuthOptions]): @@ -37,15 +38,18 @@ def do_run( # obs columns of dtype 'object'. As a workaround we create a # clean matrix without obs columns on which azimuth is run # after which the annotations are copied back to the original matrix + temp_index = self.create_temp_obs_index(data) clean_matrix_path = Path("clean_matrix.h5ad") - clean_matrix = self.create_clean_matrix(data) + clean_matrix = self.create_clean_matrix(data, temp_index) + + self.set_data_layer(clean_matrix, options["query_layers_key"]) clean_matrix.write_h5ad(clean_matrix_path) annotated_matrix_path = self.run_azimuth_scripts( clean_matrix_path, reference_data ) annotated_matrix = anndata.read_h5ad(annotated_matrix_path) - self.copy_annotations(data, annotated_matrix) + self.copy_annotations(data, annotated_matrix, temp_index) return { "data": data, @@ -53,30 +57,73 @@ def do_run( "prediction_column": "predicted." + metadata["prediction_column"], } - def create_clean_matrix(self, matrix: anndata.AnnData) -> anndata.AnnData: + def create_temp_obs_index(self, matrix: anndata.AnnData) -> pandas.Index: + """Creates a new index by adding a prefix to each index name. + Used as a workaround for: https://github.com/satijalab/azimuth/issues/178 + and https://github.com/satijalab/azimuth/issues/138 + + Args: + matrix (anndata.AnnData): Original data + + Returns: + pandas.Index: A new index + """ + return matrix.obs.index.map(lambda name: f"QUERY:{name}") + + def create_clean_matrix( + self, + matrix: anndata.AnnData, + temp_index: pandas.Index, + ) -> anndata.AnnData: """Creates a copy of the data with all observation columns removed. Args: matrix (anndata.AnnData): Original data + temp_index (pandas.Index): Temporary index generated by `create_temp_obs_index` Returns: anndata.AnnData: Cleaned data """ - clean_obs = pandas.DataFrame(index=matrix.obs.index) + clean_obs = pandas.DataFrame(index=temp_index) clean_matrix = matrix.copy() clean_matrix.obs = clean_obs + return clean_matrix + def set_data_layer( + self, matrix: anndata.AnnData, query_layers_key: t.Optional[str] + ) -> None: + """Set the data layer to use for annotating. + + Args: + matrix (anndata.AnnData): Matrix to update + query_layers_key (t.Optional[str]): A layer name or 'raw' + """ + if query_layers_key == "raw": + matrix.X = matrix.raw.X + elif query_layers_key is not None: + matrix.X = matrix.layers[query_layers_key].copy() + def copy_annotations( - self, matrix: anndata.AnnData, annotated_matrix: anndata.AnnData + self, + matrix: anndata.AnnData, + annotated_matrix: anndata.AnnData, + temp_index: pandas.Index, ) -> None: """Copies annotations from one matrix to another. Args: matrix (anndata.AnnData): Matrix to copy to annotated_matrix (anndata.AnnData): Matrix to copy from + temp_index (pandas.Index): Temporary index generated by `create_temp_obs_index` """ - matrix.obs = matrix.obs.join(annotated_matrix.obs, rsuffix="_azimuth") + matrix.obs = matrix.obs.merge( + annotated_matrix.obs, + how="left", + left_on=temp_index, + right_index=True, + suffixes=(None, "_azimuth"), + ) def run_azimuth_scripts(self, matrix_path: Path, reference_data: Path) -> str: """Creates a subprocess running the Azimuth annotation R script. @@ -156,6 +203,7 @@ def _get_arg_parser(): required=True, help="Path to directory with reference data", ) + parser.add_argument("--query-layers-key", help="Data layer to use") return parser diff --git a/containers/azimuth/options.yml b/containers/azimuth/options.yml index d78b252..314145a 100644 --- a/containers/azimuth/options.yml +++ b/containers/azimuth/options.yml @@ -7,3 +7,8 @@ fields: label: Directory with reference data directories inputBinding: prefix: --reference-data-dir + queryLayersKey: + type: string? + label: Data layer to use + inputBinding: + prefix: --query-layers-key diff --git a/containers/celltypist/context/main.py b/containers/celltypist/context/main.py index 0498deb..3a23387 100644 --- a/containers/celltypist/context/main.py +++ b/containers/celltypist/context/main.py @@ -16,6 +16,7 @@ class CelltypistOrganMetadata(t.TypedDict): class CelltypistOptions(t.TypedDict): ensemble_lookup: Path + query_layers_key: t.Optional[str] class CelltypistAlgorithm(Algorithm[CelltypistOrganMetadata, CelltypistOptions]): @@ -31,6 +32,7 @@ def do_run( ) -> RunResult: """Annotate data using celltypist.""" data = scanpy.read_h5ad(matrix) + self.set_data_layer(data, options["query_layers_key"]) data = self.normalize(data) data, var_names = self.normalize_var_names(data, options) data = celltypist.annotate( @@ -40,6 +42,20 @@ def do_run( return {"data": data, "organ_level": metadata["model"].replace(".", "_")} + def set_data_layer( + self, matrix: scanpy.AnnData, query_layers_key: t.Optional[str] + ) -> None: + """Set the data layer to use for annotating. + + Args: + matrix (anndata.AnnData): Matrix to update + query_layers_key (t.Optional[str]): A layer name or 'raw' + """ + if query_layers_key == "raw": + matrix.X = matrix.raw.X + elif query_layers_key is not None: + matrix.X = matrix.layers[query_layers_key].copy() + def normalize(self, data: scanpy.AnnData) -> scanpy.AnnData: """Normalizes data according to celltypist requirements. @@ -115,6 +131,7 @@ def _get_arg_parser(): default="/ensemble-lookup.csv", help="Ensemble id to gene name csv", ) + parser.add_argument("--query-layers-key", help="Data layer to use") return parser diff --git a/containers/celltypist/options.yml b/containers/celltypist/options.yml index 6894767..dfddd40 100644 --- a/containers/celltypist/options.yml +++ b/containers/celltypist/options.yml @@ -1,4 +1,9 @@ type: record name: options label: Celltypist specific options -fields: {} +fields: + queryLayersKey: + type: string? + label: Data layer to use + inputBinding: + prefix: --query-layers-key diff --git a/containers/popv/context/main.py b/containers/popv/context/main.py index 4acef4b..12edabb 100644 --- a/containers/popv/context/main.py +++ b/containers/popv/context/main.py @@ -299,9 +299,7 @@ def _get_arg_parser(): required=True, help="Path to models directory", ) - parser.add_argument( - "--query-layers-key", required=True, help="Name of layer with raw counts" - ) + parser.add_argument("--query-layers-key", help="Name of layer with raw counts") parser.add_argument("--prediction-mode", default="fast", help="Prediction mode") parser.add_argument( "--cell-ontology-dir", diff --git a/containers/popv/options.yml b/containers/popv/options.yml index 6c3ab73..fef5e06 100644 --- a/containers/popv/options.yml +++ b/containers/popv/options.yml @@ -13,7 +13,7 @@ fields: inputBinding: prefix: --models-dir queryLayersKey: - type: string + type: string? inputBinding: prefix: --query-layers-key predictionMode: