From 2a84d4fbb26c0fae271d2e47c5bf95d6b4ea945d Mon Sep 17 00:00:00 2001 From: dr-david Date: Thu, 9 Nov 2023 14:29:03 +0100 Subject: [PATCH] likelihood fitting --- src/covvfit/__init__.py | 15 ++++- src/covvfit/_plotting.py | 59 +++++++++++++++++++ src/covvfit/_preprocess_abundances.py | 85 +++++++++++++++++++++++++++ 3 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 src/covvfit/_plotting.py create mode 100644 src/covvfit/_preprocess_abundances.py diff --git a/src/covvfit/__init__.py b/src/covvfit/__init__.py index 1a1a6c1..c5811cd 100644 --- a/src/covvfit/__init__.py +++ b/src/covvfit/__init__.py @@ -1,6 +1,19 @@ from covvfit._splines import create_spline_matrix +from covvfit._preprocess_abundances import make_data_list, preprocess_df, load_data +# from covvfit._frequentist import create_model_fixed, softmax, softmax_1, fitted_values, pred_values, compute_overdispersion, make_jacobian, project_se, make_ratediff_confints, make_fitness_confints +from covvfit._frequentist import * +from covvfit._plotting import * VERSION = "0.1.0" -__all__ = ["create_spline_matrix", "VERSION"] +__all__ = [ +"create_spline_matrix", +"make_data_list", +"preprocess_df", +"load_data", + "VERSION" + ] + +__all__ += _frequentist.__all__ +__all__ += _plotting.__all__ diff --git a/src/covvfit/_plotting.py b/src/covvfit/_plotting.py new file mode 100644 index 0000000..9dd54f8 --- /dev/null +++ b/src/covvfit/_plotting.py @@ -0,0 +1,59 @@ +"""utilities to plot""" +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +import numpy as np + +import matplotlib.dates as mdates +import matplotlib.ticker as ticker +import matplotlib.cm as cm +import matplotlib.pyplot as plt +import matplotlib.lines as mlines +import matplotlib.patches as mpatches + +__all__ = ['colors_covsp', 'make_legend', 'num_to_date'] + +colors_covsp = { + 'B.1.1.7': '#D16666', + 'B.1.351': '#FF6666', + 'P.1': '#FFB3B3', + 'B.1.617.1': '#A3FFD1', + 'B.1.617.2': '#66C266', + 'BA.1': '#A366A3', + 'BA.2': '#CFAFCF', + 'BA.4': '#8467F6', + 'BA.5': '#595EF6', + 'BA.2.75': '#DE9ADB', + 'BQ.1.1': '#8fe000', + 'XBB.1.9': '#dd6bff', + 'XBB.1.5': '#ff5656', + 'XBB.1.16': '#e99b30', + 'XBB.2.3': '#b4b82a', + 'EG.5': '#359f99', + 'BA.2.86': '#FF20E0', + 'JN.1': '#00e9ff', + 'undetermined': '#999696', +} + + +def make_legend(colors, variants): + """make a shared legend for the plot""" + # Create a patch (i.e., a colored box) for each variant + variant_patches = [mpatches.Patch(color=color, label=variants[i]) for i, color in enumerate(colors)] + + # Create lines for "fitted", "predicted", and "observed" labels + fitted_line = mlines.Line2D([], [], color='black', linestyle='-', label='fitted') + predicted_line = mlines.Line2D([], [], color='black', linestyle='--', label='predicted') + observed_points = mlines.Line2D([], [], color='black', marker='o', linestyle='None', label='daily estimates') + blank_line = mlines.Line2D([], [], color='white', linestyle='', label='') + + # Combine all the legend handles + handles = variant_patches + [blank_line, fitted_line, predicted_line, observed_points] + + return handles + +def num_to_date(num, pos=None, date_min='2023-01-01', fmt='%b. \'%y'): + """convert days number into a date format""" + date = pd.to_datetime(date_min) + pd.to_timedelta(num, "D") + return date.strftime(fmt) + diff --git a/src/covvfit/_preprocess_abundances.py b/src/covvfit/_preprocess_abundances.py new file mode 100644 index 0000000..e2ddd08 --- /dev/null +++ b/src/covvfit/_preprocess_abundances.py @@ -0,0 +1,85 @@ +"""utilities to preprocess relative abundances""" +import pandas as pd +import numpy as np + + +# variants = [ +# 'B.1.1.7', 'B.1.351', 'P.1', 'undetermined', +# 'B.1.617.2', 'BA.1', 'BA.2', 'BA.4', 'BA.5', 'BA.2.75', +# 'BQ.1.1', 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86" +# ] + +# variants2 = [ +# 'BA.4', 'BA.5', 'BA.2.75', +# 'BQ.1.1', 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86" +# ] + + +# variants3 = [ +# 'BA.2.75', 'BA.5', 'BQ.1.1', +# 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', +# 'BA.2.86', +# ] + +# variants4 = [ +# 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86" +# ] + +# variants5 = [ +# # 'B.1.1.7', 'B.1.351', 'P.1', 'undetermined', +# 'B.1.617.2', 'BA.1', 'BA.2', 'BA.4', 'BA.5', 'BA.2.75', +# 'BQ.1.1', 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86" +# ] + +# cities = ['Lugano (TI)', 'Zürich (ZH)', 'Chur (GR)', 'Altenrhein (SG)', +# 'Laupen (BE)', 'Genève (GE)', 'Basel (BS)', 'Porrentruy (JU)', +# 'Lausanne (VD)', 'Bern (BE)', 'Luzern (LU)', 'Solothurn (SO)', +# 'Neuchâtel (NE)', 'Schwyz (SZ)'] + +def load_data(file): + wwdat = pd.read_csv(file) + wwdat = wwdat.rename(columns={wwdat.columns[0]: 'time'}) + return wwdat + + +def preprocess_df( + df, + cities, + variants, + undertermined_thresh=0.01, + zero_date='2023-01-01', + date_min=None, + date_max=None, + ): + # Convert the 'time' column to datetime + df['time'] = pd.to_datetime(df['time']) + + # Remove days with too high undetermined + df = df[df['undetermined'] < undertermined_thresh] + + # Subset the 'BQ.1.1' column + df = df[['time', 'city'] + variants] + + # Subset only the specified cities + df = df[df['city'].isin(cities)] + + # Create a new column which is the difference in days between zero_date and the date + df['days_from'] = (df['time'] - pd.to_datetime(zero_date)).dt.days + + # Subset dates + if date_min is not None: + df = df[df['time'] >= pd.to_datetime(date_min)] + if date_max is not None: + df = df[df['time'] < pd.to_datetime(date_max)] + + + return df + + +def make_data_list(df, cities, variants): + ts_lst = [df[(df.city == city)].days_from.values for city in cities] + ys_lst = [df[(df.city == city)][variants].values.T for city in cities] + + return (ts_lst, ys_lst) + +