From 14dd3348831ee9823f17e976d7acfbb916de09f4 Mon Sep 17 00:00:00 2001 From: Phuong Ho Date: Mon, 9 Dec 2024 02:28:20 -0800 Subject: [PATCH] Add func to simulate and plot dynamics for motifs --- protosignet/plot_results.py | 81 ++++++++++++++++++++++++++++++++++--- protosignet/util.py | 13 +++--- 2 files changed, 83 insertions(+), 11 deletions(-) diff --git a/protosignet/plot_results.py b/protosignet/plot_results.py index 8c8494d..f1f0159 100644 --- a/protosignet/plot_results.py +++ b/protosignet/plot_results.py @@ -1,11 +1,14 @@ +import ast from pathlib import Path import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +import pandas as pd import seaborn as sns -from protosignet.util import eval_pareto, tag_objectives +from protosignet.model import sim_signet +from protosignet.util import calc_n_nodes, eval_pareto, tag_objectives CUSTOM_PALETTE = ["#648FFF", "#2ECC71", "#8069EC", "#EA822C", "#D143A4", "#F1C40F", "#34495E"] @@ -27,7 +30,7 @@ "axes.labelsize": 72, "xtick.labelsize": 56, "ytick.labelsize": 56, - "legend.fontsize": 56, + "legend.fontsize": 48, } @@ -45,7 +48,7 @@ def plot_figure_1d(data_dp, fig_fp): df_top = df.iloc[df.groupby("obj1")["obj2"].idxmax().values].copy() df_top["is_pareto"] = eval_pareto(df_top[["obj1", "obj2"]].to_numpy()) df_pareto = df_top.loc[df_top["is_pareto"] == 1] - # print(df_pareto) + print(df_pareto) with plt.style.context(("seaborn-v0_8-whitegrid", CUSTOM_STYLE)): fig, ax = plt.subplots(figsize=(24, 20)) sns.scatterplot(data=df_gen_001, x="obj1", y="obj2", edgecolor="#212121", facecolor="#2ECC71", alpha=0.8, linewidth=2, s=600) @@ -83,12 +86,80 @@ def plot_figure_1d(data_dp, fig_fp): plt.close("all") +def plot_figure_1e(address, data_dp, fig_fp): + """Simulate dynamics for a specified motif. + + Args: + address (list): [rep_i, gen_j, pop_k] + data_dp (str): absolute path to data directory + fig_fp (str): absolute path for saving generated figure + """ + rep_i, gen_j, pop_k = address + df_rep = pd.read_csv(data_dp / f"{int(rep_i)}.csv") + pop_rep = df_rep["population"].values + pop_gen = np.array(ast.literal_eval(pop_rep[int(gen_j)])) + n_params = pop_gen.shape[1] + n_nodes = calc_n_nodes(n_params) + indiv = pop_gen[int(pop_k)].reshape(int(n_nodes), -1) + print(indiv) + kr = indiv[:, 0] + ku = indiv[:, 1] + kX = indiv[:, 2:] + tu = np.arange(0, 121, 1.0) + uu = np.zeros_like(tu) + uu[40:80:10] = 1.0 # sparse input + uu[80:121:1] = 1.0 # dense input + tm, Xm = sim_signet(tu, uu, kr, ku, kX) + X1_df = pd.DataFrame({"t": tm, "y": Xm[0], "h": np.ones_like(tm) * 0}) + X2_df = pd.DataFrame({"t": tm, "y": Xm[1], "h": np.ones_like(tm) * 1}) + Xm_df = pd.concat([X1_df, X2_df], ignore_index=True) + with plt.style.context(("seaborn-v0_8-whitegrid", CUSTOM_STYLE)): + fig, ax = plt.subplots(figsize=(24, 20)) + sns.lineplot(data=Xm_df, x="t", y="y", hue="h", ax=ax, palette=["#8069EC", "#EA822C"], zorder=2.2) + ymin, ymax = ax.get_ylim() + for t in tu[uu > 0]: + ax.axvspan(t, t + 1, color="#648FFF", alpha=0.5, linewidth=0, zorder=2.1) + ax.set_ylim(ymin, ymax) + handles = [ + mpl.lines.Line2D([], [], color="#648FFF", linewidth=16, alpha=0.5), + mpl.lines.Line2D([], [], color="#8069EC", linewidth=16), + mpl.lines.Line2D([], [], color="#EA822C", linewidth=16), + ] + group_labels = ["Input", "Dense Decoder", "Sparse Decoder"] + ax.legend( + handles, + group_labels, + loc="best", + markerscale=4, + frameon=True, + shadow=False, + framealpha=1.0, + handletextpad=0.4, + borderpad=0.2, + labelspacing=0.2, + handlelength=1, + ) + ax.set_xlabel("Time") + ax.set_ylabel("AU") + ax.locator_params(axis="x", nbins=10) + ax.locator_params(axis="y", nbins=10) + fig.tight_layout() + fig.canvas.draw() + fig.savefig(fig_fp, pad_inches=0.3, dpi=200, bbox_inches="tight", transparent=False) + plt.close("all") + + def main(): data_dp = Path("/home/phuong/data/protosignet/dual_fm/data/") save_dp = Path("/home/phuong/data/protosignet/dual_fm/figs/") save_dp.mkdir(parents=True, exist_ok=True) - fig_fp = save_dp / "fig_1d.png" - plot_figure_1d(data_dp, fig_fp) + + # fig_fp = save_dp / "fig_1d.png" + # plot_figure_1d(data_dp, fig_fp) + + for a, address in enumerate([[0, 241, 93], [1, 235, 83], [3, 246, 45]]): + fig_fp = save_dp / f"fig_1e_{a}.png" + plot_figure_1e(address, data_dp, fig_fp) if __name__ == "__main__": diff --git a/protosignet/util.py b/protosignet/util.py index 5670744..097fd81 100644 --- a/protosignet/util.py +++ b/protosignet/util.py @@ -79,9 +79,10 @@ def eval_pareto(objectives): return np.array(is_pareto) -def fetch_indiv(csv_fp, gen_j, pop_k): - df = pd.read_csv(Path(csv_fp)) - pop_rep = df["population"].values - pop_gen = np.array(ast.literal_eval(pop_rep[int(gen_j)])) - indiv = pop_gen[int(pop_k)] - return indiv +def calc_n_nodes(n_params): + x = 0 + y = 1 + while y != 0: + x += 1 + y = x**2 + 2 * x - n_params + return x