diff --git a/NOTICE.md b/NOTICE.md index f811144..8bf568e 100644 --- a/NOTICE.md +++ b/NOTICE.md @@ -155,32 +155,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# NMFreg tutorial - -**Source**: https://github.com/tudaga/NMFreg_tutorial - -MIT License - -Copyright (c) 2019 Aleksandrina Goeva - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - # Sphinx-Autosummary-Recursion **Source**: https://github.com/JamesALeedham/Sphinx-Autosummary-Recursion diff --git a/docsource/index.rst b/docsource/index.rst index b159d1e..48a75f1 100644 --- a/docsource/index.rst +++ b/docsource/index.rst @@ -7,6 +7,7 @@ Home Examples API reference <_autosummary/tacco> + Release Notes References Welcome to the TACCO documentation! diff --git a/docsource/release_notes.rst b/docsource/release_notes.rst new file mode 100644 index 0000000..dd9d770 --- /dev/null +++ b/docsource/release_notes.rst @@ -0,0 +1,8 @@ +Release Notes +------------- + +.. toctree:: + :maxdepth: 2 + :glob: + + release_notes/release_notes_* diff --git a/docsource/release_notes/release_notes_0.3.0.rst b/docsource/release_notes/release_notes_0.3.0.rst new file mode 100644 index 0000000..0261032 --- /dev/null +++ b/docsource/release_notes/release_notes_0.3.0.rst @@ -0,0 +1,56 @@ +TACCO 0.3.0 (2024-01-10) +======================== + +Features +-------- + +- :func:`tacco.plots.subplots`: support for changing the dpi and forwarding kwargs to :func:`matplotlib.pyplot.subplots` + +- :func:`tacco.plots.dotplot`: new swap_axes argument + +- :func:`tacco.utils.split_spatial_samples`: more flexible and intuitive reimplementation for spatial splitting samples which can account explicitly for spatial correlations; deprecates the :func:`tacco.utils.spatial_split` function + +- :func:`tacco.tools.setup_orthology_converter`, :func:`tacco.tools.run_orthology_converter`: orthology conversion between species + +- :func:`tacco.get.get_data_from_key`: general getter to retrieve data from an anndata given a data path + +- :func:`tacco.get.get_positions`: support for data paths + +Fixes +-------- + +- :func:`tacco.tools.annotate`: reconstruction_key now honors max_annotation. So :func:`tacco.tools.split_observations` works with reconstruction_key as well. This fixes issue `#9 `__ . + +- :func:`tacco.tools.split_observations`: fixed map_obsm_keys parameter + +- :func:`tacco.plots.significances`: fix using pre-supplied ax, fix not significant annotated but significance colored cells, fix future warning, work for data with enrichment and without depletion + +- :func:`tacco.plots.dotplot`: catch edge case with gene-group combinations without a match in "marks" + +- :func:`tacco.plots.co_occurrence`: fixed bug for multiple anndatas in the input + +- :func:`tacco.plots.co_occurrence_matrix`: fixed bug for restrict_intervals=None + +- :func:`tacco.tools.annotate`: multi_center=1 changed so it now behaves the same as multi_center=0/None, fix FutureWarning from kmeans + +- :func:`tacco.tools.get_contributions`: fix FutureWarning from groupby + +- :func:`tacco.plots.co_occurrence`, :func:`tacco.plots.co_occurrence_matrix`: coocurrence plots now follow the show_only and show_only_center order + +Documentation +------------- + +- Add release notes + +- Add visium example to address `#8 https://github.com/simonwm/tacco/issues/8`__ + +Miscellaneous +------------- + +- Switch from setup.cfg to pyproject.toml + +- Generalization of benchmarking to support conda-forge time + +- Expanded testing + +- Remove duplication in NOTICE.md diff --git a/docsource/release_notes/template_release_notes.rst b/docsource/release_notes/template_release_notes.rst new file mode 100644 index 0000000..1256248 --- /dev/null +++ b/docsource/release_notes/template_release_notes.rst @@ -0,0 +1,35 @@ +.. before release, ensure a correct date and version number here and in the name of the release_notes_a.b.c.rst file!!! +.. and remove all the comments here... + +TACCO a.b.c (yyyy-mm-dd) +======================== + +.. all sections are optional: remove if empty +.. link to issues/PRs/etc on github where applicable + +Features +-------- +.. include new and enhanced features here, as well as significant performance improvements + +- Add :func:`tacco.tools.new_func` to do new things. + +- Add :func:`tacco.tools.other_new_func` to implement feature request. + +Fixes +-------- +.. include resolution of bugs and very unintuitive behaviour here + +- Fix some issue `#9 `__ + +Documentation +------------- +.. include documentation updates here + +- Add example `#8 https://github.com/simonwm/tacco/issues/8`__ + +Miscellaneous +------------- +.. include all the rest here + +- Deprecated old functionality + diff --git a/pyproject.toml b/pyproject.toml index b2c34f2..2c764bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,32 @@ [build-system] -requires = ["setuptools<=60", "wheel", "setuptools_scm>=6.2"] #for next release: could be relaxed, maybe to >61 +requires = ["setuptools", "wheel", "setuptools_scm"] build-backend = "setuptools.build_meta" [project] name = "tacco" +description = "TACCO: Transfer of Annotations to Cells and their COmbinations" +authors = [ + {name = "Simon Mages"}, + {name = "Noa Moriel"}, + {name = "Jan Watter"}, +] +maintainers = [ + {name = "Jan Watter", email = "jan.watter@gmail.com"}, + {name = "Simon Mages", email = "smages@broadinstitute.org"}, +] readme = "README.md" license = {file = "LICENCE.txt"} requires-python = ">=3.7" + classifiers = [ + "License :: OSI Approved :: BSD License", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Natural Language :: English", "Programming Language :: Python :: 3", "Operating System :: OS Independent", + "Topic :: Scientific/Engineering :: Bio-Informatics", + "Topic :: Scientific/Engineering :: Visualization", ] dynamic = ["version"] @@ -31,14 +48,34 @@ dependencies = [ "mkl_service", "scikit_learn", "threadpoolctl", - "pyamg" + "pyamg", ] [project.urls] -documentation = "https://simonwm.github.io/tacco/index.html" +Source = "https://github.com/simonwm/tacco" +Documentation = "https://simonwm.github.io/tacco/index.html" [tool.setuptools_scm] local_scheme = "no-local-version" -[tool.setuptools.dynamic] -version = {attr = "tacco.VERSION"} \ No newline at end of file +[project.optional-dependencies] +doc = [ + "sphinx", + "sphinx-rtd-theme", + "pydata-sphinx-theme", + "sphinx-autodoc-typehints", + "nbsphinx", + "ipython", + "jupytext", + "jupyter", +] + +benchmark = ["time"] + +test = ["pytest"] + +[tool.pytest.ini_options] +testpaths = ["tests"] + +[tool.setuptools] +packages = ["tacco"] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 7b23cbc..0000000 --- a/setup.cfg +++ /dev/null @@ -1,41 +0,0 @@ -[metadata] -name = tacco -description = "TACCO: Transfer of Annotations to Cells and their COmbinations" -#url = "https://github.com/simonwm/tacco" -author = "Simon Mages" -author_email = "smages@broadinstitute.org" - -[options] -packages = find: -install_requires = - requests - importlib; python_version == "2.6" - joblib - numba>=0.51.2 - numpy - matplotlib!=3.7.0 - seaborn - sparse_dot_mkl>=0.7.3 - scanpy>=1.7.0 - statsmodels - anndata - pandas>=1.1.0 - scipy>=1.6.0 - mkl - mkl_service - scikit_learn - threadpoolctl - pyamg - -[options.extras_require] -tests = - pytest -docs = - sphinx - sphinx-rtd-theme - pydata-sphinx-theme - sphinx-autodoc-typehints - nbsphinx - ipython - jupytext - jupyter diff --git a/tacco/benchmarking/_benchmarking.py b/tacco/benchmarking/_benchmarking.py index 255baf9..637f887 100644 --- a/tacco/benchmarking/_benchmarking.py +++ b/tacco/benchmarking/_benchmarking.py @@ -9,6 +9,17 @@ import numpy as np import pandas as pd +TIME_PATH = None +BENCHMARKING_AVAILABLE = False + +if os.path.exists('/usr/bin/time'): + TIME_PATH = '/usr/bin/time' + BENCHMARKING_AVAILABLE = True +elif os.path.exists(sys.exec_prefix + '/bin/time'): + TIME_PATH = sys.exec_prefix + '/bin/time' + BENCHMARKING_AVAILABLE = True + + def _set_up_benchmark_working_directory( working_directory, ): @@ -90,7 +101,7 @@ def benchmark_shell( # run the command proc = subprocess.Popen( - ['/usr/bin/time','-f','wall_clock_time_seconds %e\nmax_memory_used_kbytes %M\nexit_status %x',command,*command_args,], + [TIME_PATH,'-f','wall_clock_time_seconds %e\nmax_memory_used_kbytes %M\nexit_status %x',command,*command_args,], cwd=working_directory, stderr=subprocess.PIPE, stdout=subprocess.PIPE, @@ -214,6 +225,9 @@ def benchmark_annotate( reading data under the key "benchmark_time_s". """ + + if not BENCHMARKING_AVAILABLE: + raise Exception('No /usr/bin/time or conda-forge time executable found. If on macOS or linux install conda-forge time in your current conda env to run benchmarks') if working_directory is not None and 'annotation_key' not in kw_args: print('`working_directory` is set, but `annotation_key` is not. This frequently is a mistake.\nIf you are certain that it it not, you can deactivate this message by explicitly setting `annotation_key` to `None`.') diff --git a/tacco/get/__init__.py b/tacco/get/__init__.py index 938d85e..55cdc50 100644 --- a/tacco/get/__init__.py +++ b/tacco/get/__init__.py @@ -3,5 +3,6 @@ """ # expose the API +from ._data import get_data_from_key as data_from_key from ._counts import get_counts as counts from ._positions import get_positions as positions diff --git a/tacco/get/_data.py b/tacco/get/_data.py new file mode 100644 index 0000000..6848032 --- /dev/null +++ b/tacco/get/_data.py @@ -0,0 +1,446 @@ +import pandas as pd +import anndata as ad +from scipy.sparse import issparse as _issparse + +def _as_series(series_like, index, name): + """\ + If the series_like is already a :class:`~pandas.Series`, pass through, + otherwise wrap in a :class:`~pandas.Series`. + """ + if not isinstance(series_like, pd.Series): + if _issparse(series_like): + series_like = series_like.toarray() + if len(series_like.shape) > 1: + series_like = series_like.flatten() + series_like = pd.Series(series_like,index=index,name=name) + return series_like + +def _as_frame(frame_like, index, columns): + """\ + If the frame_like is already a :class:`~pandas.DataFrame`, pass through, + otherwise wrap in a :class:`~pandas.DataFrame`. + """ + if not isinstance(frame_like, pd.DataFrame): + if _issparse(frame_like): + frame_like = frame_like.toarray() + col_args = {} if columns == ... else {'columns':columns} + frame_like = pd.DataFrame(frame_like,index=index,**col_args) + return frame_like + +def _as_list(list_like_or_element): + """\ + Wrap the list_like_or_element in a list. + """ + _list = list_like_or_element + if isinstance(_list,str): # wrap strings + _list = [_list] + _list = list(_list) # convert tuples etc. + return _list + +def get_data_from_key( + adata, + key, + default_key=None, + key_name='key', + check_uns=False, + search_order=('X','layer','obs','var','obsm','varm'), + check_uniqueness=True, + result_type='obs', + verbose=1, + raise_on_error=True, +): + + """\ + Given a short key for data stored somewhere in an :class:`~anndata.AnnData` + find the data associated with it following a deterministic scheme. + + Parameters + ---------- + adata + A :class:`~anndata.AnnData`; if it is a :class:`~pandas.DataFrame`, use + it in place of a `adata.obs` or `adata.var` data frame. + key + The short key or "path" in `adata` where the data is stored as a string + (e.g. "X") or as a list-like of strings, e.g. `("raw","obsm","counts")`, + where the last element can be a list-like of stings, e.g. + `("obsm","types",["B","T"])`. + Possible parts of paths are: + + - "X" + - "layer" + - "obs" + - "var" + - "obsm" + - "varm" + - "raw" + + If `None` and `check_uns` is not `False`, looks in + `adata.uns[key_name]` for the key and if unsuccessful uses + `default_key`. + default_key + The default location to use; see `key`. + key_name + The name of the key to use for lookup in `.uns`; see `key`. Also used + for meaningful error messages. + check_uns + Whether to check `.uns` for the key; if a string, overrides the value + of `key_name` for the lookup; see `key` and `key_name`. + search_order + In which order to check the properties of :class:`~anndata.AnnData` for + the `key`, if it is a string. The first hit will be returned. If + `check_uniqueness`, the remaining properties will be checked for + uniqueness of the hit, generating a warning for non-unique hits. In + addition to the properties 'X','layer','obs','var','obsm', and 'varm', + the two pseudo properties 'multi-obs' and 'multi-var' can be specified + to allow for the selection of multiple obs/var columns. + check_uniqueness + Whether to check for uniqueness of the hit; see `search_order`. + result_type + What type of data to look for. Options are: + + - "obs": return a :class:`~pandas.Series` with index like a `.obs` + column + - "var": return a :class:`~pandas.Series` with index like a `.var` + column + - "obsm": return a :class:`~pandas.DataFrame` with index like a `.obs` + column + - "varm": return a :class:`~pandas.DataFrame` with index like a `.var` + column + - "X": return an array-like of shape compatible with `adata` + + If "obs" or "obsm", then "var" and "varm" will be excluded from the + `search_order`. If "var" or "varm", then "obs" and "obsm" will be + excluded from the `search_order`. If "X", then "obs" and "var" will be + excluded from the `search_order`. + verbose + Level of verbosity, with `0` (no output), `1` (some output), ... + raise_on_error + Whether to raise on errors. Alternatively just return `None`. + + Returns + ------- + Depending on result_type, a :class:`~pandas.Series`,\ + :class:`~pandas.DataFrame`, or array-like containing the data (or `None`\ + on failure it not `raise_on_error`). + + """ + + # generic exception handling + def raise_Exception(msg): + if raise_on_error: + raise ValueError(msg) + elif verbose > 0: + print(msg) + return None + def print_Message(msg, verbose): + if verbose > 0: + print(msg) + + # extra layer of indirection for adata access + if isinstance(adata, ad.AnnData): + adata_obs = adata.obs + adata_var = adata.var + adata_obsm = adata.obsm + adata_varm = adata.varm + adata_uns = adata.uns + adata_layers = adata.layers + real_adata = True + else: # support dataframes as replacement for adata.obs and adata.var + # result_type indicates whether obs or var should be used, so we can populate both here + adata_obs = adata + adata_var = adata + # put empty dictionaries here to enable the 'is in' evaluations below + adata_obsm = {} + adata_varm = {} + adata_uns = {} + adata_layers = {} + # here we need explicit code changes below for .X and .raw access... + real_adata = False + def no_real_adata_message(requested_property): + return f'The key path for the key named {key_name!r} is {key!r} and contains "{requested_property}", but the supplied adata is of type "{type(adata)!r}"! Specifying "{requested_property}" paths in the key is only possible if adata is indeed an AnnData object!' + + # ensure sanity of `key` + if key is None and check_uns != False: # check_uns can also be a string, so `if check_uns` is not sufficient + _key_name = key_name if isinstance(check_uns, bool) else check_uns + if _key_name in adata_uns: + key = adata_uns[_key_name] + if key is None and not default_key is None: + key = default_key + if key is None: + return raise_Exception(f'The key named {key_name!r} is `None` and did not get valid options for `default_key` and `check_uns`!') + + # ensure sanity of `result_type` + valid_result_types = ['obs','obsm','var','varm','X'] + if result_type not in valid_result_types: + return raise_Exception(f'The result_type {result_type!r} is invalid! Only {valid_result_types!r} can be used.') + + + def _get_hit(): + # look for hits + hits = [] + for element in search_order: + if element == 'X': # look in the counts + if not isinstance(key, str): + continue + if result_type in ['obs','obsm']: # look for a gene in the counts + if key in adata_var.index: + hits.append((element,key)) + if not check_uniqueness: + return hits[0] + elif result_type in ['var','varm']: # look for a cell in the counts + if key in adata_obs.index: + hits.append((element,key)) + if not check_uniqueness: + return hits[0] + elif element == 'layer': # look in the layers + if not isinstance(key, str): + continue + if result_type in ['X']: # only full count matrices can be specified like this + if key in adata_layers: + hits.append((element,key)) + if not check_uniqueness: + return hits[0] + elif element == 'obs': # look in the obs annotation + if not isinstance(key, str): + continue + if result_type in ['obs','obsm']: # look for a single annotation column + if key in adata_obs: + hits.append((element,key)) + if not check_uniqueness: + return hits[0] + elif element == 'var': # look in the var annotation + if not isinstance(key, str): + continue + if result_type in ['var','varm']: # look for a single annotation row + if key in adata_var: + hits.append((element,key)) + if not check_uniqueness: + return hits[0] + elif element == 'obsm': # look in the obsm annotation + if not isinstance(key, str): + continue + if result_type in ['obsm']: # obsms only fit to obsms + if key in adata_obsm: + hits.append((element,key)) + if not check_uniqueness: + return hits[0] + elif element == 'varm': # look in the varm annotation + if not isinstance(key, str): + continue + if result_type in ['varm']: # varms only fit to varms + if key in adata_varm: + hits.append((element,key)) + if not check_uniqueness: + return hits[0] + elif element == 'multi-obs': # look for a set of obs annotations + if isinstance(key, str): + continue + if result_type in ['obsm']: # a set of obs annotations only fit to obsms + if all(_key in adata_obs for _key in key): + hits.append(('obs',key)) + if not check_uniqueness: + return hits[0] + elif element == 'multi-var': # look for a set of var annotations + if isinstance(key, str): + continue + if result_type in ['varm']: # a set of var annotations only fit to varms + if all(_key in adata_var for _key in key): + hits.append(('var',key)) + if not check_uniqueness: + return hits[0] + else: + raise_Exception(f'The element {element!r} of search_order {search_order!r} is not valid!') + if len(hits) == 0: + if isinstance(key, str): + raise_Exception(f'The key {key!r} was not found anywhere in {search_order!r} for a result_type of {result_type!r}!') + else: + hits.append(key) + elif len(hits) > 1: + print_Message(f'The key {key!r} was not found in more than one location: {hits!r}! Continue using the first hit {hits[0]!r}', verbose=verbose) + return hits[0] + key = _get_hit() + + # look at the specified position + if len(key) < 1: + return raise_Exception(f'The key named {key_name!r} is {key!r} which is not a valid path! Only list-likes of length > 0 could be valid.') + + if key[0] == 'raw': + if not isinstance(adata, ad.AnnData): + return raise_Exception(no_real_adata_message('raw')) + return get_data_from_key( + adata=adata.raw.to_adata(), + key=key[1:], + default_key=default_key, + key_name=key_name, + check_uns=check_uns, + search_order=search_order, + check_uniqueness=check_uniqueness, + result_type=result_type, + verbose=verbose, + raise_on_error=raise_on_error, + ) + + + if result_type == 'obs': + if key[0] == 'X': + if not isinstance(adata, ad.AnnData): + return raise_Exception(no_real_adata_message('X')) + if len(key) != 2: + return raise_Exception(f'An obs result_type can only be retrieved from "X" with a key with length of exactly 2, but {key!r} was given!') + if key[1] not in adata_var.index: + return raise_Exception(f'An obs result_type can only be retrieved from "X" if the second path element matches a gene name, but {key[1]!r} is not an available gene name!') + return _as_series(adata[:,[key[1]]].X,index=adata_obs.index,name=key[1]) + elif key[0] == 'layer': + if len(key) != 3: + return raise_Exception(f'An obs result_type can only be retrieved from "layer" with a key with length of exactly 3, but {key!r} was given!') + if key[1] not in adata_layers: + return raise_Exception(f'An obs result_type can only be retrieved from "layer" if the second path element matches a layer name, but {key[1]!r} is not an available layer!') + if key[2] not in adata_var.index: + return raise_Exception(f'An obs result_type can only be retrieved from "layer" if the third path element matches a gene name, but {key[2]!r} is not an available gene name!') + return _as_series(adata[:,[key[2]]].layers[key[1]],index=adata_obs.index,name=key[2]) + elif key[0] == 'obs': + if len(key) != 2: + return raise_Exception(f'An obs result_type can only be retrieved from "obs" with a key with length of exactly 2, but {key!r} was given!') + if key[1] not in adata_obs.columns: + return raise_Exception(f'An obs result_type can only be retrieved from "obs" if the second path element matches a obs key, but {key[1]!r} is not an available obs column!') + return _as_series(adata_obs[key[1]],index=adata_obs.index,name=key[1]) + elif key[0] == 'obsm': + if len(key) != 3: + return raise_Exception(f'An obs result_type can only be retrieved from "obsm" with a key with length of exactly 3, but {key!r} was given!') + if key[1] not in adata_obsm: + return raise_Exception(f'An obs result_type can only be retrieved from "obsm" if the second path element matches an obsm name, but {key[1]!r} is not an available obsm!') + if key[2] not in adata_obsm[key[1]].columns: + return raise_Exception(f'An obs result_type can only be retrieved from "obsm" if the third path element matches a column name of the selected obsm, but {key[2]!r} is not a available there!') + return _as_series(adata_obsm[key[1]][key[2]],index=adata_obs.index,name=key[2]) + else: + return raise_Exception(f'An obs result_type can only be retrieved from "X","layer","obs", and "obsm" paths, but {key!r} starts with {key[0]!r}!') + + elif result_type == 'var': + if key[0] == 'X': + if not isinstance(adata, ad.AnnData): + return raise_Exception(no_real_adata_message('X')) + if len(key) != 2: + return raise_Exception(f'An var result_type can only be retrieved from "X" with a key with length of exactly 2, but {key!r} was given!') + if key[1] not in adata_obs.index: + return raise_Exception(f'An var result_type can only be retrieved from "X" if the second path element matches a cell name, but {key[1]!r} is not an available cell name!') + return _as_series(adata[[key[1]],:].X,index=adata_var.index,name=key[1]) + elif key[0] == 'layer': + if len(key) != 3: + return raise_Exception(f'An var result_type can only be retrieved from "layer" with a key with length of exactly 3, but {key!r} was given!') + if key[1] not in adata_layers: + return raise_Exception(f'An var result_type can only be retrieved from "layer" if the second path element matches a layer name, but {key[1]!r} is not an available layer!') + if key[2] not in adata_obs.index: + return raise_Exception(f'An var result_type can only be retrieved from "layer" if the third path element matches a cell name, but {key[2]!r} is not an available cell name!') + return _as_series(adata[[key[2]],:].layers[key[1]],index=adata_var.index,name=key[2]) + elif key[0] == 'var': + if len(key) != 2: + return raise_Exception(f'An var result_type can only be retrieved from "var" with a key with length of exactly 2, but {key!r} was given!') + if key[1] not in adata_var.columns: + return raise_Exception(f'An var result_type can only be retrieved from "var" if the second path element matches a var key, but {key[1]!r} is not an available var column!') + return _as_series(adata_var[key[1]],index=adata_var.index,name=key[1]) + elif key[0] == 'varm': + if len(key) != 3: + return raise_Exception(f'An var result_type can only be retrieved from "varm" with a key with length of exactly 3, but {key!r} was given!') + if key[1] not in adata_varm: + return raise_Exception(f'An var result_type can only be retrieved from "varm" if the second path element matches an varm name, but {key[1]!r} is not an available varm!') + if key[2] not in adata_varm[key[1]].columns: + return raise_Exception(f'An var result_type can only be retrieved from "varm" if the third path element matches a column name of the selected varm, but {key[2]!r} is not a available there!') + return _as_series(adata_varm[key[1]][key[2]],index=adata_var.index,name=key[2]) + else: + return raise_Exception(f'An var result_type can only be retrieved from "X","layer","var", and "varm" paths, but {key!r} starts with {key[0]!r}!') + + elif result_type == 'obsm': + if key[0] == 'X': + if not isinstance(adata, ad.AnnData): + return raise_Exception(no_real_adata_message('X')) + if len(key) != 2: + return raise_Exception(f'An obsm result_type can only be retrieved from "X" with a key with length of exactly 2, but {key!r} was given!') + columns = _as_list(key[1]) + for column in columns: + if column not in adata_var.index: + return raise_Exception(f'An obsm result_type can only be retrieved from "X" if every element of the second path element matches a gene name, but {column!r} is not an available gene name!') + return _as_frame(adata[:,columns].X,index=adata_obs.index,columns=columns) + elif key[0] == 'layer': + if len(key) != 3: + return raise_Exception(f'An obsm result_type can only be retrieved from "layer" with a key with length of exactly 3, but {key!r} was given!') + if key[1] not in adata_layers: + return raise_Exception(f'An obsm result_type can only be retrieved from "layer" if the second path element matches a layer name, but {key[1]!r} is not an available layer!') + columns = _as_list(key[2]) + for column in columns: + if column not in adata_var.index: + return raise_Exception(f'An obsm result_type can only be retrieved from "layer" if every element of the third path element matches a gene name, but {column!r} is not an available gene name!') + return _as_frame(adata[:,columns].layers[key[1]],index=adata_obs.index,columns=columns) + elif key[0] == 'obs': + if len(key) != 2: + return raise_Exception(f'An obsm result_type can only be retrieved from "obs" with a key with length of exactly 2, but {key!r} was given!') + columns = _as_list(key[1]) + for column in columns: + if column not in adata_obs.columns: + return raise_Exception(f'An obsm result_type can only be retrieved from "obs" if every element of the second path element matches a obs key, but {column!r} is not an available obs column!') + return _as_frame(adata_obs[columns],index=adata_obs.index,columns=columns) + elif key[0] == 'obsm': + if len(key) != 2: + return raise_Exception(f'An obsm result_type can only be retrieved from "obsm" with a key with length of exactly 2, but {key!r} was given!') + if key[1] not in adata_obsm: + return raise_Exception(f'An obsm result_type can only be retrieved from "obsm" if the second path element matches an obsm name, but {key[1]!r} is not an available obsm!') + return _as_frame(adata_obsm[key[1]],index=adata_obs.index,columns=...) + else: + return raise_Exception(f'An obsm result_type can only be retrieved from "X","layer","obs", and "obsm" paths, but {key!r} starts with {key[0]!r}!') + + elif result_type == 'varm': + if key[0] == 'X': + if not isinstance(adata, ad.AnnData): + return raise_Exception(no_real_adata_message('X')) + if len(key) != 2: + return raise_Exception(f'An varm result_type can only be retrieved from "X" with a key with length of exactly 2, but {key!r} was given!') + columns = _as_list(key[1]) + for column in columns: + if column not in adata_obs.index: + return raise_Exception(f'An varm result_type can only be retrieved from "X" if every element of the second path element matches a cell name, but {column!r} is not an available cell name!') + return _as_frame(adata[columns,:].X.T,index=adata_var.index,columns=columns) + elif key[0] == 'layer': + if len(key) != 3: + return raise_Exception(f'An varm result_type can only be retrieved from "layer" with a key with length of exactly 3, but {key!r} was given!') + if key[1] not in adata_layers: + return raise_Exception(f'An varm result_type can only be retrieved from "layer" if the second path element matches a layer name, but {key[1]!r} is not an available layer!') + columns = _as_list(key[2]) + for column in columns: + if column not in adata_obs.index: + return raise_Exception(f'An varm result_type can only be retrieved from "layer" if every element of the third path element matches a cell name, but {column!r} is not an available cell name!') + return _as_frame(adata[columns,:].layers[key[1]].T,index=adata_var.index,columns=columns) + elif key[0] == 'var': + if len(key) != 2: + return raise_Exception(f'An varm result_type can only be retrieved from "var" with a key with length of exactly 2, but {key!r} was given!') + columns = _as_list(key[1]) + for column in columns: + if column not in adata_var.columns: + return raise_Exception(f'An varm result_type can only be retrieved from "var" if every element of the second path element matches a var key, but {column!r} is not an available var column!') + return _as_frame(adata_var[columns],index=adata_var.index,columns=columns) + elif key[0] == 'varm': + if len(key) != 2: + return raise_Exception(f'An varm result_type can only be retrieved from "varm" with a key with length of exactly 2, but {key!r} was given!') + if key[1] not in adata_varm: + return raise_Exception(f'An varm result_type can only be retrieved from "varm" if the second path element matches an varm name, but {key[1]!r} is not an available varm!') + return _as_frame(adata_varm[key[1]],index=adata_var.index,columns=...) + else: + return raise_Exception(f'An varm result_type can only be retrieved from "X","layer","var", and "varm" paths, but {key!r} starts with {key[0]!r}!') + + elif result_type == 'X': + if key[0] == 'X': + if not isinstance(adata, ad.AnnData): + return raise_Exception(no_real_adata_message('X')) + if len(key) != 1: + return raise_Exception(f'An X result_type can only be retrieved from "X" with a key with length of exactly 1, but {key!r} was given!') + return adata.X + elif key[0] == 'layer': + if len(key) != 2: + return raise_Exception(f'An X result_type can only be retrieved from "layer" with a key with length of exactly 2, but {key!r} was given!') + if key[1] not in adata_layers: + return raise_Exception(f'An X result_type can only be retrieved from "layer" if the second path element matches a layer name, but {key[1]!r} is not an available layer!') + return adata_layers[key[1]] + else: + return raise_Exception(f'An X result_type can only be retrieved from "X" and "layer", but {key!r} starts with {key[0]!r}!') + + else: + return raise_Exception(f'The result_type {result_type!r} cannot be interpreted.') + diff --git a/tacco/get/_positions.py b/tacco/get/_positions.py index 202ed21..39d9027 100644 --- a/tacco/get/_positions.py +++ b/tacco/get/_positions.py @@ -1,5 +1,6 @@ import pandas as pd import anndata as ad +from ._data import get_data_from_key def get_positions( adata, @@ -16,27 +17,19 @@ def get_positions( treated like the `.obs` of an :class:`~anndata.AnnData`. position_key The `.obsm` key or array-like of `.obs` keys with the position space - coordinates. - + coordinates. Also supports data paths as specified in + :func:`~tacco.tl.get_data_from_key`. + Returns ------- A :class:`~pandas.DataFrame` with the positions of the observations. """ - coords = None - if isinstance(position_key, str): - if isinstance(adata, ad.AnnData) and position_key in adata.obsm: - coords = adata.obsm[position_key] - else: - position_key = [position_key] - if coords is None: - if isinstance(adata, ad.AnnData): - coords = adata.obs[list(position_key)] - else: - coords = adata[list(position_key)] - - if not hasattr(coords, 'columns'): - coords = pd.DataFrame(coords, index=adata.obs.index) - - return coords \ No newline at end of file + return get_data_from_key( + adata=adata, + key=position_key, + key_name='position_key', + result_type='obsm', + search_order=('obsm','multi-obs','obs'), + ) diff --git a/tacco/plots/_plots.py b/tacco/plots/_plots.py index 23966b8..9393507 100644 --- a/tacco/plots/_plots.py +++ b/tacco/plots/_plots.py @@ -121,7 +121,7 @@ def _scatter_frame(df, colors, ax_, id_=None, cmap=None, cmap_vmin_vmax=None, ou color_array = np.array([colors[c] for c in df.columns.difference(['x','y'])]) - dpi = matplotlib.rcParams['figure.dpi'] + dpi = ax_[0].get_figure().get_dpi() point_radius = np.sqrt(point_size / np.pi) / 2 * dpi # convert from pixel to axes units @@ -317,7 +317,7 @@ def _render_frame(df, colors, ax_, id_=None, cmap=None, cmap_vmin_vmax=None, out n_types = len(typing.columns) weights = typing - dpi = matplotlib.rcParams['figure.dpi'] + dpi = ax_[0].get_figure().get_dpi() if id_ is None: n_cols = len(colors) @@ -811,6 +811,8 @@ def subplots( height_ratios=None, x_shifts=None, y_shifts=None, + dpi=None, + **kwargs, ): """\ Creates a new figure with a grid of subplots. @@ -857,6 +859,11 @@ def subplots( y_shifts The absolute shifts in position in vertical/y direction per row of subplots; if `None`, the rows are not shifted + dpi + The dpi setting to use for this figure + **kwargs + Extra keyword arguments are forwarded to + :func:`matplotlib.pyplot.subplots` Returns ------- @@ -894,7 +901,11 @@ def subplots( fig_height += title_space top = 1 - title_space / fig_height - fig, axs = plt.subplots(n_y,n_x,figsize=(fig_width,fig_height), squeeze=False, sharex=sharex, sharey=sharey, gridspec_kw={'wspace':effective_wspace,'hspace':effective_hspace,'left':0,'right':1,'top':top,'bottom':0,'width_ratios': width_ratios,'height_ratios': height_ratios}) + if dpi is not None: + kwargs = {**kwargs} + kwargs['dpi'] = dpi + + fig, axs = plt.subplots(n_y,n_x,figsize=(fig_width,fig_height), squeeze=False, sharex=sharex, sharey=sharey, gridspec_kw={'wspace':effective_wspace,'hspace':effective_hspace,'left':0,'right':1,'top':top,'bottom':0,'width_ratios': width_ratios,'height_ratios': height_ratios}, **kwargs) if title is not None: fig.suptitle(title, fontsize=16, y=1) @@ -918,7 +929,7 @@ def _add_legend_or_colorbars(fig, axs, colors, cmap=None, min_max=None, scale_le bbox_to_anchor=(1, 1), loc='upper left', ncol=1) elif cmap is not None: - rel_dpi_factor = matplotlib.rcParams['figure.dpi'] / 72 + rel_dpi_factor = fig.get_dpi() / 72 height_pxl = 200 * rel_dpi_factor * scale_legend width_pxl = 15 * rel_dpi_factor * scale_legend offset_top_pxl = 0 * rel_dpi_factor * scale_legend @@ -2843,6 +2854,8 @@ def significances( p_key, value_key, group_key, + enrichment_key='enrichment', + enriched_label='enriched', pmax=0.05, pmin=1e-5, annotate_pvalues=True, @@ -2861,13 +2874,24 @@ def significances( Parameters ---------- significances - A :class:`~pandas.DataFrame` with p-values and their annotation.. + A :class:`~pandas.DataFrame` with p-values and their annotation. If it + contains significances for enrichment and depletion, this direction has + to be specified with values "enriched" and something else (e.g. + "depleted" or "purified") in a column "enrichment" of the DataFrame. + See also the parameters `enrichment_key` and `enrichment_label`. p_key The key with the p-values. value_key The key with the values for which the enrichment was determined. group_key The key with the groups in which the enrichment was determined. + enrichment_key + The key with the direction of enrichment, i.e something like "enriched" + and "purified". See also parameter `enriched_label`. Default: + "enrichment". + enriched_label + The value under the key `enrichment_key` which indicates something like + enrichment. Default: "enriched". pmax The maximum p-value to show. pmin @@ -2900,25 +2924,45 @@ def significances( A :class:`~matplotlib.figure.Figure`. """ - - enr_e = pd.pivot(significances[significances['enrichment']=='enriched'], value_key, group_key, p_key) - enr_p = pd.pivot(significances[significances['enrichment']!='enriched'], value_key, group_key, p_key) small_value = 1e-300 max_log = -np.log(pmin) min_log = -np.log(pmax) - enr_e = np.maximum(enr_e,small_value) - enr_p = np.maximum(enr_p,small_value) - - enr_p = enr_p.reindex_like(enr_e) - enr = pd.DataFrame(np.where(enr_e < enr_p, -np.log(enr_e), np.log(enr_p)),index=enr_e.index,columns=enr_e.columns) + depleted_label = None + if enrichment_key is not None: + if enrichment_key not in significances: + raise ValueError(f'The column "{enrichment_key}" does not exist in the supplied dataframe! If this is intentional and you want to supress this error, supply "enrichment_key=None" as argument.') + unique_significance_labels = significances[enrichment_key].unique() + if len(unique_significance_labels) == 1: + enriched_label = unique_significance_labels[0] + else: + if len(unique_significance_labels) > 2 or enriched_label not in unique_significance_labels: + raise ValueError(f'The column "{enrichment_key}" is expected to have exactly 2 different values: "{enriched_label}" and something else, e.g. "depleted" or "purified"! The supplied column has the values {significances[enrichment_key].unique()}!') + depleted_label = unique_significance_labels[unique_significance_labels!=enriched_label][0] - if annotate_pvalues: + if depleted_label is not None: + enr_e = pd.pivot(significances[significances[enrichment_key]==enriched_label], index=value_key, columns=group_key, values=p_key) + enr_p = pd.pivot(significances[significances[enrichment_key]==depleted_label], index=value_key, columns=group_key, values=p_key) + + enr_e = np.maximum(enr_e,small_value) + enr_p = np.maximum(enr_p,small_value) + + enr_p = enr_p.reindex_like(enr_e) + + enr = pd.DataFrame(np.where(enr_e < enr_p, -np.log(enr_e), np.log(enr_p)),index=enr_e.index,columns=enr_e.columns) ann = pd.DataFrame(np.where(enr_e < enr_p, enr_e, enr_p),index=enr_e.index,columns=enr_e.columns) - ann = pd.DataFrame(np.where(ann <= pmax, ann.applymap(lambda x: f'{x:.2}'), ''),index=enr_e.index,columns=enr_e.columns) - ann = ann.T else: + ann = pd.pivot(significances, index=value_key, columns=group_key, values=p_key) + enr = -np.log(ann) + + # avoid discrepancies between color and annotation by basing both color and annotation on cuts on the same values + enr = pd.DataFrame(np.where(ann > pmax, 0, enr),index=enr.index,columns=enr.columns) + ann = pd.DataFrame(np.where(ann > pmax, '', ann.applymap(lambda x: f'{x:.2}')),index=enr.index,columns=enr.columns) + + enr = enr.T + ann = ann.T + if not annotate_pvalues: ann = None # setup the plotting @@ -2930,41 +2974,64 @@ def significances( np.array([[slightly_weight,1-slightly_weight,0.0],[0.0,1-slightly_weight,slightly_weight],]), np.array([list(enriched_color),list(null_color),list(depleted_color)]) ) - ct1 = 0.5 * (1 - min_log/max_log) - ct2 = 0.5 * (1 + min_log/max_log) - cdict = {'red': [[0.0, depleted_color[0], depleted_color[0]], - [ct1, slightly_depleted_color[0], null_color[0]], - [ct2, null_color[0], slightly_enriched_color[0]], - [1.0, enriched_color[0], enriched_color[0]]], - 'green': [[0.0, depleted_color[1], depleted_color[1]], - [ct1, slightly_depleted_color[1], null_color[1]], - [ct2, null_color[1], slightly_enriched_color[1]], - [1.0, enriched_color[1], enriched_color[1]]], - 'blue': [[0.0, depleted_color[2], depleted_color[2]], - [ct1, slightly_depleted_color[2], null_color[2]], - [ct2, null_color[2], slightly_enriched_color[2]], - [1.0, enriched_color[2], enriched_color[2]]]} - cmap = LinearSegmentedColormap('sigmap', segmentdata=cdict, N=256) - - fig = heatmap(enr.T, None, None, cmap=cmap, cmap_vmin_vmax=(-max_log,max_log), annotation=ann, colorbar=False, value_cluster=value_cluster, group_cluster=group_cluster, value_order=value_order, group_order=group_order, axsize=axsize, ax=ax); - - rel_dpi_factor = matplotlib.rcParams['figure.dpi'] / 72 + + if depleted_label is None: + ct2 = min_log/max_log + cdict = {'red': [[0.0, null_color[0], null_color[0]], + [ct2, null_color[0], slightly_enriched_color[0]], + [1.0, enriched_color[0], enriched_color[0]]], + 'green': [[0.0, null_color[1], null_color[1]], + [ct2, null_color[1], slightly_enriched_color[1]], + [1.0, enriched_color[1], enriched_color[1]]], + 'blue': [[0.0, null_color[2], null_color[2]], + [ct2, null_color[2], slightly_enriched_color[2]], + [1.0, enriched_color[2], enriched_color[2]]]} + cmap = LinearSegmentedColormap('sigmap', segmentdata=cdict, N=256) + + fig = heatmap(enr, None, None, cmap=cmap, cmap_vmin_vmax=(0.0,max_log), cmap_center=max_log/2, annotation=ann, colorbar=False, value_cluster=value_cluster, group_cluster=group_cluster, value_order=value_order, group_order=group_order, axsize=axsize, ax=ax); + + else: + ct1 = 0.5 * (1 - min_log/max_log) + ct2 = 0.5 * (1 + min_log/max_log) + cdict = {'red': [[0.0, depleted_color[0], depleted_color[0]], + [ct1, slightly_depleted_color[0], null_color[0]], + [ct2, null_color[0], slightly_enriched_color[0]], + [1.0, enriched_color[0], enriched_color[0]]], + 'green': [[0.0, depleted_color[1], depleted_color[1]], + [ct1, slightly_depleted_color[1], null_color[1]], + [ct2, null_color[1], slightly_enriched_color[1]], + [1.0, enriched_color[1], enriched_color[1]]], + 'blue': [[0.0, depleted_color[2], depleted_color[2]], + [ct1, slightly_depleted_color[2], null_color[2]], + [ct2, null_color[2], slightly_enriched_color[2]], + [1.0, enriched_color[2], enriched_color[2]]]} + cmap = LinearSegmentedColormap('sigmap', segmentdata=cdict, N=256) + + fig = heatmap(enr, None, None, cmap=cmap, cmap_vmin_vmax=(-max_log,max_log), cmap_center=0.0, annotation=ann, colorbar=False, value_cluster=value_cluster, group_cluster=group_cluster, value_order=value_order, group_order=group_order, axsize=axsize, ax=ax); + + rel_dpi_factor = fig.get_dpi() / 72 height_pxl = 200 * rel_dpi_factor * scale_legend width_pxl = 15 * rel_dpi_factor * scale_legend offset_top_pxl = 0 * rel_dpi_factor * scale_legend offset_left_pxl = 30 * rel_dpi_factor * scale_legend - ax = fig.axes[0] + if ax is None: + ax = fig.axes[0] left,bottom = fig.transFigure.inverted().transform(ax.transAxes.transform((1,1))+np.array([offset_left_pxl,-offset_top_pxl-height_pxl])) width,height = fig.transFigure.inverted().transform(fig.transFigure.transform((0,0))+np.array([width_pxl,height_pxl])) cax = fig.add_axes((left, bottom, width, height)) - norm = Normalize(vmin=-max_log, vmax=max_log) + norm = Normalize(vmin=(0.0 if depleted_label is None else -max_log), vmax=max_log) cb = fig.colorbar(ScalarMappable(norm=norm, cmap=cmap), cax=cax) - cb.set_ticks([-max_log,-min_log,min_log,max_log]) - cb.set_ticklabels([pmin,pmax,pmax,pmin]) + if depleted_label is None: + cb.set_ticks([min_log,max_log]) + cb.set_ticklabels([pmax,pmin]) + else: + cb.set_ticks([-max_log,-min_log,min_log,max_log]) + cb.set_ticklabels([pmin,pmax,pmax,pmin]) cb.ax.annotate('enriched', xy=(0, 1), xycoords='axes fraction', xytext=(-3, -5), textcoords='offset pixels', horizontalalignment='right', verticalalignment='top', rotation=90, fontsize=10*scale_legend) - cb.ax.annotate('insignificant', xy=(0, 0.5), xycoords='axes fraction', xytext=(-3, 5), textcoords='offset pixels', horizontalalignment='right', verticalalignment='center', rotation=90, fontsize=10*scale_legend) - cb.ax.annotate('depleted', xy=(0, 0), xycoords='axes fraction', xytext=(-3, 5), textcoords='offset pixels', horizontalalignment='right', verticalalignment='bottom', rotation=90, fontsize=10*scale_legend) + cb.ax.annotate('insignificant', xy=(0, (0.0 if depleted_label is None else 0.5)), xycoords='axes fraction', xytext=(-3, 5), textcoords='offset pixels', horizontalalignment='right', verticalalignment=('bottom' if depleted_label is None else 'center'), rotation=90, fontsize=10*scale_legend) + if depleted_label is not None: + cb.ax.annotate('depleted', xy=(0, 0), xycoords='axes fraction', xytext=(-3, 5), textcoords='offset pixels', horizontalalignment='right', verticalalignment='bottom', rotation=90, fontsize=10*scale_legend) return fig @@ -3000,6 +3067,11 @@ def _get_co_occurrence(adata, analysis_key, show_only, show_only_center, colors, raise ValueError(f'The `show_only` categories {[s for s in show_only if s not in annotation]!r} are not available in the data!') annotation = annotation[select] mean_scores = mean_scores[select,:,:] + # check if the order is the same as in the show_only selection + permutation = np.argsort(annotation)[np.argsort(np.argsort(show_only))] + if not np.all(permutation == np.arange(len(permutation))): + annotation = annotation[permutation] + mean_scores = mean_scores[permutation,:,:] if show_only_center is not None: if isinstance(show_only_center, str): @@ -3009,6 +3081,11 @@ def _get_co_occurrence(adata, analysis_key, show_only, show_only_center, colors, raise ValueError(f'The `show_only_center` categories {[s for s in show_only_center if s not in center]!r} are not available in the data!') center = center[select] mean_scores = mean_scores[:,select,:] + # check if the order is the same as in the show_only_center selection + permutation = np.argsort(center)[np.argsort(np.argsort(show_only_center))] + if not np.all(permutation == np.arange(len(permutation))): + center = center[permutation] + mean_scores = mean_scores[:,permutation,:] colors, types = _get_colors(colors, pd.Series(annotation)) @@ -3121,7 +3198,7 @@ def co_occurrence( provided in `adata`. ax The :class:`~matplotlib.axes.Axes` to plot on. If `None`, creates a - fresh figure for plotting. Incompatible with dendrogram plotting. + fresh figure for plotting. Returns ------- @@ -3144,20 +3221,18 @@ def co_occurrence( raise ValueError(f'`merged==True` is ony possible with up to {len(linestyles)} andatas!') - #if fig is None: - # fig, axs = subplots(len(center), 1, axsize=axsize, hspace=hspace, wspace=wspace, sharex=sharex, sharey=sharey) - - if ax is not None: - if isinstance(ax, matplotlib.axes.Axes): - axs = np.array([[ax]]) - else: - axs = ax - if axs.shape != (len(center), 1): - raise ValueError(f'The `ax` argument got the wrong shape of axes: needed is {(len(center), 1)!r} supplied was {axs.shape!r}!') - axsize = axs[0,0].get_window_extent().transformed(axs[0,0].get_figure().dpi_scale_trans.inverted()).size - fig = axs[0,0].get_figure() - else: - fig, axs = subplots(len(center), 1, axsize=axsize, hspace=hspace, wspace=wspace, sharex=sharex, sharey=sharey, ) + if fig is None: + if ax is not None: + if isinstance(ax, matplotlib.axes.Axes): + axs = np.array([[ax]]) + else: + axs = ax + if axs.shape != (len(center), 1): + raise ValueError(f'The `ax` argument got the wrong shape of axes: needed is {(len(center), 1)!r} supplied was {axs.shape!r}!') + axsize = axs[0,0].get_window_extent().transformed(axs[0,0].get_figure().dpi_scale_trans.inverted()).size + fig = axs[0,0].get_figure() + else: + fig, axs = subplots(len(center), 1, axsize=axsize, hspace=hspace, wspace=wspace, sharex=sharex, sharey=sharey, ) for ir, nr in enumerate(center): @@ -3179,20 +3254,18 @@ def co_occurrence( else: - #if fig is None: - # fig, axs = subplots(len(center), len(adatas), axsize=axsize, hspace=hspace, wspace=wspace, sharex=sharex, sharey=sharey) - - if ax is not None: - if isinstance(ax, matplotlib.axes.Axes): - axs = np.array([[ax]]) + if fig is None: + if ax is not None: + if isinstance(ax, matplotlib.axes.Axes): + axs = np.array([[ax]]) + else: + axs = ax + if axs.shape != (len(center), len(adatas)): + raise ValueError(f'The `ax` argument got the wrong shape of axes: needed is {(len(center), len(adatas))!r} supplied was {axs.shape!r}!') + axsize = axs[0,0].get_window_extent().transformed(axs[0,0].get_figure().dpi_scale_trans.inverted()).size + fig = axs[0,0].get_figure() else: - axs = ax - if axs.shape != (len(center), len(adatas)): - raise ValueError(f'The `ax` argument got the wrong shape of axes: needed is {(len(center), len(adatas))!r} supplied was {axs.shape!r}!') - axsize = axs[0,0].get_window_extent().transformed(axs[0,0].get_figure().dpi_scale_trans.inverted()).size - fig = axs[0,0].get_figure() - else: - fig, axs = subplots(len(center), len(adatas), axsize=axsize, hspace=hspace, wspace=wspace, sharex=sharex, sharey=sharey, ) + fig, axs = subplots(len(center), len(adatas), axsize=axsize, hspace=hspace, wspace=wspace, sharex=sharex, sharey=sharey, ) for ir, nr in enumerate(center): x = (intervals[1:] + intervals[:-1]) / 2 @@ -3332,18 +3405,6 @@ def co_occurrence_matrix( min_max = None - if ax is not None: - if isinstance(ax, matplotlib.axes.Axes): - axs = np.array([[ax]]) - else: - axs = ax - if axs.shape != (len(restrict_intervals), len(adatas)): - raise ValueError(f'The `ax` argument got the wrong shape of axes: needed is {(len(restrict_intervals), len(adatas))!r} supplied was {axs.shape!r}!') - axsize = axs[0,0].get_window_extent().transformed(axs[0,0].get_figure().dpi_scale_trans.inverted()).size - fig = axs[0,0].get_figure() - else: - fig, axs = subplots(len(restrict_intervals), len(adatas), axsize=axsize, hspace=hspace, wspace=wspace, x_padding=x_padding, y_padding=y_padding, ) - # first pass through the data to get global min/max of the values for colormap for adata_i, (adata_name, adata) in enumerate(adatas.items()): @@ -3366,6 +3427,18 @@ def co_occurrence_matrix( if axsize is None: axsize = (0.2*len(center),0.2*len(annotation)) + + if ax is not None: + if isinstance(ax, matplotlib.axes.Axes): + axs = np.array([[ax]]) + else: + axs = ax + if axs.shape != (len(adatas), len(restrict_intervals)): + raise ValueError(f'The `ax` argument got the wrong shape of axes: needed is {(len(adatas), len(restrict_intervals))!r} supplied was {axs.shape!r}!') + axsize = axs[0,0].get_window_extent().transformed(axs[0,0].get_figure().dpi_scale_trans.inverted()).size + fig = axs[0,0].get_figure() + else: + fig, axs = subplots(len(restrict_intervals), len(adatas), axsize=axsize, hspace=hspace, wspace=wspace, x_padding=x_padding, y_padding=y_padding, ) # second pass for actual plotting for adata_i, (adata_name, adata) in enumerate(adatas.items()): @@ -3603,7 +3676,7 @@ def annotated_heatmap( im = axs[1,1].imshow(data.T,aspect='auto',cmap=cmap, vmin=cmap_vmin_vmax[0], vmax=cmap_vmin_vmax[1]) axs[1,1].set_xticks([]) axs[1,1].set_yticks([]) - rel_dpi_factor = matplotlib.rcParams['figure.dpi'] / 72 + rel_dpi_factor = fig.get_dpi() / 72 cax_width = 100 * rel_dpi_factor # color bar width in pixel cax_height = 10 * rel_dpi_factor # color bar height in pixel cax_offset = 10 * rel_dpi_factor # color bar y offset in pixel @@ -3864,8 +3937,9 @@ def dotplot( log1p=True, marks=None, marks_colors=None, + swap_axes=True, ): - + """\ Dot plot of expression values. @@ -3888,46 +3962,59 @@ def dotplot( marks_colors A mapping from the categories in `marks` to colors; if `None`, default colors are used. + swap_axes + If `False`, the x axis contains the genes and the y axis the groups. + Otherwise the axes are swapped. Returns ------- A :class:`~matplotlib.figure.Figure`. """ - + if not pd.Index(genes).isin(adata.var.index).all(): raise ValueError(f'The genes {pd.Index(genes).difference(adata.var.index)!r} are not available in `adata.var`!') - + markers = genes[::-1] - + if group_key not in adata.obs.columns: raise ValueError(f'The `group_key` {group_key!r} is not available in `adata.obs`!') - + if hasattr(adata.obs[group_key], 'cat'): cluster = adata.obs[group_key].cat.categories else: cluster = adata.obs[group_key].unique() + + if swap_axes: + xticklabels = cluster + yticklabels = markers + else: + xticklabels = markers + yticklabels = cluster - fig,axs = subplots(axsize=0.25*np.array([len(cluster),len(markers)])) + fig,axs = subplots(axsize=0.25*np.array([len(xticklabels),len(yticklabels)])) - x = np.arange(len(cluster)) # the label locations + x = np.arange(len(xticklabels)) # the label locations + y = np.arange(len(yticklabels)) # the label locations + if not swap_axes: + x = x[::-1] + y = y[::-1] axs[0,0].set_xticks(x) - axs[0,0].set_xticklabels(cluster, rotation=45, ha='right',) - y = np.arange(len(markers)) # the label locations + axs[0,0].set_xticklabels(xticklabels, rotation=45, ha='right',) axs[0,0].set_yticks(y) - axs[0,0].set_yticklabels(markers) + axs[0,0].set_yticklabels(yticklabels) axs[0,0].set_xlim((x.min()-0.5,x.max()+0.5)) axs[0,0].set_ylim((y.min()-0.5,y.max()+0.5)) axs[0,0].set_axisbelow(True) axs[0,0].grid(True) - + marker_counts = adata[:,markers].to_df() if log1p: marker_counts = np.log1p(marker_counts) mean_exp = pd.DataFrame({c: marker_counts.loc[df.index].mean(axis=0) for c,df in adata.obs.groupby(group_key) }) mean_pos = pd.DataFrame({c: (marker_counts.loc[df.index] != 0).mean(axis=0) for c,df in adata.obs.groupby(group_key) }) - + if marks is not None: marks = marks.reindex_like(mean_pos) @@ -3935,7 +4022,7 @@ def dotplot( mean_pos_index_name = 'index' if mean_pos.index.name is None else mean_pos.index.name mean_exp = pd.melt(mean_exp, ignore_index=False).reset_index().rename(columns={mean_exp_index_name:'value','variable':'cluster','value':'mean_exp'}) mean_pos = pd.melt(mean_pos, ignore_index=False).reset_index().rename(columns={mean_pos_index_name:'value','variable':'cluster','value':'mean_pos'}) - + if marks is not None: marks.index.name = None marks.columns.name = None @@ -3945,21 +4032,27 @@ def dotplot( marks_colors = get_default_colors(marks['marks'].unique()) all_df = pd.merge(mean_exp, mean_pos, on=['value', 'cluster']) - + if marks is not None: all_df = pd.merge(all_df, marks, on=['value', 'cluster'], how='outer') - - all_df['x'] = all_df['cluster'].map(pd.Series(x,index=cluster)) - all_df['y'] = all_df['value'].map(pd.Series(y,index=markers)) + if all_df['marks'].isna().any(): + raise ValueError(f'There were gene-group combinations without a match in "marks"!') + + if swap_axes: + all_df['x'] = all_df['cluster'].map(pd.Series(x,index=xticklabels)) + all_df['y'] = all_df['value'].map(pd.Series(y,index=yticklabels)) + else: + all_df['x'] = all_df['value'].map(pd.Series(x,index=xticklabels)) + all_df['y'] = all_df['cluster'].map(pd.Series(y,index=yticklabels)) legend_items = [] - + mean_exp_min, mean_exp_max = all_df['mean_exp'].min(), all_df['mean_exp'].max() norm = Normalize(vmin=mean_exp_min, vmax=mean_exp_max) cmap='Reds'#LinearSegmentedColormap.from_list('mean_exp', [(0,(1, 1, 1)),(1,(1, g, b))]) mapper = ScalarMappable(norm=norm, cmap=cmap) color = [ tuple(x) for x in mapper.to_rgba(all_df['mean_exp'].to_numpy()) ] - + legend_items.append(mpatches.Patch(color='#0000', label='mean expression')) mean_exp_for_legend = np.linspace(mean_exp_min, mean_exp_max, 4) legend_items.extend([mpatches.Patch(color=color, label=f'{ind:.2f}') for color,ind in zip(mapper.to_rgba(mean_exp_for_legend),mean_exp_for_legend)]) @@ -3968,20 +4061,20 @@ def dotplot( def size_map(x): return (x/mean_pos_max * 14)**2 size = size_map(all_df['mean_pos']) - + legend_items.append(mpatches.Patch(color='#0000', label='fraction of expressing cells')) mean_pos_for_legend = np.linspace(mean_pos_min, mean_pos_max, 5)[1:] legend_items.extend([mlines.Line2D([], [], color='#aaa', linestyle='none', marker='o', markersize=np.sqrt(size_map(ind)), label=f'{ind:.2f}') for ind in mean_pos_for_legend]) edgecolors = '#aaa' if marks is None else all_df['marks'].map(marks_colors) - + if marks is not None: marks_name = marks_colors.name if hasattr(marks_colors, 'name') else '' legend_items.append(mpatches.Patch(color='#0000', label=marks_name)) legend_items.extend([mlines.Line2D([], [], color='#aaa', linestyle='none', fillstyle='none', markeredgecolor=color, marker='o', markersize=np.sqrt(size_map(mean_pos_for_legend[-2])), label=f'{ind}') for ind,color in marks_colors.items()]) axs[0,0].scatter(all_df['x'], all_df['y'], c=color, s=size, edgecolors=edgecolors) - + axs[0,0].legend(handles=legend_items, bbox_to_anchor=(1, 1), loc='upper left', ncol=1) - + return fig diff --git a/tacco/tools/__init__.py b/tacco/tools/__init__.py index 07b2ad7..3c36937 100644 --- a/tacco/tools/__init__.py +++ b/tacco/tools/__init__.py @@ -21,3 +21,4 @@ from ._wot import annotate_wot from ._SingleR import annotate_SingleR from ._goa import setup_goa_analysis, run_goa_analysis +from ._orthology import setup_orthology_converter, run_orthology_converter diff --git a/tacco/tools/_annotate.py b/tacco/tools/_annotate.py index a81be9f..83f653e 100644 --- a/tacco/tools/_annotate.py +++ b/tacco/tools/_annotate.py @@ -377,7 +377,7 @@ def _method(adata, reference, annotation_key, annotation_prior, verbose): nonlocal annotation_method, multi_center, prepare_reconstruction - if multi_center is None or multi_center < 1: + if multi_center is None or multi_center <= 1: cell_type = annotation_method(adata, reference, annotation_key, annotation_prior, verbose) if prepare_reconstruction is not None: prepare_reconstruction['annotation'] = cell_type.copy() @@ -419,6 +419,7 @@ def _method(adata, reference, annotation_key, annotation_prior, verbose): random_state=42, batch_size=100, max_iter=100, + n_init=3, # avoid FutureWarning about changing the default n_init to 'auto'; possible speedup by using 'auto' - needs evaluation ).fit(X) for c in range(_multi_center): @@ -459,6 +460,7 @@ def _method(adata, reference, annotation_key, annotation_prior, verbose): def max_annotation_method( annotation_method, max_annotation, + prepare_reconstruction ): """\ @@ -474,6 +476,13 @@ def max_annotation_method( the maximum annotation, higher values assign the top annotations and distribute the remaining annotations equally on the top annotations. If `None` or smaller than `1`, no restrictions are imposed. + prepare_reconstruction + This is an out-argument providing a dictionary to fill with the data + necessary for the reconstruction of "denoised" profiles. The necessary + data is a :class:`~pandas.DataFrame` containing the annotation on + sub-categories, another :class:`~pandas.DataFrame` containing the + profiles of the sub-categories, and a mapping of sub-categories to + their original categories. Returns ------- @@ -485,7 +494,7 @@ def max_annotation_method( def _method(adata, reference, annotation_key, annotation_prior, verbose): - nonlocal annotation_method, max_annotation + nonlocal annotation_method, max_annotation, prepare_reconstruction cell_type = annotation_method(adata, reference, annotation_key, annotation_prior, verbose) @@ -495,6 +504,15 @@ def _method(adata, reference, annotation_key, annotation_prior, verbose): _cell_type[_cell_type cuts).sum(axis=1) + + # then remove the points near the cuts symmetrically for each cut + dropped = np.any((sorted_points[:,None] > cuts - minimum_separation/2) & (sorted_points[:,None] < cuts + minimum_separation/2),axis=1) + + interval[dropped] = -1 + + interval = interval[reodering_indices] + + if check_splits: + # check some stats about the split + stats = pd.Series(interval).value_counts() + normed_stats = stats / stats.sum() + if len(normed_stats) <= 1: + print(f'WARNING: All points are assigned to the separations! Maybe minimum_separation or n_intervals is too large, or the number of points to small?') + else: + if normed_stats.iloc[-1] > 0.5: + print(f'WARNING: The fraction of points assigned to the separations is very high ({normed_stats.iloc[-1]})! Maybe minimum_separation or n_intervals is too large?') + if len(stats) < n_intervals + 1: + print(f'WARNING: At least one interval did not get any points assigned! Maybe minimum_separation or n_intervals is too large?') + + return interval + +def split_spatial_samples( + adata, + buffer_thickness, + position_key=('x','y'), + split_direction=None, + split_scheme=2, + sample_key=None, + result_key=None, + check_splits=True, +): + """\ + Splits a dataset into separated spatial patches. The patches are selected + to have approximately equal amounts of observations per patch. Between the + patches a buffer layer of specified thickness is discarded to reduce the + correlations between the patches. Therefore the thickness should be chosen + to accomodate the largest relevant correlation length. + + Parameters + ---------- + adata + An :class:`~anndata.AnnData` with annotation in `.obs` (and `.obsm`). + Can also be a :class:`~pandas.DataFrame` which is then used in place of + `.obs`. + buffer_thickness + The thickness of the buffer layer to discard between the spatial + patches. The units are the same as thosed used in the specification of + the position information. This should be chosen carefully, as the + remaining correlation between the spatial patches depends on it. + position_key + The `.obsm` key or array-like of `.obs` keys with the position space + coordinates. + split_direction + The direction(s) to use for the spatial splits. Use e.g. ('y','x') for + two splits with the first in 'y' and the second in 'x' direction, or + 'z' to always split along the 'z' direction. Use `None` instead of the + name of a coordinate direction to automatically determine the direction + as a flavor of direction with largest extent (first principal axis). + split_scheme + The specification of the strategy used for defining spatial splits. In + the simplest case this is just an integer specifying the number of + patches per split. It can also be a tuple to specify a different number + of patches per split. + sample_key + The `.obs` key with categorical sample information: every sample is + split separately. Can also be a :class:`~pandas.Series` containing the + sample information. If `None`, assume a single sample. + result_key + The `.obs` key to write the split sample annotation to. If `None`, + returns the split sample annotation as :class:`~pandas.Series`. + check_splits + Whether to warn about unusual split properties + + Returns + ------- + Depending on `result_key` returns either a :class:`~pandas.Series`\ + containing the split sample annotation or the input `adata` with the split\ + sample annotation written to `adata.obs[result_key]`. + + """ + + if isinstance(adata, ad.AnnData): + adata_obs = adata.obs + else: + adata_obs = adata + + if sample_key is None: + sample_column = pd.Series(np.full(shape=adata_obs.shape[0],fill_value=''),index=adata_obs.index) + elif isinstance(sample_key, pd.Series): + sample_column = sample_key.reindex(index=adata_obs.index) + elif sample_key in adata_obs: + if not hasattr(adata_obs[sample_key], 'cat'): + raise ValueError(f'`adata.obs[sample_key]` has to be categorical, but `adata.obs["{sample_key}"]` is not!') + sample_column = adata_obs[sample_key] + else: + raise ValueError(f'The `sample_key` argument is {sample_key!r} but has to be either a key of `adata.obs["{sample_key}"]`, a `pandas.Series` or `None`!') + + positions = get.positions(adata, position_key) + + # divide spatial samples spatially into subsamples: keeps all the correlation structure + ndim = positions.shape[1] + + # get consensus split plan from direction and scheme + split_direction_array = np.array(split_direction) # use array for checking dimensionality - but cast to lists to preserve None values... + split_scheme_array = np.array(split_scheme) + if (len(split_direction_array.shape) == 0) and (len(split_scheme_array.shape) == 0): + split_direction = [split_direction] + split_scheme = [split_scheme] + elif (len(split_direction_array.shape) == 0) and (len(split_scheme_array.shape) == 1): + split_direction = [split_direction] * len(split_scheme) + elif (len(split_direction_array.shape) == 1) and (len(split_scheme_array.shape) == 0): + split_scheme = [split_scheme] * len(split_direction) + elif (len(split_direction_array.shape) == 1) and (len(split_scheme_array.shape) == 1): + if len(split_direction) != len(split_scheme): + raise ValueError(f'The length of "split_direction" ({len(split_direction)}) does not fit to the length of "split_scheme" ({len(split_scheme)})!') + else: + raise ValueError(f'The "split_direction" and "split_scheme" must be of shape 0 or 1!') + + sample_column = sample_column.astype(str) + for direction,n_patches in zip(split_direction,split_scheme): + new_sample_column = sample_column.copy() + for sample, sub in positions.groupby(sample_column): + + # get direction vector + if direction is None: + direction_vector = get_first_principal_axis(sub) + else: + dir_loc = positions.columns.get_loc(direction) + if not isinstance(dir_loc, int): + raise ValueError(f'The direction "{direction}" is neither `None` nor does it correspond to a unique coordinate direction!') + direction_vector = np.array([0.0]*len(positions.columns)) + direction_vector[dir_loc] = 1.0 + + # project positions on the direction_vector + projections = sub @ direction_vector + + # find optimal division into patches + patches = get_balanced_separated_intervals(projections, n_patches, buffer_thickness, check_splits=check_splits) + + new_values = new_sample_column.loc[sub.index] + '|' + patches.astype(str) + new_values.loc[patches == -1] = None + new_sample_column.loc[sub.index] = new_values + + sample_column = new_sample_column + + result = sample_column.astype('category') + if result_key is not None: + adata_obs[result_key] = result + result = adata + + return result + def get_maximum_annotation(adata, obsm_key, result_key=None): """\ Turns a soft annotation into a categorical annotation by reporting the diff --git a/tests/test_annotation.py b/tests/test_annotation.py index c6cea4f..be8fb36 100644 --- a/tests/test_annotation.py +++ b/tests/test_annotation.py @@ -70,26 +70,25 @@ def adata_reference_and_typing1(): adata.obsm['annotation0']=pd.DataFrame(np.array([ [0.5,0.5], ]),index=adata.obs.index,columns=pd.Index(reference.obs['type'].cat.categories,dtype=reference.obs['type'].dtype)) - adata.obsm['annotation1']=pd.DataFrame(np.array([ + adata.obsm['annotation1']=adata.obsm['annotation0'] + adata.obsm['annotation2']=pd.DataFrame(np.array([ [0.5,0.5], ]),index=adata.obs.index,columns=pd.Index(['0-0', '1-0'])) - adata.obsm['annotation2']=adata.obsm['annotation1'] - adata.obsm['annotation10']=adata.obsm['annotation1'] + adata.obsm['annotation10']=adata.obsm['annotation2'] adata.varm['profiles0']=pd.DataFrame(np.array([ [1, 0 ], [0,2/3], [0,1/3], ]),index=adata.var.index,columns=pd.Index(reference.obs['type'].cat.categories.astype('category'))) - adata.varm['profiles1']=pd.DataFrame(np.array([ + adata.varm['profiles1']=adata.varm['profiles0'] + adata.varm['profiles2']=pd.DataFrame(np.array([ [1, 0 ], [0,2/3], [0,1/3], - ]),index=adata.var.index,columns=pd.Index(['0-0', '1-0'])) - adata.varm['profiles2']=adata.varm['profiles1'] - adata.varm['profiles10']=adata.varm['profiles1'] - adata.uns['mapping1']=pd.Series([0,1],index=['0-0', '1-0']).astype('category') - adata.uns['mapping2']=adata.uns['mapping1'] - adata.uns['mapping10']=adata.uns['mapping1'] + ]),index=adata.var.index,columns=pd.Index(['0', '1'])) + adata.varm['profiles10']=adata.varm['profiles2'] + adata.uns['mapping2']=pd.Series([0,1],index=['0-0', '1-0']).astype('category') + adata.uns['mapping10']=adata.uns['mapping2'] adata.uns['annotation_prior']=pd.Series(adata.obsm['type'].sum(axis=0).to_numpy(),index=adata.obsm['type'].columns) return ( adata, reference, ) @@ -154,20 +153,7 @@ def adata_reference_and_typing2(): [0.5,0.5,0], [0.5,0.5,0], ]),index=adata.obs.index,columns=pd.Index(reference.obs['type'].cat.categories,dtype=reference.obs['type'].dtype)) - adata.obsm['annotation1']=pd.DataFrame(np.array([ - [1,0,0], - [1,0,0], - [1,0,0], - [0,1,0], - [0,1,0], - [0,1,0], - [0,0,1], - [0,0,1], - [0,0,1], - [0.5,0.5,0], - [0.5,0.5,0], - [0.5,0.5,0], - ]),index=adata.obs.index,columns=pd.Index(['0-0', '1-0', '2-0'])) + adata.obsm['annotation1'] = adata.obsm['annotation0'] adata.obsm['annotation2']=pd.DataFrame(np.array([ [1/2,1/2, 0 , 0 , 0 , 0 ], [1/2,1/2, 0 , 0 , 0 , 0 ], @@ -210,20 +196,7 @@ def adata_reference_and_typing2(): [ 0 , 0 ,1/4], [ 0 , 0 ,1/4], ],dtype=np.float32),index=adata.var.index,columns=pd.Index(reference.obs['type'].cat.categories.astype('category'))) - adata.varm['profiles1']=pd.DataFrame(np.array([ - [1/4, 0 , 0 ], - [1/4, 0 , 0 ], - [1/4, 0 , 0 ], - [1/4, 0 , 0 ], - [ 0 ,1/4, 0 ], - [ 0 ,1/4, 0 ], - [ 0 ,1/4, 0 ], - [ 0 ,1/4, 0 ], - [ 0 , 0 ,1/4], - [ 0 , 0 ,1/4], - [ 0 , 0 ,1/4], - [ 0 , 0 ,1/4], - ],dtype=np.float32),index=adata.var.index,columns=pd.Index(['0-0', '1-0', '2-0'])) + adata.varm['profiles1']=adata.varm['profiles0'] adata.varm['profiles2']=pd.DataFrame(np.array([ [1/2, 0 , 0 , 0 , 0 , 0 ], [1/2, 0 , 0 , 0 , 0 , 0 ], @@ -252,7 +225,6 @@ def adata_reference_and_typing2(): [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,1/1, 0 ], [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,1/1], ],dtype=np.float32),index=adata.var.index,columns=pd.Index(['0-0', '0-1', '0-2', '0-3', '1-0', '1-1', '1-2', '1-3', '2-0', '2-1', '2-2', '2-3'])) - adata.uns['mapping1']=pd.Series([0,1,2],index=['0-0', '1-0', '2-0']).astype('category') adata.uns['mapping2']=pd.Series([0,0,1,1,2,2],index=['0-0', '0-1', '1-0', '1-1', '2-0', '2-1']).astype('category') adata.uns['mapping10']=pd.Series([0,0,0,0,1,1,1,1,2,2,2,2],index=['0-0', '0-1', '0-2', '0-3', '1-0', '1-1', '1-2', '1-3', '2-0', '2-1', '2-2', '2-3']).astype('category') adata.uns['annotation_prior']=pd.Series(adata.obsm['type'].sum(axis=0).to_numpy(),index=adata.obsm['type'].columns) @@ -350,7 +322,7 @@ def test_annotate_OT_multi_center(adata_reference_and_typing, dataset, multi_cen tc.testing.assert_index_equal(result_profiles.index, result_profiles.index, rtol=1e-14, atol=1e-14) tc.testing.assert_index_equal(result_profiles.columns, result_profiles.columns, rtol=1e-14, atol=1e-14) - if multi_center > 0: + if multi_center > 1: mapping = adata.uns[f'mapping{multi_center}'] result_mapping = adata.uns[reconstruction_key] tc.testing.assert_series_equal(result_mapping, mapping, rtol=1e-14, atol=1e-14) @@ -445,6 +417,7 @@ def test_annotate_NovoSpaRc(adata_reference_and_typing, dataset): tc.testing.assert_frame_equal(result, typing, rtol=1e-7, atol=3.4e-1) # exact marginal enforcement by optimal transport makes a better result impossible... @pytest.mark.parametrize('dataset', [0,1,2,]) +@pytest.mark.skipif(not tc.benchmarking._benchmarking.BENCHMARKING_AVAILABLE, reason='Benchmarking not available on this system') def test_benchmark_annotate(adata_reference_and_typing, dataset): adata, reference = adata_reference_and_typing[dataset] typing = adata.obsm['type'] diff --git a/tests/test_get.py b/tests/test_get.py new file mode 100644 index 0000000..00c6105 --- /dev/null +++ b/tests/test_get.py @@ -0,0 +1,231 @@ +import pytest +import numpy as np +import pandas as pd +import anndata as ad +import tacco as tc +import scipy.sparse + +@pytest.fixture(scope="session") +def adata_with_data(): + adata = ad.AnnData(X=np.array([ + [1,0,0], + [0,1,0], + [0,0,1], + [1,1,0], + ]),dtype=np.float32, obs=pd.DataFrame(index=['A','T','C','G'])) + adata.layers['sparse'] = scipy.sparse.csr_matrix(adata.X) + adata.layers['dense'] = adata.X + adata.var['type']=pd.Series(['a','b','c'],index=adata.var.index,name='type') + adata.var['blub']=pd.Series(['1','2','3'],index=adata.var.index,name='blub') + adata.varm['type']=pd.DataFrame(np.array([ + [1,0,0], + [0,1,0], + [0,0,1], + ]),index=adata.var.index,columns=['A','B','C']) + adata.obs['type']=pd.Series(['A','B','C','AB'],index=adata.obs.index,name='type') + adata.obs['blub']=pd.Series(['1','2','3','12'],index=adata.obs.index,name='blub') + adata.obsm['type']=pd.DataFrame(np.array([ + [1,0,0], + [0,1,0], + [0,0,1], + [0.5,0.5,0], + ]),index=adata.obs.index,columns=['A','B','C']) + adata.obsm['array']=np.array([ + [1,0,0], + [0,1,0], + [0,0,1], + [0.5,0.5,0], + ]) + adata.raw = adata.copy() + return adata + +def test_get_data_from_key(adata_with_data): + data = adata_with_data.obs['type'] + result = tc.get.data_from_key(adata_with_data, 'type') + tc.testing.assert_series_equal(result, data) + +def test_get_data_from_key__result_type_obs(adata_with_data): + data = adata_with_data.obs['type'] + result = tc.get.data_from_key(adata_with_data, 'type', result_type='obs') + tc.testing.assert_series_equal(result, data) + +def test_get_data_from_key__result_type_var(adata_with_data): + data = adata_with_data.var['type'] + result = tc.get.data_from_key(adata_with_data, 'type', result_type='var') + tc.testing.assert_series_equal(result, data) + +def test_get_data_from_key__result_type_obsm(adata_with_data): + data = pd.DataFrame(adata_with_data.obs['type']) + result = tc.get.data_from_key(adata_with_data, 'type', result_type='obsm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_varm(adata_with_data): + data = pd.DataFrame(adata_with_data.var['type']) + result = tc.get.data_from_key(adata_with_data, 'type', result_type='varm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_obsm__path(adata_with_data): + data = adata_with_data.obsm['type'] + result = tc.get.data_from_key(adata_with_data, ('obsm','type'), result_type='obsm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_varm__path(adata_with_data): + data = adata_with_data.varm['type'] + result = tc.get.data_from_key(adata_with_data, ('varm','type'), result_type='varm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_X__path(adata_with_data): + data = adata_with_data.X + result = tc.get.data_from_key(adata_with_data, ('X',), result_type='X') + tc.testing.assert_dense_equal(result, data) + +def test_get_data_from_key__result_type_X__path_layer(adata_with_data): + data = adata_with_data.layers['sparse'] + result = tc.get.data_from_key(adata_with_data, ('layer','sparse'), result_type='X') + tc.testing.assert_sparse_equal(result, data) + +def test_get_data_from_key__result_type_obs__path_layer_sparse(adata_with_data): + data = pd.Series(adata_with_data[:,'1'].layers['sparse'].toarray().flatten(), index=adata_with_data.obs.index, name='1') + result = tc.get.data_from_key(adata_with_data, ('layer','sparse','1'), result_type='obs') + tc.testing.assert_series_equal(result, data) + +def test_get_data_from_key__result_type_obs__path_layer_dense(adata_with_data): + data = pd.Series(adata_with_data[:,'1'].layers['dense'].flatten(), index=adata_with_data.obs.index, name='1') + result = tc.get.data_from_key(adata_with_data, ('layer','dense','1'), result_type='obs') + tc.testing.assert_series_equal(result, data) + +def test_get_data_from_key__result_type_obsm__path_layer_sparse(adata_with_data): + data = pd.DataFrame(adata_with_data[:,['1','2']].layers['sparse'].toarray(), index=adata_with_data.obs.index, columns=['1','2']) + result = tc.get.data_from_key(adata_with_data, ('layer','sparse',['1','2']), result_type='obsm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_obsm__path_layer_dense(adata_with_data): + data = pd.DataFrame(adata_with_data[:,['1','2']].layers['dense'], index=adata_with_data.obs.index, columns=['1','2']) + result = tc.get.data_from_key(adata_with_data, ('layer','dense',['1','2']), result_type='obsm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_var__path_layer_sparse(adata_with_data): + data = pd.Series(adata_with_data['T',:].layers['sparse'].toarray().flatten(), index=adata_with_data.var.index, name='T') + result = tc.get.data_from_key(adata_with_data, ('layer','sparse','T'), result_type='var') + tc.testing.assert_series_equal(result, data) + +def test_get_data_from_key__result_type_var__path_layer_dense(adata_with_data): + data = pd.Series(adata_with_data['T',:].layers['dense'].flatten(), index=adata_with_data.var.index, name='T') + result = tc.get.data_from_key(adata_with_data, ('layer','dense','T'), result_type='var') + tc.testing.assert_series_equal(result, data) + +def test_get_data_from_key__result_type_varm__path_layer_sparse(adata_with_data): + data = pd.DataFrame(adata_with_data[['T','G'],:].layers['sparse'].toarray().T, index=adata_with_data.var.index, columns=['T','G']) + result = tc.get.data_from_key(adata_with_data, ('layer','sparse',['T','G']), result_type='varm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_varm__path_layer_dense(adata_with_data): + data = pd.DataFrame(adata_with_data[['T','G'],:].layers['dense'].T, index=adata_with_data.var.index, columns=['T','G']) + result = tc.get.data_from_key(adata_with_data, ('layer','dense',['T','G']), result_type='varm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_obsm__path_obs(adata_with_data): + data = adata_with_data.obs[['blub','type']] + result = tc.get.data_from_key(adata_with_data, ('obs',['blub','type']), result_type='obsm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_varm__path_var(adata_with_data): + data = adata_with_data.var[['blub','type']] + result = tc.get.data_from_key(adata_with_data, ('var',['blub','type']), result_type='varm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_varm__path_raw_var(adata_with_data): + data = adata_with_data.raw.to_adata().var[['blub','type']] + result = tc.get.data_from_key(adata_with_data, ('raw','var',['blub','type']), result_type='varm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_obsm__array(adata_with_data): + data = pd.DataFrame(adata_with_data.obsm['array'], index=adata_with_data.obs.index) + result = tc.get.data_from_key(adata_with_data, 'array', result_type='obsm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__direct_frame_obs(adata_with_data): + data = adata_with_data.obs['type'] + result = tc.get.data_from_key(adata_with_data.obs, 'type') + tc.testing.assert_series_equal(result, data) + +def test_get_data_from_key__result_type_obs__direct_frame_obs(adata_with_data): + data = adata_with_data.obs['type'] + result = tc.get.data_from_key(adata_with_data.obs, 'type', result_type='obs') + tc.testing.assert_series_equal(result, data) + +def test_get_data_from_key__result_type_var__direct_frame_var(adata_with_data): + data = adata_with_data.var['type'] + result = tc.get.data_from_key(adata_with_data.var, 'type', result_type='var') + tc.testing.assert_series_equal(result, data) + +def test_get_data_from_key__result_type_obsm__direct_frame_obs(adata_with_data): + data = pd.DataFrame(adata_with_data.obs['type']) + result = tc.get.data_from_key(adata_with_data.obs, 'type', result_type='obsm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_varm__direct_frame_var(adata_with_data): + data = pd.DataFrame(adata_with_data.var['type']) + result = tc.get.data_from_key(adata_with_data.var, 'type', result_type='varm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_obsm__path_obs__direct_frame_obs(adata_with_data): + data = adata_with_data.obs[['blub','type']] + result = tc.get.data_from_key(adata_with_data.obs, ('obs',['blub','type']), result_type='obsm') + tc.testing.assert_frame_equal(result, data) + +def test_get_data_from_key__result_type_varm__path_var__direct_frame_var(adata_with_data): + data = adata_with_data.var[['blub','type']] + result = tc.get.data_from_key(adata_with_data.var, ('var',['blub','type']), result_type='varm') + tc.testing.assert_frame_equal(result, data) + + +def test_get_positions(adata_with_data): + data = adata_with_data.obsm['type'] + result = tc.get.positions(adata_with_data, 'type') + tc.testing.assert_frame_equal(result, data) + +def test_get_positions__array(adata_with_data): + data = pd.DataFrame(adata_with_data.obsm['array'], index=adata_with_data.obs.index) + result = tc.get.positions(adata_with_data, 'array') + tc.testing.assert_frame_equal(result, data) + +def test_get_positions__single_obs(adata_with_data): + data = pd.DataFrame(adata_with_data.obs['blub']) + result = tc.get.positions(adata_with_data, 'blub') + tc.testing.assert_frame_equal(result, data) + +def test_get_positions__multiple_obs(adata_with_data): + data = adata_with_data.obs[['blub','type']] + result = tc.get.positions(adata_with_data, ['blub','type']) + tc.testing.assert_frame_equal(result, data) + +def test_get_positions__multiple_obs__tuple(adata_with_data): + data = adata_with_data.obs[['blub','type']] + result = tc.get.positions(adata_with_data, ('blub','type')) + tc.testing.assert_frame_equal(result, data) + +def test_get_positions__single_obs__direct_frame_obs(adata_with_data): + data = pd.DataFrame(adata_with_data.obs['blub']) + result = tc.get.positions(adata_with_data.obs, 'blub') + tc.testing.assert_frame_equal(result, data) + +def test_get_positions__multiple_obs__direct_frame_obs(adata_with_data): + data = adata_with_data.obs[['blub','type']] + result = tc.get.positions(adata_with_data.obs, ['blub','type']) + tc.testing.assert_frame_equal(result, data) + +def test_get_positions__path_single_obs(adata_with_data): + data = pd.DataFrame(adata_with_data.obs['type']) + result = tc.get.positions(adata_with_data, ('obs','type')) + tc.testing.assert_frame_equal(result, data) + +def test_get_positions__path_multiple_obs(adata_with_data): + data = adata_with_data.obs[['blub','type']] + result = tc.get.positions(adata_with_data, ('obs',['blub','type'])) + tc.testing.assert_frame_equal(result, data) + +def test_get_positions__path_layer_sparse(adata_with_data): + data = pd.DataFrame(adata_with_data[:,['1','2']].layers['sparse'].toarray(), index=adata_with_data.obs.index, columns=['1','2']) + result = tc.get.positions(adata_with_data, ('layer','sparse',['1','2'])) + tc.testing.assert_frame_equal(result, data) diff --git a/tests/test_orthology.py b/tests/test_orthology.py new file mode 100644 index 0000000..b8fdebe --- /dev/null +++ b/tests/test_orthology.py @@ -0,0 +1,30 @@ +import pytest +import numpy as np +import pandas as pd +import anndata as ad +import tacco as tc +import scipy.sparse + +@pytest.fixture(scope="session") +def adata_to_convert(): + human_adata = ad.AnnData(scipy.sparse.csr_matrix(np.eye(4)), + var=pd.DataFrame(index=['ISG15','TP53','GSTM3','GSTM2']), + dtype=np.float32, + ) + mouse_adata = ad.AnnData(scipy.sparse.csr_matrix(np.eye(4)), + var=pd.DataFrame(index=['Isg15','Trp53','Gstm5','Gstm7']), + dtype=np.float32, + ) + return {'human':human_adata,'mouse':mouse_adata} + +def test_orthology_converter(adata_to_convert): + human_adata = adata_to_convert['human'] + mouse_adata = adata_to_convert['mouse'] + + result = tc.tools.run_orthology_converter(adata=human_adata, source_tax_id='human',target_tax_id='mouse', target_gene_symbols=mouse_adata.var.index) + + result = result[:,mouse_adata.var.index].copy() # avoid assertion error for comparing differently ordered .var.index + + result.X = scipy.sparse.csr_matrix(result.X) # avoid assertion error for comparing a csr matrix view with an actual csr matrix + + tc.testing.assert_adata_equal(result, mouse_adata) diff --git a/tests/test_spatial_split.py b/tests/test_spatial_split.py new file mode 100644 index 0000000..66329fd --- /dev/null +++ b/tests/test_spatial_split.py @@ -0,0 +1,62 @@ +import pytest +import numpy as np +import pandas as pd +import anndata as ad +import tacco as tc +import scipy.sparse + +@pytest.fixture(scope="session") +def spatial_adata_to_split(): + points = pd.DataFrame({ + 'x': pd.Series([ 0, 0, 10, 10, 20, 20, 30, 30, 40, 40, 50, 50],dtype=float), + 'y': pd.Series([ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],dtype=float), + 'split_0_N_N': pd.Series([ '|0', '|0', '|0', '|0', '|0', '|0', '|0', '|0', '|0', '|0', '|0', '|0']).astype('category'), + 'split_0_x_2': pd.Series([ '|0', '|0', '|0', '|0', '|0', '|0', '|1', '|1', '|1', '|1', '|1', '|1']).astype('category'), + 'split_0_x_3': pd.Series([ '|0', '|0', '|0', '|0', '|1', '|1', '|1', '|1', '|2', '|2', '|2', '|2']).astype('category'), + 'split_0_x_6': pd.Series([ '|0', '|0', '|1', '|1', '|2', '|2', '|3', '|3', '|4', '|4', '|5', '|5']).astype('category'), + 'split_0_y_2': pd.Series([ '|0', '|1', '|0', '|1', '|0', '|1', '|0', '|1', '|0', '|1', '|0', '|1']).astype('category'), + 'split_0_x_6_y_2': pd.Series(['|0|0','|0|1','|1|0','|1|1','|2|0','|2|1','|3|0','|3|1','|4|0','|4|1','|5|0','|5|1']).astype('category'), + 'split_25_x_2': pd.Series([ '|0', '|0', '|0', '|0', None, None, None, None, '|1', '|1', '|1', '|1']).astype('category'), + 'split_45_x_2': pd.Series([ '|0', '|0', None, None, None, None, None, None, None, None, '|1', '|1']).astype('category'), + }) + adata = ad.AnnData(np.zeros((len(points),0), dtype=np.int8), + obs=points, + dtype=np.int8, + ) + return adata + +@pytest.mark.parametrize('label_thickness_direction_scheme', [ + ('split_0_N_N',0,None,1), + ('split_0_N_N',0,'x',1), + ('split_0_N_N',0,'y',1), + ('split_0_x_2',0,None,2), + ('split_0_x_2',0,'x',2), + ('split_0_x_3',0,None,3), + ('split_0_x_3',0,'x',3), + ('split_0_x_6',0,None,6), + ('split_0_x_6',0,'x',6), + ('split_0_y_2',0,'y',2), + ('split_0_x_6_y_2',0,('x','y'),(6,2)), + ('split_0_x_6_y_2',0,None,(6,2)), + + ('split_0_N_N',5,None,1), + ('split_0_N_N',5,'x',1), + ('split_0_x_2',5,None,2), + ('split_0_x_2',5,'x',2), + ('split_0_x_3',5,None,3), + ('split_0_x_3',5,'x',3), + ('split_0_x_6',5,None,6), + ('split_0_x_6',5,'x',6), + + ('split_25_x_2',25,'x',2), + ('split_45_x_2',45,'x',2), +]) +def test_distance_matrix(spatial_adata_to_split, label_thickness_direction_scheme): + adata = spatial_adata_to_split + + label, thickness, direction, scheme = label_thickness_direction_scheme + + result = tc.utils.split_spatial_samples(adata, buffer_thickness=thickness, split_direction=direction, split_scheme=scheme) + + result.name = label + tc.testing.assert_series_equal(result, adata.obs[label])