Skip to content

Commit

Permalink
Fixes to parameter analysis (#509)
Browse files Browse the repository at this point in the history
  • Loading branch information
tnatt authored Dec 19, 2020
1 parent abc91e0 commit 493b955
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 78 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- [#529](https://github.com/equinor/webviz-subsurface/pull/529) - Added support for PVDO and PVTG to PVT plot and to respective data modules.
- [#509](https://github.com/equinor/webviz-subsurface/pull/509) - Added descriptive hoverinfo to `ParameterAnalysis`. Average and standard deviation of parameter value
for each ensemble shown on mouse hover over figure. Included dynamic sizing of plot titles and plot spacing to optimize the appearance of plots when many parameters are plotted.

## [0.1.6] - 2020-11-30
### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,12 @@ def _update_graphs(
raise PreventUpdate
ctx = dash.callback_context.triggered[0]["prop_id"].split(".")[0]

initial_run = timeseries_fig is None
color = options["color"] if options["color"] is not None else "#007079"
daterange = parent.vmodel.daterange_for_plot(vector=vector)

# Make timeseries graph
if (
relevant_ctx_for_plot(parent, ctx, plot="timeseries_fig")
or timeseries_fig is None
):
if relevant_ctx(parent, ctx, operation="timeseries_fig") or initial_run:
timeseries_fig = update_timeseries_graph(
parent.vmodel,
ensemble,
Expand All @@ -84,20 +82,19 @@ def _update_graphs(
real_filter=None,
)

vectors_filtered = filter_vectors(
parent, vector_type_filter, vector_item_filters
)
if vector not in vectors_filtered:
vectors_filtered.append(vector)

merged_df = merge_parameter_and_vector_df(
parent, ensemble, vectors_filtered, date
)
if parent.uuid("plot-options") not in ctx or initial_run:
vectors_filtered = filter_vectors(
parent, vector_type_filter, vector_item_filters
)
if vector not in vectors_filtered:
vectors_filtered.append(vector)
merged_df = merge_parameter_and_vector_df(
parent, ensemble, vectors_filtered, date
)

# Make correlation figure for vector
if options["autocompute_corr"] and (
relevant_ctx_for_plot(parent, ctx, plot="vector_correlation")
or corr_v_fig is None
relevant_ctx(parent, ctx, operation="vector_correlation") or initial_run
):
corr_v_fig = make_correlation_figure(
merged_df, response=vector, corrwith=parent.pmodel.parameters
Expand All @@ -111,8 +108,7 @@ def _update_graphs(

# Make correlation figure for parameter
if options["autocompute_corr"] and (
relevant_ctx_for_plot(parent, ctx, plot="parameter_correlation")
or corr_p_fig is None
relevant_ctx(parent, ctx, operation="parameter_correlation") or initial_run
):
corr_p_fig = make_correlation_figure(
merged_df, response=parameter, corrwith=vectors_filtered
Expand All @@ -121,7 +117,7 @@ def _update_graphs(
corr_p_fig = color_corr_bars(corr_p_fig, vector, color, options["opacity"])

# Create scatter plot of vector vs parameter
if relevant_ctx_for_plot(parent, ctx, plot="scatter") or scatter_fig is None:
if relevant_ctx(parent, ctx, operation="scatter") or initial_run:
scatter_fig = update_scatter_graph(merged_df, vector, parameter, color)

scatter_fig = scatter_fig_color_update(scatter_fig, color, options["opacity"])
Expand All @@ -130,9 +126,10 @@ def _update_graphs(
df_value_norm = parent.pmodel.get_real_and_value_df(
ensemble, parameter=parameter, normalize=True
)
timeseries_fig = color_timeseries_graph(
timeseries_fig, ensemble, parameter, vector, df_value_norm
)
if relevant_ctx(parent, ctx, operation="color_timeseries_fig") or initial_run:
timeseries_fig = color_timeseries_graph(
timeseries_fig, ensemble, parameter, vector, df_value_norm
)

# Draw date selected as line
timeseries_fig = add_date_line(timeseries_fig, date, options["show_dateline"])
Expand Down Expand Up @@ -320,7 +317,7 @@ def _update_parameter_selected(


# pylint: disable=inconsistent-return-statements
def relevant_ctx_for_plot(parent, ctx: list, plot: str):
def relevant_ctx(parent, ctx: list, operation: str):
"""Group relevant uuids for the different plots"""
vector = parent.uuid("vector-select") in ctx
date = parent.uuid("date-selected") in ctx
Expand All @@ -330,14 +327,16 @@ def relevant_ctx_for_plot(parent, ctx: list, plot: str):
parent.uuid("vtype-filter") in ctx or parent.uuid("vitem-filter") in ctx
)

if plot == "timeseries_fig":
if operation == "timeseries_fig":
return any([vector, ensemble])
if plot == "scatter":
if operation == "scatter":
return any([vector, date, parameter, ensemble])
if plot == "parameter_correlation":
if operation == "parameter_correlation":
return any([filtered_vectors, date, parameter, ensemble])
if plot == "vector_correlation":
if operation == "vector_correlation":
return any([vector, date, ensemble])
if operation == "color_timeseries_fig":
return any([parameter, ensemble, vector])


def find_vector_type(parent, vector: str):
Expand Down Expand Up @@ -375,7 +374,7 @@ def update_timeseries_graph(
ensemble=ensemble, vector=vector, real_filter=real_filter
),
"layout": dict(
margin={"r": 20, "l": 20, "t": 60, "b": 20},
margin={"r": 40, "l": 20, "t": 60, "b": 20},
yaxis={"automargin": True},
xaxis={"range": xaxisrange},
hovermode="closest",
Expand Down Expand Up @@ -425,13 +424,12 @@ def color_timeseries_graph(
):
"""Color timeseries lines by parameter value"""
if df_norm is not None:

for trace_no, trace in enumerate(figure.get("data", [])):
for trace in figure.get("data", []):
if trace["name"] == ensemble:
figure["data"][trace_no]["marker"]["color"] = set_real_color(
trace["marker"]["color"] = set_real_color(
real_no=trace["customdata"], df_norm=df_norm
)
figure["data"][trace_no]["hovertext"] = (
trace["hovertext"] = (
f"Real: {str(trace['customdata'])}, {selected_param}: "
f"{df_norm.loc[df_norm['REAL'] == trace['customdata']].iloc[0]['VALUE']}"
)
Expand All @@ -448,7 +446,7 @@ def set_real_color(df_norm, real_no: str):
Midpoint for the colorscale is set on the average value
"""
red = "rgba(255,18,67, 1)"
mid_color = "rgba(220,220,220,1"
mid_color = "rgba(220,220,220,1)"
green = "rgba(62,208,62, 1)"

mean = df_norm["VALUE_NORM"].mean()
Expand All @@ -460,7 +458,7 @@ def set_real_color(df_norm, real_no: str):
if norm_value > mean:
intermed = (norm_value - mean) / (1 - mean)
return find_intermediate_color(mid_color, green, intermed, colortype="rgba")
return "rgba(220,220,220, 0.2"
return "rgba(220,220,220, 0.2)"


def merge_parameter_and_vector_df(parent, ensemble: str, vectors: list, date: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,58 @@


def color_figure(
px_colors: dict,
custom_colors: dict = None,
colors: list,
bargap: float = None,
height: float = None,
):
"""
Create bar chart with colors, can be used as a color selector
Create bar chart with colors, can e.g. be used as a color selector
by retrieving clickdata.
Input can either be a ditionary with plotly colors with the key
beeing the px colormodul to use and the value is a list of
colorscale names, or can be given in as a custom_color dictionary where
the key is a name and the value is a list of colors.
The color argument is a list of items where the individual items
are either a name of a built-in plotly colorscale, or a list of colors.
"""
custom_colors = custom_colors if custom_colors is not None else {}

for px_cmodule, px_cscales in px_colors.items():
if px_cmodule == "diverging":
custom_colors.update(get_px_colors(px.colors.diverging, px_cscales))
if px_cmodule == "sequential":
custom_colors.update(get_px_colors(px.colors.sequential, px_cscales))
if px_cmodule == "cyclical":
custom_colors.update(get_px_colors(px.colors.cyclical, px_cscales))
if px_cmodule == "qualitative":
custom_colors.update(get_px_colors(px.colors.qualitative, px_cscales))
color_lists = []
for item in colors:
if isinstance(item, list) and item:
color_lists.append(item)
if isinstance(item, str):
color_lists.append(get_px_colors(item))

return go.Figure(
data=[
go.Bar(
orientation="h",
y=[name] * len(colors),
x=[1] * len(colors),
customdata=list(range(len(colors))),
marker=dict(color=colors),
y=[str(idx)] * len(clist),
x=[1] * len(clist),
customdata=list(range(len(clist))),
marker=dict(color=clist),
hovertemplate="%{marker.color}<extra></extra>",
)
for name, colors in custom_colors.items()
for idx, clist in enumerate(reversed(color_lists))
],
layout=dict(
title=None,
barmode="stack",
barnorm="fraction",
bargap=0.5,
bargap=bargap if bargap is not None else 0.5,
showlegend=False,
xaxis=dict(range=[-0.02, 1.02], showticklabels=False, showgrid=False),
yaxis_showticklabels=False,
height=height if height is not None else 20 * len(custom_colors),
height=height if height is not None else 20 * len(color_lists),
margin=dict(l=0, r=0, t=0, b=0),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
),
)


def get_px_colors(px_cmodule, cscales: list):
return {k: v for (k, v) in px_cmodule.__dict__.items() if k in cscales}
def get_px_colors(px_cscale: str):
for cmodule in [
px.colors.diverging,
px.colors.sequential,
px.colors.cyclical,
px.colors.qualitative,
]:
if px_cscale in cmodule.__dict__:
return cmodule.__dict__[px_cscale]
return []
Original file line number Diff line number Diff line change
Expand Up @@ -207,39 +207,92 @@ def make_grouped_plot(
),
20,
),
facet_row_spacing=max((0.08 - (0.00071 * len(parameters))), 0.03),
color="ENSEMBLE",
color_discrete_sequence=self.colorway,
custom_data=["PARAMETER"],
)
.update_xaxes(
matches=None,
fixedrange=True,
title=None,
showticklabels=(len(parameters) < 20),
showticklabels=len(parameters) <= 100,
tickangle=0,
tickfont_size=max((18 - (0.4 * len(parameters))), 10),
)
.update_yaxes(showticklabels=False)
.for_each_trace(
lambda t: t.update(
y0=0,
hoveron="violins",
hoverinfo="name",
hoverinfo="none",
meanline_visible=True,
orientation="h",
side="positive",
width=3,
width=2,
points=False,
)
)
.for_each_annotation(
lambda a: a.update(
hovertext=a.text.split("=")[-1],
text=(a.text.split("=")[-1]) if len(parameters) < 40 else "",
text=(a.text.split("=")[-1]),
visible=len(parameters) <= 42,
font_size=max((18 - (0.4 * len(parameters))), 10),
)
)
)

# Create invisible boxes used for hoverinfo on the violin plots
# Necessary due to https://github.com/plotly/plotly.js/issues/2145
ensembles = df["ENSEMBLE"].unique()
hovertraces = []
for trace in fig["data"]:
parameter = trace["customdata"][0][0]
# check of parameter value to determine print formatter
value = abs(self.get_stat_value(parameter, ensembles[0], stat_column="Avg"))
form = ".1f" if value > 10 else ".2g"
hovertraces.append(
go.Scatter(
x=[min(trace.x), min(trace.x), max(trace.x), max(trace.x)],
y=[0, 1, 1, 0],
xaxis=trace.xaxis,
yaxis=trace.yaxis,
mode="lines",
fill="toself",
opacity=0,
showlegend=False,
text=(
f"<b>{parameter}</b><br>"
+ "<br>".join(
f"<b>{ens}:</b><br>"
"Avg: "
f"{self.get_stat_value(parameter, ens, stat_column='Avg'):{form}}<br>"
"Std: "
f"{self.get_stat_value(parameter, ens, stat_column='Stddev'):{form}}"
for ens in ensembles
)
),
hoverinfo="text",
hoverlabel=dict(
bgcolor="#E6FAEC", font=dict(color="#243746", size=15)
),
)
)
fig = fig.to_dict()
fig["data"].extend(hovertraces)
fig["layout"] = self.theme.create_themed_layout(fig["layout"])

fig["layout"].update(paper_bgcolor="white", plot_bgcolor="white")
return fig

def get_stat_value(self, parameter: str, ensemble: str, stat_column: str):
"""
Retrive statistical value for a parameter in an ensamble.
"""
return self.statframe.loc[
(self.statframe["PARAMETER"] == parameter)
& (self.statframe["ENSEMBLE"] == ensemble)
].iloc[0][stat_column]

def get_real_and_value_df(
self, ensemble: str, parameter: str, normalize: bool = False
) -> pd.DataFrame:
Expand All @@ -255,4 +308,4 @@ def get_real_and_value_df(
df["VALUE_NORM"] = (df["VALUE"] - df["VALUE"].min()) / (
df["VALUE"].max() - df["VALUE"].min()
)
return df.reset_index()
return df.reset_index(drop=True)
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def timeseries_view(parent) -> html.Div:

def selector_view(parent) -> html.Div:

theme_colors = parent.theme.plotly_theme.get("layout", {}).get("colorway", [])
theme_colors = (
theme_colors[1:12] if theme_colors and len(theme_colors) >= 12 else theme_colors
)

return html.Div(
style={
"height": "80vh",
Expand Down Expand Up @@ -75,8 +80,9 @@ def selector_view(parent) -> html.Div:
color_selector(
parent=parent,
tab="response",
px_colors={"sequential": ["Greys"], "diverging": ["BrBG"]},
height=60,
colors=[theme_colors, "Greys", "BrBG"],
bargap=0.2,
height=50,
),
color_opacity_selector(parent=parent, tab="response", value=0.5),
],
Expand Down
Loading

0 comments on commit 493b955

Please sign in to comment.