Skip to content

Commit

Permalink
fixing some issues with command line script
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthias Lienhard committed Dec 20, 2021
1 parent 27cf00e commit 5da7453
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 135 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
## TODO: ideas, issues and planed extensions or changes that are not yet implemented
* optimize add_qc_metrics for run after new samples have been added - should not recompute everything

## [0.2.8]
* fix: version information lost when pickeling reference.
* fix missing genen name
* added pt_size parameter to plot_embedding and plot_diff_results function
* added colors parameter to plotting functions


## [0.2.7]
* added command line script run_isotools.py
* added test data for unit tests
Expand Down
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.7
0.2.8rc17
118 changes: 95 additions & 23 deletions docs/notebooks/isotools_alzheimer.ipynb

Large diffs are not rendered by default.

71 changes: 32 additions & 39 deletions notebooks/test_data.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ umap-learn
sklearn
scipy
statsmodels
importlib-metadata
11 changes: 6 additions & 5 deletions src/isotools/_transcriptome_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,6 @@ def import_ref_transcripts(fn, transcriptome, file_format, chromosomes=None, gen
returns a dict interval trees for the genes'''
if gene_categories is None:
gene_categories = ['gene']
logger.info('importing annotation from %s', fn)
if file_format == 'gtf':
exons, transcripts, genes, gene_set, cds_start, cds_stop, skipped = _read_gtf_file(fn, transcriptome, chromosomes, **kwargs)
else: # gff/gff3
Expand Down Expand Up @@ -966,13 +965,12 @@ def transcript_table(self, samples=None, groups=None, coverage=False, tpm=Fals
Exports all transcript isoforms within region to a table.
:param samples: provide a list of samples for which coverage / expression information is added.
:param samples: provide a list of samples for which coverage / expression information is added.
:param groups: provide groups as a dict (as from Transcriptome.groups()), for which coverage / expression information is added.
:param coverage: If set, coverage information is added for specified samples / groups.
:param tpm: If set, expression information (in tpm) is added for specified samples / groups.
:param extra_columns: Specify the additional information added to the table.
These can be any transcrit property as defined by the key in the transcript dict.
:param region: Specify the region, either as (chr, start, end) tuple or as "chr:start-end" string.
If omitted specify the complete genome.
:param query: Specify transcript filter query.
Expand All @@ -997,8 +995,10 @@ def transcript_table(self, samples=None, groups=None, coverage=False, tpm=Fals
samples_set.update(*groups.values())
assert all(s in self.samples for s in samples_set), 'Not all specified samples are known'
if len(samples_set) == len(self.samples):
all_samples = True
sample_i = slice(None)
else:
all_samples = False
sample_i = [i for i, sa in enumerate(self.samples) if sa in samples_set]

if not isinstance(extra_columns, list):
Expand All @@ -1012,7 +1012,8 @@ def transcript_table(self, samples=None, groups=None, coverage=False, tpm=Fals
cov = []
for g, trids, trs in self.iter_transcripts(**filter_args, genewise=True):
if sample_i:
cov.append(g.coverage[sample_i, trids])
idx = (slice(None), trids) if all_samples else np.ix_(sample_i, trids)
cov.append(g.coverage[idx])
for trid, tr in zip(trids, trs):
exons = tr['exons']
trlen = sum(e[1]-e[0] for e in exons)
Expand All @@ -1030,7 +1031,7 @@ def transcript_table(self, samples=None, groups=None, coverage=False, tpm=Fals
df = pd.DataFrame(rows, columns=colnames)
if cov:
df_list = [df]
cov = pd.DataFrame(np.concatenate(cov, 1).T, columns=self.samples if isinstance(sample_i, slice) else [sa for i, sa in self.samples if i in sample_i])
cov = pd.DataFrame(np.concatenate(cov, 1).T, columns=self.samples if all_samples else [sa for i, sa in enumerate(self.samples) if i in sample_i])
stab = self.sample_table.set_index('name')
if samples:
if coverage:
Expand Down
77 changes: 54 additions & 23 deletions src/isotools/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
logger = logging.getLogger('isotools')


def plot_diff_results(result_table, min_support=3, min_diff=.1, grid_shape=(5, 5), min_cov=10, splice_types=None):
def plot_diff_results(result_table, min_support=3, min_diff=.1, grid_shape=(5, 5), min_cov=10, splice_types=None,
group_colors=None, sample_colors=None, pt_size=20, lw=1, ls='solid'):
'''Plots differential splicing results.
For the first (e.g. most significant) differential splicing events from result_table
Expand All @@ -24,6 +25,11 @@ def plot_diff_results(result_table, min_support=3, min_diff=.1, grid_shape=(5, 5
:param grid_shape: Number of rows and columns for the figure.
:param splice_type: Only events from the splecified splice_type(s) are depicted.
If omitted, all types are selected.
:param group_colors: Specify the colors for the groups (e.g. the lines) as a dict or list of length two.
:param sample_colors: Specify the colors for the samples (e.g. the dots) as a dict. Defaults to the corresponding group color.
:param pt_size: Specify the size for the data points in the plot.
:param lw: Specify witdh of the lines. See matplotlib Line2D for details.
:param ls: Specify style of the lines. See matplotlib Line2D for details.
:return: figure, axes and list of plotted events
'''

Expand All @@ -34,29 +40,41 @@ def plot_diff_results(result_table, min_support=3, min_diff=.1, grid_shape=(5, 5
axs = axs.flatten()
x = [i / 100 for i in range(101)]
group_names = [c[:-4] for c in result_table.columns if c.endswith('_PSI')][:2]
groups = {gn: [c[:-10] for c in result_table.columns if c.endswith(gn + '_total_cov')] for gn in group_names}
groups = {gn: [c[:c.find(gn)-1] for c in result_table.columns if c.endswith(gn + '_total_cov')] for gn in group_names}
if group_colors is None:
group_colors = [0, 1]
if isinstance(group_colors, list):
group_colors = dict(zip(group_names, group_colors))
if sample_colors is None:
sample_colors = {}
sample_colors = {sa: sample_colors.get(sa, group_colors[gn]) for gn in group_names for sa in groups[gn]}
other = {group_names[0]: group_names[1], group_names[1]: group_names[0]}
logger.debug('groups: %s', str(groups))
for idx, row in result_table.iterrows():
logger.debug(f'plotting {idx}: {row.gene}')
logger.debug('plotting %s: %s', idx, row.gene)
if splice_types is not None and row.splice_type not in splice_types:
continue
if row.gene in set(plotted.gene):
continue
params_alt = {gn: (row[f'{gn}_PSI'], row[f'{gn}_disp']) for gn in group_names}
# select only samples covered >= min_support
psi_gr = {gn: [row[f'{sa}_in_cov'] / row[f'{sa}_total_cov'] for sa in gr if row[f'{sa}_total_cov'] >= min_support] for gn, gr in groups.items()}
support = {s: sum(abs(i - params_alt[s][0]) < abs(i - params_alt[o][0]) for i in psi_gr[s]) for s, o in zip(group_names, reversed(group_names))}
# select only samples covered >= min_cov
# psi_gr = {gn: [row[f'{sa}_in_cov'] / row[f'{sa}_total_cov'] for sa in gr if row[f'{sa}_total_cov'] >= min_cov] for gn, gr in groups.items()}
psi_gr_list = [(sa, gn, row[f'{sa}_{gn}_in_cov'] / row[f'{sa}_{gn}_total_cov'])
for gn, gr in groups.items() for sa in gr if row[f'{sa}_{gn}_total_cov'] >= min_cov]
psi_gr = pd.DataFrame(psi_gr_list, columns=['sample', 'group', 'psi'])
psi_gr['support'] = [abs(sa.psi - params_alt[sa['group']][0]) < abs(sa.psi - params_alt[other[sa['group']]][0]) for i, sa in psi_gr.iterrows()]
support = dict(psi_gr.groupby('group')['support'].sum())
if any(sup < min_support for sup in support.values()):
logger.debug(f'skipping {row.gene} with {support} supporters')
logger.debug('skipping %s with %s supporters', row.gene, support)
continue
if abs(params_alt[group_names[0]][0] - params_alt[group_names[1]][0]) < min_diff:
logger.debug(f'{row.gene} with {"vs".join(str(p[0]) for p in params_alt.values())}')
logger.debug('%s with %s', row.gene, "vs".join(str(p[0]) for p in params_alt.values()))
continue
# get the paramters for the beta distiribution
# print(param)
ax = axs[len(plotted)]
# ax.boxplot([mut,wt], labels=['mut','wt'])
sns.swarmplot(data=pd.DataFrame(list(psi_gr.values()), index=psi_gr).T, ax=ax, orient='h')
sns.swarmplot(data=psi_gr, x='psi', y='group', hue='sample', orient='h', size=np.sqrt(pt_size), palette=sample_colors, ax=ax)
ax.legend([], [], frameon=False)
for i, gn in enumerate(group_names):
max_i = int(params_alt[gn][0] * (len(x) - 1))
ax2 = ax.twinx() # instantiate a second axes that shares the same x-axis
Expand All @@ -68,7 +86,7 @@ def plot_diff_results(result_table, min_support=3, min_diff=.1, grid_shape=(5, 5
else:
y = np.zeros(len(x))
y[max_i] = 1 # point mass
ax2.plot(x, y, color=f'C{i}')
ax2.plot(x, y, color=group_colors[gn], lw=lw, ls=ls)
ax2.tick_params(right=False, labelright=False)
ax.set_title(f'{row.gene} {row.splice_type}\nFDR={row.padj:.5f}')
plotted = plotted.append(row)
Expand All @@ -79,8 +97,8 @@ def plot_diff_results(result_table, min_support=3, min_diff=.1, grid_shape=(5, 5


def plot_embedding(splice_bubbles, method='PCA', prior_count=3,
top_var=500, min_total=100, min_alt_fraction=.1, plot_components=[1, 2],
splice_types='all', labels=True, groups=None, colors=None, ax=None, **kwargs):
top_var=500, min_total=100, min_alt_fraction=.1, plot_components=(1, 2),
splice_types='all', labels=True, groups=None, colors=None, pt_size=20, ax=None, **kwargs):
''' Plots embedding of alternative splicing events.
Alternative splicing events are soreted by variance and only the top variable events are used for the embedding.
Expand All @@ -99,6 +117,7 @@ def plot_embedding(splice_bubbles, method='PCA', prior_count=3,
:param groups: Set a group definition (e.g. by isoseq.Transcirptome.groups()) to color the datapoints.
All samples within one group are depicted.
:param colors: List or dict of colors for the groups, if ommited colors are generated automatically.
:param pt_size: Specify the size for the data points in the plot.
:param ax: The axis for plotting.
:param \\**kwargs: Additional keyword parameters are passed to PCA() or UMAP().
:return: A dataframe with the proportions of the alternative events, the transformed data and the embedding object.'''
Expand Down Expand Up @@ -150,7 +169,7 @@ def plot_embedding(splice_bubbles, method='PCA', prior_count=3,
k = k.loc[covered]
# compute the proportions
scaled_mean = k.sum(1) / n.sum(1) * prior_count
p = ((k.values + scaled_mean[:, np.newaxis]) / (n.values + prior_count)).T
p = ((k.values + scaled_mean.values[:, np.newaxis]) / (n.values + prior_count)).T
topvar = p[:, p.var(0).argsort()[-top_var:]] # sort from low to high var

# compute embedding
Expand All @@ -172,7 +191,7 @@ def plot_embedding(splice_bubbles, method='PCA', prior_count=3,
ax.scatter(
transformed.loc[sa, plot_components[0] - 1],
transformed.loc[sa, plot_components[1] - 1],
c=colors[gr], label=gr)
c=colors[gr], label=gr, s=pt_size)
ax.set(**axparams)
if labels:
for idx, (x, y) in transformed[plot_components - 1].iterrows():
Expand All @@ -182,7 +201,7 @@ def plot_embedding(splice_bubbles, method='PCA', prior_count=3,
# plots


def plot_bar(df, ax=None, drop_categories=None, legend=True, annotate=True, rot=90, bar_width=.5, **axparams):
def plot_bar(df, ax=None, drop_categories=None, legend=True, annotate=True, rot=90, bar_width=.5, colors=None, **axparams):
'''Depicts data as a barplot.
This function is intended to be called with the result from
Expand All @@ -195,6 +214,7 @@ def plot_bar(df, ax=None, drop_categories=None, legend=True, annotate=True, rot=
:param annotate: If True, print numbers / fractions in the bars.
:param rot: Set rotation of the lables.
:param bar_width: Set relative width of the plotted bars.
:param colors: Provide a dictionary with label keys and color values. By default, colors are automatically assigned.
:param \\**axparams: Additional keyword parameters are passed to ax.set().'''

if ax is None:
Expand All @@ -209,7 +229,7 @@ def plot_bar(df, ax=None, drop_categories=None, legend=True, annotate=True, rot=
dcat = []
else:
dcat = [d for d in drop_categories if d in df.index]
fractions.drop(dcat).plot.bar(ax=ax, legend=legend, width=bar_width, rot=rot)
fractions.drop(dcat).plot.bar(ax=ax, legend=legend, width=bar_width, rot=rot, color=colors)
# add numbers
if annotate:
numbers = [int(v) for c in df.drop(dcat).T.values for v in c]
Expand All @@ -225,7 +245,7 @@ def plot_bar(df, ax=None, drop_categories=None, legend=True, annotate=True, rot=
return ax


def plot_distr(counts, ax=None, density=False, smooth=None, legend=True, fill=True, **axparams):
def plot_distr(counts, ax=None, density=False, smooth=None, legend=True, fill=True, lw=1, ls='solid', colors=None, **axparams):
'''Depicts data as density plot.
This function is intended to be called with the result from
Expand All @@ -239,10 +259,15 @@ def plot_distr(counts, ax=None, density=False, smooth=None, legend=True, fill=Tr
:param smooth: Ews smoothing span.
:param legend: If True, add a legend.
:param fill: If set, the area below the lines are filled with half transparent color.
:param lw: Specify witdh of the lines. See matplotlib Line2D for details.
:param ls: Specify style of the lines. See matplotlib Line2D for details.
:param colors: Provide a dictionary with label keys and color values. By default, colors are automatically assigned.
:param \\**axparams: Additional keyword parameters are passed to ax.set().'''
# maybe add smoothing
x = [sum(bin) / 2 for bin in counts.index]
sz = [bin[1] - bin[0] for bin in counts.index]
if colors is None:
colors = {}
if ax is None:
_, ax = plt.subplots()
if density:
Expand All @@ -256,9 +281,9 @@ def plot_distr(counts, ax=None, density=False, smooth=None, legend=True, fill=Tr
if smooth:
counts = counts.ewm(span=smooth).mean()
for gn, gc in counts.items():
ax.plot(x, gc / sz, label=gn)
lines = ax.plot(x, gc / sz, label=gn, color=colors.get(gn, None), lw=lw, ls=ls)
if fill:
ax.fill_between(x, 0, gc / sz, alpha=.5)
ax.fill_between(x, 0, gc / sz, alpha=.5, color=lines[-1].get_color())
# ax.plot(x, counts.divide(sz, axis=0))
ax.set(**axparams)
if legend:
Expand Down Expand Up @@ -294,7 +319,7 @@ def plot_saturation(isoseq=None, ax=None, cov_th=2, expr_th=[.5, 1, 2, 5, 10], x
chance = nbinom.cdf(k - cov_th, n=cov_th, p=tpm_th * 1e-6) # 0 to k-cov_th failiors
ax.plot(k / 1e6, chance, label=f'{tpm_th} TPM')
for sa, cov in n_reads.items():
ax.axvline(cov / 1e6, color='grey', linestyle='--')
ax.axvline(cov / 1e6, color='grey', ls='--')
if label:
ax.text((cov + (k[-1] - k[0]) / 200) / 1e6, 0.1, f'{sa} ({cov/1e6:.2f} M)', rotation=-90)
ax.set(**axparams)
Expand All @@ -304,18 +329,24 @@ def plot_saturation(isoseq=None, ax=None, cov_th=2, expr_th=[.5, 1, 2, 5, 10], x
return ax


def plot_rarefaction(rarefaction, total=None, ax=None, legend=True, **axparams):
def plot_rarefaction(rarefaction, total=None, ax=None, legend=True, colors=None, lw=1, ls='solid', **axparams):
'''Plots the rarefaction curve.
:param rarefaction: A DataFrame with the observed number of transcripts, as computed by Transcriptome.rarefaction().
:param total: A dictionary with the total number of reads per sample/sample group, as computed by Transcriptome.rarefaction().
:param ax: The axis for the plot.
:param legend: If set True, a legend is added to the plot.
:param colors: Provide a dictionary with label keys and color values. By default, colors are automatically assigned.
:param lw: Specify witdh of the lines. See matplotlib Line2D for details.
:param ls: Specify style of the lines. See matplotlib Line2D for details.
:param \\**axparams: Additional keyword parameters are passed to ax.set().'''
if ax is None:
_, ax = plt.subplots()
if colors is None:
colors = {}
for sa in rarefaction.columns:
ax.plot([float(f) * total[sa] / 1e6 if total is not None else float(f)*100 for f in rarefaction.index], rarefaction[sa], label=sa)
ax.plot([float(f) * total[sa] / 1e6 if total is not None else float(f)*100 for f in rarefaction.index], rarefaction[sa],
label=sa, ls=ls, lw=lw, color=colors.get(sa, None))

axparams.setdefault('title', 'Rarefaction Analysis') # [nr],{'fontsize':20}, loc='left', pad=10)
axparams.setdefault('ylabel', 'Number of discovered Transcripts')
Expand Down
Loading

0 comments on commit 5da7453

Please sign in to comment.