Source code for cas.anndata_conversion

import logging
import sys
from typing import Optional
import warnings

from cap_anndata import read_h5ad

from cas.file_utils import read_json_file
from cas.utils.conversion_utils import (
    ANNOTATIONS,
    CELL_IDS,
    CELL_LABEL,
    LABELSET,
    LABELSET_NAME,
    LABELSETS,
    collect_parent_cell_ids,
    copy_and_update_file_path,
    fetch_anndata,
    reformat_json,
)

# Set up logging
logger = logging.getLogger(__name__)

# Suppress warning messages from cap_anndata.cap_anndata
logging.getLogger("cap_anndata.cap_anndata").setLevel(logging.ERROR)


[docs] def merge( cas_file_path: str, anndata_path: Optional[str], validate: bool, output_file_name: str, ): """ Tests if CAS json and AnnData are compatible and merges CAS into AnnData if possible. This function performs the following checks: 1. Verifies that all cell barcodes (cell IDs) in CAS exist in AnnData and vice versa. 2. Identifies matching labelset names between CAS and AnnData. 3. Validates that cell sets associated with each annotation match between CAS and AnnData. 4. Checks if the cell labels are identical; if not, provides options to update or terminate. Args: cas_file_path: The path to the CAS json file. anndata_path: The path to the AnnData file. validate: Boolean to determine if validation checks will be performed before writing to the output AnnData file. output_file_name: Output AnnData file name. """ input_json = read_json_file(cas_file_path) merge_cas_object(input_json, anndata_path, validate, output_file_name)
[docs] def merge_cas_object( input_json: dict, anndata_file_path: Optional[str], validate: bool, output_file_path: str, download_dir: Optional[str] = None, ): """ Tests if CAS json and AnnData are compatible and merges CAS into AnnData if possible. This function performs the following checks: 1. Verifies that all cell barcodes (cell IDs) in CAS exist in AnnData and vice versa. 2. Identifies matching labelset names between CAS and AnnData. 3. Validates that cell sets associated with each annotation match between CAS and AnnData. 4. Checks if the cell labels are identical; if not, provides options to update or terminate. Args: input_json: The CAS json object. anndata_file_path: The path to the AnnData file. validate: Boolean to determine if validation checks will be performed before writing to the output AnnData file. output_file_path: Output AnnData file name. download_dir: The directory to download AnnData files. """ if not anndata_file_path: anndata_file_path = fetch_anndata(input_json, download_dir) anndata_file_path = copy_and_update_file_path(anndata_file_path, output_file_path) with read_h5ad(file_path=anndata_file_path, edit=True) as cap_adata: cap_adata.read_obs() obs = cap_adata.obs test_compatibility(obs, input_json, validate) cap_adata.read_uns() cap_adata.uns["cas"] = reformat_json(input_json) cap_adata.overwrite()
[docs] def test_compatibility(anndata_obs, input_json, validate): """ Tests if CAS and AnnData can be merged. Args: anndata_obs: The AnnData obs object. input_json: The CAS data json object. validate: Boolean to determine if validation checks will be performed before writing to the output AnnData file. """ annotations = input_json[ANNOTATIONS] obs_index = set(anndata_obs.axes[0].tolist()) validate_cell_ids(obs_index, annotations, validate) labelsets = input_json[LABELSETS] matching_obs_keys = get_matching_obs_keys(anndata_obs.columns, labelsets) check_labelsets(input_json, anndata_obs, matching_obs_keys, validate)
[docs] def check_labelsets(cas_json, input_obs, matching_obs_keys, validate): annotations = cas_json[ANNOTATIONS] derived_cell_ids = collect_parent_cell_ids(cas_json) for ann in annotations: if ann[LABELSET] in matching_obs_keys: anndata_labelset_cell_ids = ( input_obs.groupby(ann[LABELSET], observed=False) .apply(lambda group: set(group.index), include_groups=False) .to_dict() ) for cell_label, cell_list in anndata_labelset_cell_ids.items(): cell_ids = set(ann.get(CELL_IDS, [])) if cell_ids and cell_list == cell_ids: handle_matching_labelset(ann, cell_label, input_obs, validate) elif cell_list == derived_cell_ids.get( str(ann["cell_set_accession"]), ann.get(CELL_IDS, []) ): handle_matching_labelset(ann, cell_label, input_obs, validate) elif cell_label == ann[CELL_LABEL]: if cell_list == set(ann.get(CELL_IDS, [])): handle_matching_labelset(ann, cell_label, input_obs, validate) else: handle_non_matching_labelset( ann, input_obs, validate, derived_cell_ids )
[docs] def get_matching_obs_keys(obs_keys, cas_labelsets): cas_labelset_names = {item[LABELSET_NAME] for item in cas_labelsets} matching_obs_keys = cas_labelset_names.intersection(obs_keys) return list(matching_obs_keys)
[docs] def handle_matching_labelset(ann, cell_label, input_obs, validate): # Used for label changes if cell_label != ann[CELL_LABEL]: logger.warning( f"{ann[CELL_LABEL]} cell ids from CAS match with the cell ids in {cell_label} from anndata, " "but they have different cell labels." ) if validate: logger.error("Validation failed. Exiting.") sys.exit(1) # add new category to labelset column input_obs[ann[LABELSET]] = input_obs[ann[LABELSET]].cat.add_categories( ann[CELL_LABEL] ) # Overwrite the labelset value with CAS labelset input_obs.loc[ann[CELL_IDS], ann[LABELSET]] = input_obs.loc[ ann[CELL_IDS], ann[LABELSET] ].map({cell_label: ann[CELL_LABEL]})
[docs] def handle_non_matching_labelset(ann, input_obs, validate, derived_cell_ids): # Used for hierarchy changes logger.warning( f"{ann[CELL_LABEL]} cell ids from CAS do not match with the cell ids from anndata. " "Please update your CAS json." ) if validate: logger.error("Validation failed. Exiting.") sys.exit(1) # Flush the labelset from anndata # input_anndata.obs.loc[list(cell_list), cell_label] = "" # Add labelset from CAS to anndata cell_ids = derived_cell_ids.get(str(ann["cell_set_accession"]), set()) # Bad split workaround, temporary solution # Use Pandas indexing to filter the cell_ids present in obs.index valid_cell_ids = input_obs.index.intersection(cell_ids) input_obs.loc[valid_cell_ids, ann[LABELSET]] = str(ann[CELL_LABEL])
[docs] def validate_cell_ids(anndata_cell_ids, annotations, validate): # Collect cell ids from annotations cas_cell_ids = {cell_id for ann in annotations for cell_id in ann.get(CELL_IDS, [])} # Validate cas -> anndata if not cas_cell_ids <= anndata_cell_ids: logger.warning("Not all members of cell ids from cas exist in anndata.") if validate: logger.error("Validation failed. Exiting.") sys.exit(1) # Validate anndata -> cas if not anndata_cell_ids <= cas_cell_ids: logger.warning("Not all members of cell ids from anndata exist in cas.") if validate: logger.error("Validation failed. Exiting.") sys.exit(1)