from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
from anndata import AnnData
from cas.file_utils import read_anndata_file, read_json_file
from cas.utils.conversion_utils import ANNOTATIONS, CELL_IDS, fetch_anndata
[docs]
def split_anndata_to_file(
anndata_file_path: Optional[str],
cas_json_paths: List[str],
multiple_outputs: bool,
compression_method: Optional[Literal["gzip", "lzf"]] = "gzip",
):
"""
Splits an AnnData file into multiple files based on provided CAS JSON files and writes them to disk.
Args:
anndata_file_path: Path to the AnnData file.
cas_json_paths: List of CAS JSON file paths.
multiple_outputs: If True, outputs multiple files, one for each CAS JSON file; otherwise, outputs a single file.
compression_method: Compression method utilized in anndata write function. Default is "gzip".
"""
if not anndata_file_path:
anndata_file_path = fetch_anndata(
cas_json_paths[0]
) # Assuming all splits are coming from the same CAS JSON.
adata = read_anndata_file(anndata_file_path)
cas_list = {
Path(cas_json).name: read_json_file(cas_json) for cas_json in cas_json_paths
}
cas_paths = list(cas_list.keys())
result = split_anndata(adata, cas_list, multiple_outputs)
for idx, anndata_item in enumerate(result):
anndata_item.write_h5ad(
Path(
(
f"split_{cas_paths[idx].split('.')[0]}.h5ad"
if multiple_outputs
else "split_anndata.h5ad"
)
),
compression=compression_method,
)
[docs]
def split_anndata(
adata: AnnData, cas: Dict[str, Dict[str, Any]], multiple_outputs: bool
) -> List[AnnData]:
"""
Splits an AnnData object into multiple or single AnnData objects based on the provided CAS data.
Args:
adata: AnnData object.
cas: Dictionary representing the CAS data with its file name as keys.
multiple_outputs: Determines if the output should be multiple AnnData objects or a single one.
Returns:
A list of AnnData objects if multiple_outputs is True, otherwise a single AnnData object.
Raises:
ValueError: If any required terms do not exist in the CAS data under 'parent_cell_set_name'.
"""
def get_cell_ids(cas_data: Dict[str, Dict[str, Any]]) -> List[str]:
"""Extract unique cell IDs from CAS data."""
return list(
set(
[
cid
for cas_obj in cas_data.values()
for annotation in cas_obj[ANNOTATIONS]
for cid in annotation.get(CELL_IDS, [])
]
)
)
if multiple_outputs:
splitted_anndata_list = []
for cas_file_name, cas_object in cas.items():
cell_ids = get_cell_ids({cas_file_name: cas_object})
mask = adata.obs.index.isin(cell_ids)
adata_subset = adata[mask, :].to_memory()
splitted_anndata_list.append(adata_subset)
return splitted_anndata_list
else:
cell_ids = get_cell_ids(cas)
mask = adata.obs.index.isin(cell_ids)
adata_subset = adata[mask, :].to_memory()
return [adata_subset]