Skip to content

Commit

Permalink
likelihood fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
dr-david committed Nov 9, 2023
1 parent 6f8bb82 commit 2a84d4f
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/covvfit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
from covvfit._splines import create_spline_matrix
from covvfit._preprocess_abundances import make_data_list, preprocess_df, load_data
# from covvfit._frequentist import create_model_fixed, softmax, softmax_1, fitted_values, pred_values, compute_overdispersion, make_jacobian, project_se, make_ratediff_confints, make_fitness_confints
from covvfit._frequentist import *
from covvfit._plotting import *

VERSION = "0.1.0"


__all__ = ["create_spline_matrix", "VERSION"]
__all__ = [
"create_spline_matrix",
"make_data_list",
"preprocess_df",
"load_data",
"VERSION"
]

__all__ += _frequentist.__all__
__all__ += _plotting.__all__
59 changes: 59 additions & 0 deletions src/covvfit/_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""utilities to plot"""
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import matplotlib.dates as mdates
import matplotlib.ticker as ticker
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.patches as mpatches

__all__ = ['colors_covsp', 'make_legend', 'num_to_date']

colors_covsp = {
'B.1.1.7': '#D16666',
'B.1.351': '#FF6666',
'P.1': '#FFB3B3',
'B.1.617.1': '#A3FFD1',
'B.1.617.2': '#66C266',
'BA.1': '#A366A3',
'BA.2': '#CFAFCF',
'BA.4': '#8467F6',
'BA.5': '#595EF6',
'BA.2.75': '#DE9ADB',
'BQ.1.1': '#8fe000',
'XBB.1.9': '#dd6bff',
'XBB.1.5': '#ff5656',
'XBB.1.16': '#e99b30',
'XBB.2.3': '#b4b82a',
'EG.5': '#359f99',
'BA.2.86': '#FF20E0',
'JN.1': '#00e9ff',
'undetermined': '#999696',
}


def make_legend(colors, variants):
"""make a shared legend for the plot"""
# Create a patch (i.e., a colored box) for each variant
variant_patches = [mpatches.Patch(color=color, label=variants[i]) for i, color in enumerate(colors)]

# Create lines for "fitted", "predicted", and "observed" labels
fitted_line = mlines.Line2D([], [], color='black', linestyle='-', label='fitted')
predicted_line = mlines.Line2D([], [], color='black', linestyle='--', label='predicted')
observed_points = mlines.Line2D([], [], color='black', marker='o', linestyle='None', label='daily estimates')
blank_line = mlines.Line2D([], [], color='white', linestyle='', label='')

# Combine all the legend handles
handles = variant_patches + [blank_line, fitted_line, predicted_line, observed_points]

return handles

def num_to_date(num, pos=None, date_min='2023-01-01', fmt='%b. \'%y'):
"""convert days number into a date format"""
date = pd.to_datetime(date_min) + pd.to_timedelta(num, "D")
return date.strftime(fmt)

85 changes: 85 additions & 0 deletions src/covvfit/_preprocess_abundances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""utilities to preprocess relative abundances"""
import pandas as pd
import numpy as np


# variants = [
# 'B.1.1.7', 'B.1.351', 'P.1', 'undetermined',
# 'B.1.617.2', 'BA.1', 'BA.2', 'BA.4', 'BA.5', 'BA.2.75',
# 'BQ.1.1', 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86"
# ]

# variants2 = [
# 'BA.4', 'BA.5', 'BA.2.75',
# 'BQ.1.1', 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86"
# ]


# variants3 = [
# 'BA.2.75', 'BA.5', 'BQ.1.1',
# 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5',
# 'BA.2.86',
# ]

# variants4 = [
# 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86"
# ]

# variants5 = [
# # 'B.1.1.7', 'B.1.351', 'P.1', 'undetermined',
# 'B.1.617.2', 'BA.1', 'BA.2', 'BA.4', 'BA.5', 'BA.2.75',
# 'BQ.1.1', 'XBB.1.5', 'XBB.1.9', 'XBB.1.16', 'XBB.2.3', 'EG.5', "BA.2.86"
# ]

# cities = ['Lugano (TI)', 'Zürich (ZH)', 'Chur (GR)', 'Altenrhein (SG)',
# 'Laupen (BE)', 'Genève (GE)', 'Basel (BS)', 'Porrentruy (JU)',
# 'Lausanne (VD)', 'Bern (BE)', 'Luzern (LU)', 'Solothurn (SO)',
# 'Neuchâtel (NE)', 'Schwyz (SZ)']

def load_data(file):
wwdat = pd.read_csv(file)
wwdat = wwdat.rename(columns={wwdat.columns[0]: 'time'})
return wwdat


def preprocess_df(
df,
cities,
variants,
undertermined_thresh=0.01,
zero_date='2023-01-01',
date_min=None,
date_max=None,
):
# Convert the 'time' column to datetime
df['time'] = pd.to_datetime(df['time'])

# Remove days with too high undetermined
df = df[df['undetermined'] < undertermined_thresh]

# Subset the 'BQ.1.1' column
df = df[['time', 'city'] + variants]

# Subset only the specified cities
df = df[df['city'].isin(cities)]

# Create a new column which is the difference in days between zero_date and the date
df['days_from'] = (df['time'] - pd.to_datetime(zero_date)).dt.days

# Subset dates
if date_min is not None:
df = df[df['time'] >= pd.to_datetime(date_min)]
if date_max is not None:
df = df[df['time'] < pd.to_datetime(date_max)]


return df


def make_data_list(df, cities, variants):
ts_lst = [df[(df.city == city)].days_from.values for city in cities]
ys_lst = [df[(df.city == city)][variants].values.T for city in cities]

return (ts_lst, ys_lst)


0 comments on commit 2a84d4f

Please sign in to comment.