diff --git a/src/genotyper/genotyping.py b/src/genotyper/genotyping.py index 29f0057..3bbc10e 100644 --- a/src/genotyper/genotyping.py +++ b/src/genotyper/genotyping.py @@ -1,5 +1,5 @@ import os -from typing import List, Optional +from typing import Any, Dict, List, Optional, Union import numpy as np import pandas as pd @@ -14,23 +14,22 @@ from .plotter import plot_clustering_preds, plot_complex_repeats -def decode_alleles_complex(gmm_out_dict, df): +def decode_alleles_complex(gmm_out_dict: Dict[str, Any], df: pd.DataFrame): if gmm_out_dict['is_hetero']: df1 = df.loc[gmm_out_dict['group1']] df2 = df.loc[gmm_out_dict['group2']] - mediangroup1 = [] + mediangroup1: List[int] = [] for col in df1.columns: if col != 'reverse': - mediangroup1.append(find_nearest(df1[col], np.median(df1[col]))) + mediangroup1.append(int(find_nearest(df1[col].values, np.median(df1[col])))) mediangroup1_cnt = len(gmm_out_dict['group1']) - mediangroup2 = [] + mediangroup2: List[int] = [] for col in df2.columns: if col != 'reverse': - mediangroup2.append(find_nearest(df2[col], np.median(df2[col]))) + mediangroup2.append(int(find_nearest(df2[col].values, np.median(df2[col])))) mediangroup2_cnt = len(gmm_out_dict['group2']) else: df1 = df.loc[gmm_out_dict['group1']] - df2 = None mediangroup1 = [] for col in df1.columns: if col != 'reverse': @@ -119,7 +118,7 @@ def store_predictions(gt: Genotype, gt_bc: Optional[Genotype], locus_path: str): print(f'Allele lengths as given by basecall: {gt_bc.alleles}') -def run_genotyping_complex(locus_path: str, df): +def run_genotyping_complex(locus_path: str, df: Union[pd.DataFrame, None]): if df is None: inpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_repeat_units.csv') if os.path.isfile(inpath): @@ -148,12 +147,37 @@ def run_genotyping_complex(locus_path: str, df): out['group2'] = [idx for idx, g in zip(df.index, preds) if g == 1] out['predictions'] = preds - alleles = decode_alleles_complex(out, df) if out['predictions'] is not None: df['allele'] = preds + alleles = decode_alleles_complex(out, df) + if out['predictions'] is not None: + homozygous = False + print('Genotyped complex repeats in 2 alleles:') + for idx, col in enumerate(cols): + print(f'Unit: {col:10} Repeats: {alleles[0][idx]:5} {alleles[2][idx]:5}', ) + print(f'There were {alleles[1]} reads for allele1 and {alleles[3]} for allele2') + + outpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_alleles.csv') + with open(outpath, 'w') as f: + f.write('unit,allele1_repeats,allele2_repeats\n') + for idx, col in enumerate(cols): + f.write(f'{col},{alleles[0][idx]},{alleles[2][idx]}\n') + else: + homozygous = True + print('Genotyped complex repeats in a homozygous allele:') + for idx, col in enumerate(cols): + print(f'Unit: {col:10} Repeats: {alleles[0][idx]:5} {alleles[2]:5}', ) + print(f'There were {alleles[1]} reads for allele1 and {alleles[3]} for allele2') + + outpath = os.path.join(locus_path, tmpl.PREDICTIONS_SUBDIR, tmpl.COMPLEX_SUBDIR, 'complex_alleles.csv') + with open(outpath, 'w') as f: + f.write('unit,allele1_repeats, allele2_repeats\n') + for idx, col in enumerate(cols): + f.write(f'{col},{alleles[0][idx]},{alleles[2]}\n') + img_path = os.path.join(locus_path, tmpl.SUMMARY_SUBDIR, 'complex_genotypes.svg') - plot_complex_repeats(df, cols, alleles, img_path) + plot_complex_repeats(df, cols, alleles, homozygous, img_path) def run_genotyping(unfilt_vals: List[int]): diff --git a/src/genotyper/plotter.py b/src/genotyper/plotter.py index 1ed1efb..f3ff58d 100644 --- a/src/genotyper/plotter.py +++ b/src/genotyper/plotter.py @@ -1,32 +1,45 @@ from typing import List, Optional, Tuple import numpy as np +import pandas as pd import seaborn as sns from matplotlib import pyplot as plt -from matplotlib.collections import PolyCollection -def plot_complex_repeats(df, cols, alleles, img_path: str): - fig, axes = plt.subplots(nrows=len(cols), ncols=1, figsize=(8, 6*len(cols))) +def plot_complex_repeats( + df: pd.DataFrame, + cols: List[str], + alleles: Tuple[List[int], int, List[int], int], + homozygous: bool, + img_path: str +): + _, axes = plt.subplots(nrows=len(cols), ncols=1, figsize=(8, 6*len(cols))) for idx, col in enumerate(cols): ex_df = df - val1 = alleles[0][idx] - val2 = alleles[2][idx] - name = 'repeat numbers: '+str(val1)+','+str(val2) - ex_df[name] = '' - axes[idx] = sns.violinplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical', - split=False, scale='count', whis=np.inf, inner=None, ax=axes[idx]) - plt.setp(axes[idx].collections, alpha=.3) - first = [r for r in axes[idx].get_children() if type(r) == PolyCollection] - c1 = first[0].get_facecolor()[0] - c2 = first[1].get_facecolor()[0] - if val1 != '-': - axes[idx].axhline(y=val1, color=c1, linestyle='--') - if val2 != '-': - axes[idx].axhline(y=val2, color=c2, linestyle='--') - axes[idx] = sns.stripplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical', - dodge=True, size=6, alpha=0.8, jitter=0.3, ax=axes[idx]) - axes[idx].get_legend().remove() + + if not homozygous: + val1 = alleles[0][idx] + val2 = alleles[2][idx] + name = 'repeat numbers: '+str(val1)+','+str(val2) + ex_df[name] = '' + axes[idx] = sns.violinplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical', + split=False, scale='count', whis=np.inf, inner=None, ax=axes[idx]) + plt.setp(axes[idx].collections, alpha=.3) + axes[idx].axhline(y=val1, color='b', linestyle='--') + axes[idx].axhline(y=val2, color='b', linestyle='--') + axes[idx] = sns.stripplot(data=ex_df, x=name, y=col, hue='allele', orient='vertical', + dodge=True, size=6, alpha=0.8, jitter=0.3, ax=axes[idx]) + axes[idx].get_legend().remove() + else: + val1 = alleles[0][idx] + name = 'repeat numbers: '+str(val1)+',-' + ex_df[name] = '' + axes[idx] = sns.violinplot(data=ex_df, x=name, y=col, orient='vertical', + split=False, scale='count', whis=np.inf, inner=None, ax=axes[idx]) + plt.setp(axes[idx].collections, alpha=.3) + axes[idx].axhline(y=val1, color='b', linestyle='--') + axes[idx] = sns.stripplot(data=ex_df, x=name, y=col, orient='vertical', + dodge=True, size=6, alpha=0.8, jitter=0.3, ax=axes[idx]) plt.savefig(img_path, bbox_inches='tight', format='svg') plt.close() diff --git a/src/schemas/genotype.py b/src/schemas/genotype.py index 5392b67..09df956 100644 --- a/src/schemas/genotype.py +++ b/src/schemas/genotype.py @@ -35,5 +35,5 @@ def alleles(self): return (self.first_allele, self.second_allele) -def find_nearest(array: List[int], value: float): +def find_nearest(array: List[int], value: float) -> float: return array[(np.abs(np.asarray(array) - value)).argmin()]