Skip to content

Commit

Permalink
Fix plotting for complex homozygous alleles (#8)
Browse files Browse the repository at this point in the history
* fixed plotting for homozyous complex allele

* added transparent violinplots for complex alleles
  • Loading branch information
xsitarcik authored Sep 11, 2023
1 parent 846ace3 commit dd32f55
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 31 deletions.
44 changes: 34 additions & 10 deletions src/genotyper/genotyping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Optional
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -14,23 +14,22 @@
from .plotter import plot_clustering_preds, plot_complex_repeats


def decode_alleles_complex(gmm_out_dict, df):
def decode_alleles_complex(gmm_out_dict: Dict[str, Any], df: pd.DataFrame):
if gmm_out_dict['is_hetero']:
df1 = df.loc[gmm_out_dict['group1']]
df2 = df.loc[gmm_out_dict['group2']]
mediangroup1 = []
mediangroup1: List[int] = []
for col in df1.columns:
if col != 'reverse':
mediangroup1.append(find_nearest(df1[col], np.median(df1[col])))
mediangroup1.append(int(find_nearest(df1[col].values, np.median(df1[col]))))
mediangroup1_cnt = len(gmm_out_dict['group1'])
mediangroup2 = []
mediangroup2: List[int] = []
for col in df2.columns:
if col != 'reverse':
mediangroup2.append(find_nearest(df2[col], np.median(df2[col])))
mediangroup2.append(int(find_nearest(df2[col].values, np.median(df2[col]))))
mediangroup2_cnt = len(gmm_out_dict['group2'])
else:
df1 = df.loc[gmm_out_dict['group1']]
df2 = None
mediangroup1 = []
for col in df1.columns:
if col != 'reverse':
Expand Down Expand Up @@ -119,7 +118,7 @@ def store_predictions(gt: Genotype, gt_bc: Optional[Genotype], locus_path: str):
print(f'Allele lengths as given by basecall: {gt_bc.alleles}')


def run_genotyping_complex(locus_path: str, df):
def run_genotyping_complex(locus_path: str, df: Union[pd.DataFrame, None]):
if df is None:
inpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_repeat_units.csv')
if os.path.isfile(inpath):
Expand Down Expand Up @@ -148,12 +147,37 @@ def run_genotyping_complex(locus_path: str, df):
out['group2'] = [idx for idx, g in zip(df.index, preds) if g == 1]
out['predictions'] = preds

alleles = decode_alleles_complex(out, df)
if out['predictions'] is not None:
df['allele'] = preds

alleles = decode_alleles_complex(out, df)
if out['predictions'] is not None:
homozygous = False
print('Genotyped complex repeats in 2 alleles:')
for idx, col in enumerate(cols):
print(f'Unit: {col:10} Repeats: {alleles[0][idx]:5} {alleles[2][idx]:5}', )
print(f'There were {alleles[1]} reads for allele1 and {alleles[3]} for allele2')

outpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_alleles.csv')
with open(outpath, 'w') as f:
f.write('unit,allele1_repeats,allele2_repeats\n')
for idx, col in enumerate(cols):
f.write(f'{col},{alleles[0][idx]},{alleles[2][idx]}\n')
else:
homozygous = True
print('Genotyped complex repeats in a homozygous allele:')
for idx, col in enumerate(cols):
print(f'Unit: {col:10} Repeats: {alleles[0][idx]:5} {alleles[2]:5}', )
print(f'There were {alleles[1]} reads for allele1 and {alleles[3]} for allele2')

outpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_alleles.csv')
with open(outpath, 'w') as f:
f.write('unit,allele1_repeats, allele2_repeats\n')
for idx, col in enumerate(cols):
f.write(f'{col},{alleles[0][idx]},{alleles[2]}\n')

img_path = os.path.join(locus_path, tmpl.SUMMARY_SUBDIR, 'complex_genotypes.svg')
plot_complex_repeats(df, cols, alleles, img_path)
plot_complex_repeats(df, cols, alleles, homozygous, img_path)


def run_genotyping(unfilt_vals: List[int]):
Expand Down
53 changes: 33 additions & 20 deletions src/genotyper/plotter.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,45 @@
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.collections import PolyCollection


def plot_complex_repeats(df, cols, alleles, img_path: str):
fig, axes = plt.subplots(nrows=len(cols), ncols=1, figsize=(8, 6*len(cols)))
def plot_complex_repeats(
df: pd.DataFrame,
cols: List[str],
alleles: Tuple[List[int], int, List[int], int],
homozygous: bool,
img_path: str
):
_, axes = plt.subplots(nrows=len(cols), ncols=1, figsize=(8, 6*len(cols)))
for idx, col in enumerate(cols):
ex_df = df
val1 = alleles[0][idx]
val2 = alleles[2][idx]
name = 'repeat numbers: '+str(val1)+','+str(val2)
ex_df[name] = ''
axes[idx] = sns.violinplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical',
split=False, scale='count', whis=np.inf, inner=None, ax=axes[idx])
plt.setp(axes[idx].collections, alpha=.3)
first = [r for r in axes[idx].get_children() if type(r) == PolyCollection]
c1 = first[0].get_facecolor()[0]
c2 = first[1].get_facecolor()[0]
if val1 != '-':
axes[idx].axhline(y=val1, color=c1, linestyle='--')
if val2 != '-':
axes[idx].axhline(y=val2, color=c2, linestyle='--')
axes[idx] = sns.stripplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical',
dodge=True, size=6, alpha=0.8, jitter=0.3, ax=axes[idx])
axes[idx].get_legend().remove()

if not homozygous:
val1 = alleles[0][idx]
val2 = alleles[2][idx]
name = 'repeat numbers: '+str(val1)+','+str(val2)
ex_df[name] = ''
axes[idx] = sns.violinplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical',
split=False, scale='count', whis=np.inf, inner=None, ax=axes[idx])
plt.setp(axes[idx].collections, alpha=.3)
axes[idx].axhline(y=val1, color='b', linestyle='--')
axes[idx].axhline(y=val2, color='b', linestyle='--')
axes[idx] = sns.stripplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical',
dodge=True, size=6, alpha=0.8, jitter=0.3, ax=axes[idx])
axes[idx].get_legend().remove()
else:
val1 = alleles[0][idx]
name = 'repeat numbers: '+str(val1)+',-'
ex_df[name] = ''
axes[idx] = sns.violinplot(data=ex_df, x=name, y=col, orient='vertical',
split=False, scale='count', whis=np.inf, inner=None, ax=axes[idx])
plt.setp(axes[idx].collections, alpha=.3)
axes[idx].axhline(y=val1, color='b', linestyle='--')
axes[idx] = sns.stripplot(data=ex_df, x=name, y=col, orient='vertical',
dodge=True, size=6, alpha=0.8, jitter=0.3, ax=axes[idx])

plt.savefig(img_path, bbox_inches='tight', format='svg')
plt.close()
Expand Down
2 changes: 1 addition & 1 deletion src/schemas/genotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def alleles(self):
return (self.first_allele, self.second_allele)


def find_nearest(array: List[int], value: float):
def find_nearest(array: List[int], value: float) -> float:
return array[(np.abs(np.asarray(array) - value)).argmin()]

0 comments on commit dd32f55

Please sign in to comment.