Skip to content

Commit

Permalink
Merge pull request #49 from AllenNeuralDynamics/han_issue_27_37
Browse files Browse the repository at this point in the history
feat: some minor updates
  • Loading branch information
hanhou authored Apr 3, 2024
2 parents e600711 + 7a2d9de commit 4be9c98
Show file tree
Hide file tree
Showing 5 changed files with 924 additions and 31 deletions.
55 changes: 40 additions & 15 deletions code/Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
'x_y_plot_dot_size': 10,
'x_y_plot_dot_opacity': 0.5,
'x_y_plot_line_width': 2.0,

'session_plot_mode': 'sessions selected from table or plot',

'auto_training_history_x_axis': 'date',
'auto_training_history_sort_by': 'subject_id',
Expand Down Expand Up @@ -240,7 +242,7 @@ def draw_session_plots(df_to_draw_session):
except:
date_str = key["session_date"].split("T")[0]

st.markdown(f'''<h3 style='text-align: center; color: orange;'>{key["h2o"]}, Session {key["session"]}, {date_str}''',
st.markdown(f'''<h4 style='text-align: center; color: orange;'>{key["h2o"]}, Session {int(key["session"])}, {date_str}''',
unsafe_allow_html=True)
if len(st.session_state.selected_draw_types) > 1: # more than one types, use the pre-defined layout
for row, column_setting in enumerate(layout_definition):
Expand Down Expand Up @@ -314,11 +316,18 @@ def draw_mice_plots(df_to_draw_mice):
def session_plot_settings(need_click=True):
st.markdown('##### Show plots for individual sessions ')
cols = st.columns([2, 1])

session_plot_modes = [f'sessions selected from table or plot', f'all sessions filtered from sidebar']
st.session_state.selected_draw_sessions = cols[0].selectbox('Which session(s) to draw?',
[f'selected from table/plot ({len(st.session_state.df_selected_from_plotly)} sessions)',
f'filtered from sidebar ({len(st.session_state.df_session_filtered)} sessions)'],
index=0
)
session_plot_modes,
index=session_plot_modes.index(st.session_state['session_plot_mode'])
if 'session_plot_mode' in st.session_state else
session_plot_modes.index(st.query_params['session_plot_mode'])
if 'session_plot_mode' in st.query_params
else 0,
key='session_plot_mode',
)

st.session_state.num_cols = cols[1].number_input('Number of columns', 1, 10,
3 if 'num_cols' not in st.session_state else st.session_state.num_cols)

Expand Down Expand Up @@ -515,20 +524,35 @@ def init():

# Some ad-hoc modifications on df_sessions
st.session_state.df['sessions_bonsai'].columns = st.session_state.df['sessions_bonsai'].columns.get_level_values(1)
st.session_state.df['sessions_bonsai'].sort_values(['session_end_time'], ascending=False, inplace=True)
st.session_state.df['sessions_bonsai'] = st.session_state.df['sessions_bonsai'].reset_index().query('subject_id != "0"')
st.session_state.df['sessions_bonsai']['h2o'] = st.session_state.df['sessions_bonsai']['subject_id']
st.session_state.df['sessions_bonsai'].dropna(subset=['session'], inplace=True) # Remove rows with no session number (only leave the nwb file with the largest finished_trials for now)
st.session_state.df['sessions_bonsai'].drop(st.session_state.df['sessions_bonsai'].query('session < 1').index, inplace=True)

# # add something else
# st.session_state.df['sessions_bonsai']['abs(bias)'] = np.abs(st.session_state.df['sessions_bonsai'].biasL)

# add abs(bais) to all terms that have 'bias' in name
for col in st.session_state.df['sessions_bonsai'].columns:
if 'bias' in col:
st.session_state.df['sessions_bonsai'][f'abs({col})'] = np.abs(st.session_state.df['sessions_bonsai'][col])

# # delta weight
# diff_relative_weight_next_day = st.session_state.df['sessions_bonsai'].set_index(
# ['session']).sort_values('session', ascending=True).groupby('h2o').apply(
# lambda x: - x.relative_weight.diff(periods=-1)).rename("diff_relative_weight_next_day")

# weekday
# st.session_state.df['sessions_bonsai']['weekday'] = st.session_state.df['sessions_bonsai'].session_date.dt.dayofweek + 1
st.session_state.df['sessions_bonsai'].session_date = pd.to_datetime(st.session_state.df['sessions_bonsai'].session_date)
st.session_state.df['sessions_bonsai']['weekday'] = st.session_state.df['sessions_bonsai'].session_date.dt.day_name()

# foraging performance = foraing_eff * finished_rate
if 'foraging_performance' not in st.session_state.df['sessions_bonsai'].columns:
st.session_state.df['sessions_bonsai']['foraging_performance'] = \
st.session_state.df['sessions_bonsai']['foraging_eff'] \
* st.session_state.df['sessions_bonsai']['finished_rate']
st.session_state.df['sessions_bonsai']['foraging_performance_random_seed'] = \
st.session_state.df['sessions_bonsai']['foraging_eff_random_seed'] \
* st.session_state.df['sessions_bonsai']['finished_rate']

# st.session_state.df['sessions_bonsai'] = st.session_state.df['sessions_bonsai'].merge(
# diff_relative_weight_next_day, how='left', on=['h2o', 'session'])
Expand Down Expand Up @@ -563,7 +587,7 @@ def app():
if st.button('Reload data from AWS S3'):
st.cache_data.clear()
init()
st.experimental_rerun()
st.rerun()



Expand All @@ -575,11 +599,11 @@ def app():
cols = st.columns([2, 2, 2])
cols[0].markdown(f'### Filter the sessions on the sidebar ({len(st.session_state.df_session_filtered)} filtered)')
# if cols[1].button('Press this and then Ctrl + R to reload from S3'):
# st.experimental_rerun()
# st.rerun()
if cols[1].button('Reload data '):
st.cache_data.clear()
init()
st.experimental_rerun()
st.rerun()

# aggrid_outputs = aggrid_interactive_table_units(df=df['ephys_units'])
# st.session_state.df_session_filtered = aggrid_outputs['data']
Expand All @@ -599,7 +623,7 @@ def app():
st.session_state.df_selected_from_dataframe = pd.DataFrame(aggrid_outputs['selected_rows'])
st.session_state.df_selected_from_plotly = st.session_state.df_selected_from_dataframe # Sync selected on plotly
# if st.session_state.tab_id == "tab_session_x_y":
st.experimental_rerun()
st.rerun()

chosen_id = stx.tab_bar(data=[
stx.TabBarItemData(id="tab_session_x_y", title="📈 Session X-Y plot", description="Interactive session-wise scatter plot"),
Expand Down Expand Up @@ -632,7 +656,7 @@ def app():
st.session_state.df_selected_from_plotly.set_index(['h2o', 'session']).index):
st.session_state.df_selected_from_plotly = df_selected_from_plotly
st.session_state.df_selected_from_dataframe = df_selected_from_plotly # Sync selected on dataframe
st.experimental_rerun()
st.rerun()

elif chosen_id == "tab_pygwalker":
with placeholder:
Expand Down Expand Up @@ -667,10 +691,11 @@ def app():

elif chosen_id == "tab_session_inspector":
with placeholder:
with st.columns([4, 10])[0]:
cols = st.columns([6, 3, 7])
with cols[0]:
if_draw_all_sessions = session_plot_settings(need_click=False)
df_to_draw_sessions = st.session_state.df_selected_from_plotly if 'selected' in st.session_state.selected_draw_sessions else st.session_state.df_session_filtered

if if_draw_all_sessions and len(df_to_draw_sessions):
draw_session_plots(df_to_draw_sessions)

Expand Down
55 changes: 49 additions & 6 deletions code/pages/1_Old mice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import extra_streamlit_components as stx

from aind_auto_train.auto_train_manager import DynamicForagingAutoTrainManager
from pygwalker.api.streamlit import StreamlitRenderer, init_streamlit_comm


# Sync widgets with URL query params
Expand Down Expand Up @@ -486,6 +487,7 @@ def init():
df_this_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}')
valid_field = df_this_model.columns[~np.all(~df_this_model.notna(), axis=0)]
to_add_model = st.session_state.df['model_fitting_params'].query(f'model_id == {selected_id}')[valid_field]
st.session_state.df['sessions'].drop(st.session_state.df['sessions'].query('session < 1').index, inplace=True)

st.session_state.df['sessions'] = st.session_state.df['sessions'].merge(to_add_model, on=('subject_id', 'session'), how='left')

Expand All @@ -497,6 +499,12 @@ def init():
['session']).sort_values('session', ascending=True).groupby('h2o').apply(
lambda x: - x.relative_weight.diff(periods=-1)).rename("diff_relative_weight_next_day")

# foraging performance = foraing_eff * finished_ratio
if 'foraging_performance' not in st.session_state.df['sessions'].columns:
st.session_state.df['sessions']['foraging_performance'] = \
st.session_state.df['sessions']['foraging_eff'] \
* (1 - st.session_state.df['sessions']['ignore_rate'])

# weekday
st.session_state.df['sessions']['weekday'] = st.session_state.df['sessions'].session_date.dt.dayofweek + 1

Expand All @@ -505,7 +513,9 @@ def init():

st.session_state.session_stats_names = [keys for keys in st.session_state.df['sessions'].keys()]


@st.cache_resource(ttl=24*3600)
def get_pyg_renderer(df, spec="./gw_config.json", **kwargs) -> "StreamlitRenderer":
return StreamlitRenderer(df, spec=spec, debug=False, **kwargs)


def app():
Expand All @@ -524,7 +534,7 @@ def app():
if st.button('Reload data from AWS S3'):
st.cache_data.clear()
init()
st.experimental_rerun()
st.rerun()



Expand All @@ -536,11 +546,11 @@ def app():
cols = st.columns([2, 2, 2])
cols[0].markdown(f'### Filter the sessions on the sidebar ({len(st.session_state.df_session_filtered)} filtered)')
# if cols[1].button('Press this and then Ctrl + R to reload from S3'):
# st.experimental_rerun()
# st.rerun()
if cols[1].button('Reload data '):
st.cache_data.clear()
init()
st.experimental_rerun()
st.rerun()

# aggrid_outputs = aggrid_interactive_table_units(df=df['ephys_units'])
# st.session_state.df_session_filtered = aggrid_outputs['data']
Expand All @@ -560,11 +570,12 @@ def app():
st.session_state.df_selected_from_dataframe = pd.DataFrame(aggrid_outputs['selected_rows'])
st.session_state.df_selected_from_plotly = st.session_state.df_selected_from_dataframe # Sync selected on plotly
# if st.session_state.tab_id == "tab_session_x_y":
st.experimental_rerun()
st.rerun()

chosen_id = stx.tab_bar(data=[
stx.TabBarItemData(id="tab_session_x_y", title="📈 Session X-Y plot", description="Interactive session-wise scatter plot"),
stx.TabBarItemData(id="tab_session_inspector", title="👀 Session Inspector", description="Select sessions from the table and show plots"),
stx.TabBarItemData(id="tab_pygwalker", title="📊 PyGWalker (Tableau)", description="Interactive dataframe explorer"),
stx.TabBarItemData(id="tab_auto_train_history", title="🎓 Automatic Training History", description="Track progress"),
stx.TabBarItemData(id="tab_mouse_inspector", title="🐭 Mouse Model Fitting", description="Mouse-level model fitting results"),
], default="tab_session_inspector" if 'tab_id' not in st.session_state else st.session_state.tab_id)
Expand All @@ -591,7 +602,7 @@ def app():
st.session_state.df_selected_from_plotly.set_index(['h2o', 'session']).index):
st.session_state.df_selected_from_plotly = df_selected_from_plotly
st.session_state.df_selected_from_dataframe = df_selected_from_plotly # Sync selected on dataframe
st.experimental_rerun()
st.rerun()

elif chosen_id == "tab_session_inspector":
st.session_state.tab_id = chosen_id
Expand All @@ -602,6 +613,38 @@ def app():

if if_draw_all_sessions and len(df_to_draw_sessions):
draw_session_plots(df_to_draw_sessions)

elif chosen_id == "tab_pygwalker":
with placeholder:
cols = st.columns([1, 4])
cols[0].markdown('##### Exploring data using [PyGWalker](https://docs.kanaries.net/pygwalker)')
with cols[1]:
with st.expander('Specify PyGWalker json'):
# Load json from ./gw_config.json
pyg_user_json = st.text_area("Export your plot settings to json by clicking `export_code` "
"button below and then paste your json here to reproduce your plots",
key='pyg_walker', height=100)

# If pyg_user_json is not empty, use it; otherwise, use the default gw_config.json
if pyg_user_json:
try:
pygwalker_renderer = get_pyg_renderer(
df=st.session_state.df_session_filtered,
spec=pyg_user_json,
)
except:
pygwalker_renderer = get_pyg_renderer(
df=st.session_state.df_session_filtered,
spec="./gw_config_old_mice.json",
)
else:
pygwalker_renderer = get_pyg_renderer(
df=st.session_state.df_session_filtered,
spec="./gw_config_old_mice.json",
)

pygwalker_renderer.render_explore(height=1010, scrolling=False)


elif chosen_id == "tab_auto_train_history": # Automatic training history
st.session_state.tab_id = chosen_id
Expand Down
19 changes: 11 additions & 8 deletions code/util/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def aggrid_interactive_table_session(df: pd.DataFrame):

options.configure_side_bar()

df = df.sort_values('session_date', ascending=False)
if 'session_end_time' in df.columns:
df = df.sort_values('session_end_time', ascending=False)
else:
df = df.sort_values('session_date', ascending=False)

# preselect
if (('df_selected_from_dataframe' in st.session_state and len(st.session_state.df_selected_from_dataframe))
Expand Down Expand Up @@ -390,7 +393,7 @@ def add_session_filter(if_bonsai=False, url_query={}):
def add_xy_selector(if_bonsai):
with st.expander("Select axes", expanded=True):
# with st.form("axis_selection"):
cols = st.columns([1, 1, 1])
cols = st.columns([1])
x_name = cols[0].selectbox("x axis",
st.session_state.session_stats_names,
index=st.session_state.session_stats_names.index(st.session_state['x_y_plot_xname'])
Expand All @@ -400,7 +403,7 @@ def add_xy_selector(if_bonsai):
else st.session_state.session_stats_names.index('session'),
key='x_y_plot_xname'
)
y_name = cols[1].selectbox("y axis",
y_name = cols[0].selectbox("y axis",
st.session_state.session_stats_names,
index=st.session_state.session_stats_names.index(st.session_state['x_y_plot_yname'])
if 'x_y_plot_yname' in st.session_state else
Expand All @@ -410,12 +413,12 @@ def add_xy_selector(if_bonsai):
key='x_y_plot_yname')

if if_bonsai:
options = ['h2o', 'task', 'user_name', 'rig']
options = ['h2o', 'task', 'user_name', 'rig', 'weekday']
else:
options = ['h2o', 'task', 'photostim_location', 'weekday',
'headbar', 'user_name', 'sex', 'rig']

group_by = cols[2].selectbox("grouped by",
group_by = cols[0].selectbox("grouped by",
options=options,
index=options.index(st.session_state['x_y_plot_group_by'])
if 'x_y_plot_group_by' in st.session_state else
Expand Down Expand Up @@ -574,21 +577,21 @@ def data_selector():

# if cols[1].button('❌'):
# st.session_state.df_selected_from_dataframe = pd.DataFrame()
# st.experimental_rerun()
# st.rerun()

cols = st.columns([5, 1, 1])
with cols[0].expander(f"Selected: {len(st.session_state.df_selected_from_plotly)} sessions, "
f"{len(st.session_state.df_selected_from_plotly.h2o.unique())} mice", expanded=False):
st.dataframe(st.session_state.df_selected_from_plotly)
if cols[1].button('all'):
st.session_state.df_selected_from_plotly = st.session_state.df_session_filtered
st.experimental_rerun()
st.rerun()


if cols[2].button('❌ '):
st.session_state.df_selected_from_plotly = pd.DataFrame(columns=['h2o', 'session'])
st.session_state.df_selected_from_dataframe = pd.DataFrame(columns=['h2o', 'session'])
st.experimental_rerun()
st.rerun()

def add_auto_train_manager():

Expand Down
Loading

0 comments on commit 4be9c98

Please sign in to comment.