Source code for cas.populate_cell_ids

import json
import logging
from typing import List, Optional, Union

from cap_anndata import CapAnnDataDF
from pandas import DataFrame

from cas.file_utils import read_anndata_file, read_json_file
from cas.utils.validation_utils import validate_labelset_hierarchy

# Configure logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


[docs] def populate_cell_ids( cas_json_path: str, anndata_path: str, labelsets: Optional[List[str]] = None, validate: bool = False, ): """ Add/update CellIDs in a CAS JSON file using matching data from an AnnData file. This function reads a CAS JSON file and an AnnData file, validates their consistency, and updates CellIDs in CAS based on matching labelsets from the AnnData `obs` DataFrame. The modified CAS JSON is then saved back to the original file. Parameters: cas_json_path (str): Path to the CAS JSON file. anndata_path (str): Path to the AnnData file. labelsets (list, optional): A list of labelsets to update with CellIDs from AnnData. If None, the labelset with rank '0' is used by default. validate (bool, optional): If True, runs validation checks to ensure labelset consistency. The program will exit with an error if validation fails. Defaults to False. Raises: Exception: If the AnnData file cannot be read. Returns: None """ ad = read_anndata_file(anndata_path) if ad is None: raise Exception(f"AnnData read operation failed: {anndata_path}") ad_obs = ad.obs cas = read_json_file(cas_json_path) cas = add_cell_ids(cas, ad_obs, labelsets, validate) if cas: with open(cas_json_path, "w") as json_file: json.dump(cas, json_file, indent=2)
[docs] def update_cas_with_cell_ids( cas_json: dict, anndata_obs: CapAnnDataDF, labelsets: list = None ) -> dict: """ Update a CAS dictionary by adding or modifying CellIDs using matching AnnData observations. This function takes a CAS dictionary and an AnnData `obs` DataFrame and updates the CAS with cell IDs extracted from the specified labelsets in the AnnData. Parameters: cas_json (dict): The CAS dictionary to update with cell IDs from AnnData. anndata_obs (CapAnnDataDF): The `obs` DataFrame extracted from an AnnData object. labelsets (list, optional): A list of labelsets to update with IDs from AnnData. If None, the labelset with rank '0' is used. Returns: dict: The updated CAS dictionary with CellIDs populated. """ # validation step validate_labelset_hierarchy(cas_json, anndata_obs) cas = add_cell_ids(cas_json, anndata_obs, labelsets) return cas
[docs] def add_cell_ids( cas: dict, ad_obs: Union[DataFrame, CapAnnDataDF], labelsets: list = None, validate: bool = False, ): """ Add/update CellIDs to CAS from matching AnnData file. Parameters: cas: CAS JSON object ad_obs: Obs DataFrame extracted from an AnnData object. labelsets: List of labelsets to update with IDs from AnnData. If value is null, rank '0' labelset is used. The labelsets should be provided in order, starting from rank 0 (leaf nodes) and ascending to higher ranks. validate (bool, optional): If True, runs validation checks to ensure labelset consistency. The program wil`l exit with an error if validation fails. Defaults to False. """ # Run validation checks (exit if validation fails and validate=True) validate_labelset_hierarchy(cas, ad_obs, validate) rank_zero_labelset = [ lbl_set["name"] for lbl_set in cas["labelsets"] if "rank" in lbl_set and lbl_set.get("rank") in (0, "0") ][0] if not labelsets: labelsets = [rank_zero_labelset] obs_keys = ad_obs.columns.tolist() cluster_identifier_column = get_obs_cluster_identifier_column( obs_keys, labelsets, rank_zero_labelset ) if cluster_identifier_column: cid_lookup = {} for anno in cas["annotations"]: if anno["labelset"] == rank_zero_labelset and anno["labelset"] in labelsets: cell_ids = [] if cluster_identifier_column.endswith( "_id" ) or cluster_identifier_column.lower().endswith(" id"): # cluster column value is integer cluster id cluster_id = anno.get("author_annotation_fields", {}).get( cluster_identifier_column ) if not cluster_id: raise ValueError( "AnnData cluster identifier column ({}) couldn't be find found in" "CAS author_annotation_fields.".format( cluster_identifier_column ) ) cell_ids = list( ad_obs.loc[ ad_obs[cluster_identifier_column] == int(cluster_id), cluster_identifier_column, ].index ) else: # cluster column value is cell label cluster_label = anno["cell_label"] cell_ids = list( ad_obs.loc[ ad_obs[cluster_identifier_column] == cluster_label ].index ) anno["cell_ids"] = cell_ids if "parent_cell_set_name" in anno: lookup_key = anno["parent_cell_set_name"] if lookup_key in cid_lookup: cid_lookup[lookup_key].update(cell_ids) else: cid_lookup[lookup_key] = set(cell_ids) for anno in cas["annotations"]: if anno["labelset"] in labelsets and anno["cell_label"] in cid_lookup: cell_ids = list(cid_lookup.get(anno["cell_label"], [])) anno["cell_ids"] = cell_ids return cas else: logger.warning( "Cluster identifier column couldn't be identified in OBS. Populate cell ids operation is aborted." ) return None
[docs] def get_obs_cluster_identifier_column( obs_keys: List[str], labelsets: list = None, rank_zero_labelset: str = None ): """ Anndata files may use different column names to uniquely identify Clusters. Get the cluster identifier column name for the current file. Args: obs_keys: Anndata observation keys. labelsets: List of labelsets to update with IDs from AnnData. The labelsets should be provided in order, starting from rank 0 (leaf nodes) and ascending to higher ranks. rank_zero_labelset: rank 0 labelset name Returns: cluster identifier column name """ # obs_keys = ad.obs_keys() cluster_identifier_column = "" if labelsets and labelsets[0] in obs_keys: cluster_identifier_column = labelsets[0] else: # check id variants if rank_zero_labelset.title() + "_id" in obs_keys: cluster_identifier_column = rank_zero_labelset.title() + "_id" elif rank_zero_labelset.lower() + "_id" in obs_keys: cluster_identifier_column = rank_zero_labelset.lower() + "_id" elif rank_zero_labelset.title() + " id" in obs_keys: cluster_identifier_column = rank_zero_labelset.title() + " id" elif rank_zero_labelset.lower() + " id" in obs_keys: cluster_identifier_column = rank_zero_labelset.lower() + " id" return cluster_identifier_column