-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
158 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,19 @@ | ||
from covvfit._splines import create_spline_matrix | ||
from covvfit._preprocess_abundances import make_data_list, preprocess_df, load_data | ||
# from covvfit._frequentist import create_model_fixed, softmax, softmax_1, fitted_values, pred_values, compute_overdispersion, make_jacobian, project_se, make_ratediff_confints, make_fitness_confints | ||
from covvfit._frequentist import * | ||
from covvfit._plotting import * | ||
|
||
VERSION = "0.1.0" | ||
|
||
|
||
__all__ = ["create_spline_matrix", "VERSION"] | ||
__all__ = [ | ||
"create_spline_matrix", | ||
"make_data_list", | ||
"preprocess_df", | ||
"load_data", | ||
"VERSION" | ||
] | ||
|
||
__all__ += _frequentist.__all__ | ||
__all__ += _plotting.__all__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""utilities to plot""" | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
import numpy as np | ||
|
||
import matplotlib.dates as mdates | ||
import matplotlib.ticker as ticker | ||
import matplotlib.cm as cm | ||
import matplotlib.pyplot as plt | ||
import matplotlib.lines as mlines | ||
import matplotlib.patches as mpatches | ||
|
||
__all__ = ['colors_covsp', 'make_legend', 'num_to_date'] | ||
|
||
colors_covsp = { | ||
'B.1.1.7': '#D16666', | ||
'B.1.351': '#FF6666', | ||
'P.1': '#FFB3B3', | ||
'B.1.617.1': '#A3FFD1', | ||
'B.1.617.2': '#66C266', | ||
'BA.1': '#A366A3', | ||
'BA.2': '#CFAFCF', | ||
'BA.4': '#8467F6', | ||
'BA.5': '#595EF6', | ||
'BA.2.75': '#DE9ADB', | ||
'BQ.1.1': '#8fe000', | ||
'XBB.1.9': '#dd6bff', | ||
'XBB.1.5': '#ff5656', | ||
'XBB.1.16': '#e99b30', | ||
'XBB.2.3': '#b4b82a', | ||
'EG.5': '#359f99', | ||
'BA.2.86': '#FF20E0', | ||
'JN.1': '#00e9ff', | ||
'undetermined': '#999696', | ||
} | ||
|
||
|
||
def make_legend(colors, variants): | ||
"""make a shared legend for the plot""" | ||
# Create a patch (i.e., a colored box) for each variant | ||
variant_patches = [mpatches.Patch(color=color, label=variants[i]) for i, color in enumerate(colors)] | ||
|
||
# Create lines for "fitted", "predicted", and "observed" labels | ||
fitted_line = mlines.Line2D([], [], color='black', linestyle='-', label='fitted') | ||
predicted_line = mlines.Line2D([], [], color='black', linestyle='--', label='predicted') | ||
observed_points = mlines.Line2D([], [], color='black', marker='o', linestyle='None', label='daily estimates') | ||
blank_line = mlines.Line2D([], [], color='white', linestyle='', label='') | ||
|
||
# Combine all the legend handles | ||
handles = variant_patches + [blank_line, fitted_line, predicted_line, observed_points] | ||
|
||
return handles | ||
|
||
def num_to_date(num, pos=None, date_min='2023-01-01', fmt='%b. \'%y'): | ||
"""convert days number into a date format""" | ||
date = pd.to_datetime(date_min) + pd.to_timedelta(num, "D") | ||
return date.strftime(fmt) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""utilities to preprocess relative abundances""" | ||
import pandas as pd | ||
import numpy as np | ||
|
||
|
||
# variants = [ | ||
# 'B.1.1.7', 'B.1.351', 'P.1', 'undetermined', | ||
# 'B.1.617.2', 'BA.1', 'BA.2', 'BA.4', 'BA.5', 'BA.2.75', | ||
# 'BQ.1.1', 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86" | ||
# ] | ||
|
||
# variants2 = [ | ||
# 'BA.4', 'BA.5', 'BA.2.75', | ||
# 'BQ.1.1', 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86" | ||
# ] | ||
|
||
|
||
# variants3 = [ | ||
# 'BA.2.75', 'BA.5', 'BQ.1.1', | ||
# 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', | ||
# 'BA.2.86', | ||
# ] | ||
|
||
# variants4 = [ | ||
# 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86" | ||
# ] | ||
|
||
# variants5 = [ | ||
# # 'B.1.1.7', 'B.1.351', 'P.1', 'undetermined', | ||
# 'B.1.617.2', 'BA.1', 'BA.2', 'BA.4', 'BA.5', 'BA.2.75', | ||
# 'BQ.1.1', 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86" | ||
# ] | ||
|
||
# cities = ['Lugano (TI)', 'Zürich (ZH)', 'Chur (GR)', 'Altenrhein (SG)', | ||
# 'Laupen (BE)', 'Genève (GE)', 'Basel (BS)', 'Porrentruy (JU)', | ||
# 'Lausanne (VD)', 'Bern (BE)', 'Luzern (LU)', 'Solothurn (SO)', | ||
# 'Neuchâtel (NE)', 'Schwyz (SZ)'] | ||
|
||
def load_data(file): | ||
wwdat = pd.read_csv(file) | ||
wwdat = wwdat.rename(columns={wwdat.columns[0]: 'time'}) | ||
return wwdat | ||
|
||
|
||
def preprocess_df( | ||
df, | ||
cities, | ||
variants, | ||
undertermined_thresh=0.01, | ||
zero_date='2023-01-01', | ||
date_min=None, | ||
date_max=None, | ||
): | ||
# Convert the 'time' column to datetime | ||
df['time'] = pd.to_datetime(df['time']) | ||
|
||
# Remove days with too high undetermined | ||
df = df[df['undetermined'] < undertermined_thresh] | ||
|
||
# Subset the 'BQ.1.1' column | ||
df = df[['time', 'city'] + variants] | ||
|
||
# Subset only the specified cities | ||
df = df[df['city'].isin(cities)] | ||
|
||
# Create a new column which is the difference in days between zero_date and the date | ||
df['days_from'] = (df['time'] - pd.to_datetime(zero_date)).dt.days | ||
|
||
# Subset dates | ||
if date_min is not None: | ||
df = df[df['time'] >= pd.to_datetime(date_min)] | ||
if date_max is not None: | ||
df = df[df['time'] < pd.to_datetime(date_max)] | ||
|
||
|
||
return df | ||
|
||
|
||
def make_data_list(df, cities, variants): | ||
ts_lst = [df[(df.city == city)].days_from.values for city in cities] | ||
ys_lst = [df[(df.city == city)][variants].values.T for city in cities] | ||
|
||
return (ts_lst, ys_lst) | ||
|
||
|