from collections import defaultdict
from typing import Any, Dict, List, Union
from cas.file_utils import read_json_file, write_dict_to_json_file
from cas.utils.conversion_utils import (
ANNOTATIONS,
CELL_SET_ACCESSION,
LABELSET,
LABELSET_NAME,
LABELSETS,
NT_ACCESSION,
PARENT_CELL_SET_ACCESSION,
)
[docs]
def split_cas_to_file(
cas_json_path: str, split_terms: Union[List[str], str], multiple_outputs: bool
):
"""
Splits a CAS JSON file into files based on provided terms, and writes them to disk.
Args:
cas_json_path: Path to the CAS JSON file.
split_terms: Terms used to determine how to split the CAS file; can be a string or a list of strings.
multiple_outputs: If True, outputs multiple files, one for each split term; otherwise, outputs a single file.
"""
cas = read_json_file(cas_json_path)
result = split_cas(cas, split_terms, multiple_outputs)
if isinstance(result, dict):
result = [result]
for idx, cas_item in enumerate(result):
write_dict_to_json_file(
(
f"cas_{split_terms[idx].replace(':', '_')}.json"
if multiple_outputs
else "split_cas.json"
),
cas_item,
)
[docs]
def split_cas(
cas: Dict[str, Any], split_terms: Union[List[str], str], multiple_outputs: bool
) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
"""
Splits a CAS dictionary into multiple or single dictionary based on split terms.
Args:
cas: Dictionary representing the CAS data.
split_terms: Terms used to filter and split the CAS data; can be a string or a list of strings.
multiple_outputs: Determines if the output should be multiple dictionaries or a single dictionary.
Returns:
A list of dictionaries if multiple_outputs is True, otherwise a single dictionary.
Raises:
ValueError: If any split_terms do not exist in the CAS data under 'parent_cell_set_name'.
"""
parent_cell_dict = defaultdict(list)
for annotation in cas[ANNOTATIONS]:
child = annotation[CELL_SET_ACCESSION]
# Map parent cell accession to the child cell
if PARENT_CELL_SET_ACCESSION in annotation:
parent_cell = annotation[PARENT_CELL_SET_ACCESSION]
parent_cell_dict[parent_cell].append(child)
# Map NT accession to the child cell
if NT_ACCESSION in annotation:
nt_parent = annotation[NT_ACCESSION]
parent_cell_dict[child].append(nt_parent)
if isinstance(split_terms, str):
split_terms = [split_terms]
keys_and_values = list(parent_cell_dict.keys()) + [
item
for sublist in parent_cell_dict.values()
if isinstance(sublist, list)
for item in sublist
]
missing_terms = [term for term in split_terms if term not in keys_and_values]
if missing_terms:
raise ValueError(
f"{', '.join(missing_terms)} do not exist in CAS as 'cell_set_name'"
)
if multiple_outputs:
splitted_cas_list = []
for term in split_terms:
label_to_copy_list = get_split_terms(parent_cell_dict, term)
splitted_cas_list.append(
filter_and_copy_cas_entries(cas, label_to_copy_list)
)
return splitted_cas_list
else:
label_to_copy_list = get_split_terms(parent_cell_dict, split_terms)
return filter_and_copy_cas_entries(cas, label_to_copy_list)
[docs]
def filter_and_copy_cas_entries(
cas: Dict[str, Any], label_to_copy_list: List[str]
) -> Dict[str, Any]:
"""
Copies entries from the CAS based on a list of labels to copy.
Args:
cas: Dictionary representing the original CAS data.
label_to_copy_list: List of labels indicating which entries to copy.
Returns:
A dictionary with filtered CAS entries.
"""
output_dict = dict(cas)
output_dict[ANNOTATIONS] = []
output_dict[LABELSETS] = []
labelset_dict = set()
for annotation in cas[ANNOTATIONS]:
if annotation[CELL_SET_ACCESSION] in label_to_copy_list:
output_dict[ANNOTATIONS].append(annotation)
labelset_dict.add(annotation[LABELSET])
for labelset in cas[LABELSETS]:
if labelset[LABELSET_NAME] in labelset_dict:
output_dict[LABELSETS].append(labelset)
return output_dict
[docs]
def get_split_terms(
parent_dict: Dict[str, List[str]], split_terms: Union[List[str], str]
) -> List[str]:
"""
Resolves split terms into a comprehensive list of terms based on a parent-child relationship dictionary.
Args:
parent_dict: Dictionary mapping parent terms to lists of child terms.
split_terms: Initial terms to resolve, can be a string or a list of strings.
Returns:
A list of all terms, resolved from the parent_dict.
"""
if isinstance(split_terms, str):
split_terms = [split_terms]
result = set()
stack = list(split_terms)
while stack:
term = stack.pop()
if term in parent_dict:
stack.extend(parent_dict[term])
result.add(term)
return list(result)