From 148500f55c8d6358bb7702e6679ee55adcae92a8 Mon Sep 17 00:00:00 2001 From: Hans Kallekleiv Date: Wed, 8 Jan 2020 14:49:45 +0100 Subject: [PATCH] Add theming for plotly figures (#180) --- setup.py | 2 +- .../integration_tests/test_parameter_corr.py | 6 +- .../_private_plugins/tornado_plot.py | 99 +++++++++++-------- webviz_subsurface/plugins/_inplace_volumes.py | 46 +++++---- .../plugins/_inplace_volumes_onebyone.py | 12 ++- .../plugins/_parameter_correlation.py | 39 +++----- .../plugins/_parameter_distribution.py | 2 + .../_parameter_response_correlation.py | 31 +++--- .../_reservoir_simulation_timeseries.py | 69 ++++++------- ...eservoir_simulation_timeseries_onebyone.py | 13 ++- 10 files changed, 171 insertions(+), 148 deletions(-) diff --git a/setup.py b/setup.py index 5ecbd0812..bdce8aca6 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ "matplotlib~=3.0", "pillow~=6.1", "xtgeo~=2.1", - "webviz-config>=0.0.35", + "webviz-config>=0.0.41", # webviz-subsurface-components is part of the webviz-subsurface project, # just located in a separate repository for convenience, # and is therefore pinned exactly here: diff --git a/tests/integration_tests/test_parameter_corr.py b/tests/integration_tests/test_parameter_corr.py index f7aff6398..92936c144 100644 --- a/tests/integration_tests/test_parameter_corr.py +++ b/tests/integration_tests/test_parameter_corr.py @@ -2,6 +2,7 @@ import dash import pandas as pd from webviz_config.common_cache import CACHE +from webviz_config.themes import default_theme from webviz_config.plugins import ParameterCorrelation # mocked functions @@ -15,7 +16,10 @@ def test_parameter_corr(dash_duo): app.scripts.config.serve_locally = True app.config.suppress_callback_exceptions = True CACHE.init_app(app.server) - app.webviz_settings = {"shared_settings": {"scratch_ensembles": {"iter-0": ""}}} + app.webviz_settings = { + "shared_settings": {"scratch_ensembles": {"iter-0": ""}}, + "theme": default_theme, + } ensembles = ["iter-0"] with mock.patch(get_parameters) as mock_parameters: diff --git a/webviz_subsurface/_private_plugins/tornado_plot.py b/webviz_subsurface/_private_plugins/tornado_plot.py index 7b33bc9ed..7cac076fa 100644 --- a/webviz_subsurface/_private_plugins/tornado_plot.py +++ b/webviz_subsurface/_private_plugins/tornado_plot.py @@ -58,6 +58,7 @@ def __init__( ) self.allow_click = allow_click self.uid = uuid4() + self.plotly_theme = app.webviz_settings["theme"].plotly_theme self.set_callbacks(app) def ids(self, element): @@ -212,7 +213,14 @@ def _calc_tornado(reference, scale, cutbyref, data): self.realizations["ENSEMBLE"] == data["ENSEMBLE"] ] try: - return tornado_plot(realizations, values, reference, scale, cutbyref) + return tornado_plot( + realizations, + values, + plotly_theme=self.plotly_theme, + reference=reference, + scale=scale, + cutbyref=cutbyref, + ) except KeyError: return {} @@ -288,7 +296,12 @@ def cut_by_ref(tornadotable, refname): @CACHE.memoize(timeout=CACHE.TIMEOUT) def tornado_plot( - realizations, data, reference="rms_seed", scale="Percentage", cutbyref=True + realizations, + data, + plotly_theme, + reference="rms_seed", + scale="Percentage", + cutbyref=True, ): # pylint: disable=too-many-locals # Raise key error if no senscases, i.e. the ensemble has no design matrix @@ -394,44 +407,44 @@ def tornado_plot( df = sort_by_max(df) # Return tornado data as Plotly figure - return { - "data": [ - dict( - type="bar", - y=df["sensname"], - x=df["low"], - name="low", - customdata=df["low_reals"], - hovertext=[ - f"Case: {label}
True Value: {val:.2f}
Realizations:" - f"{min(reals) if reals else None}-{max(reals) if reals else None}" - for label, val, reals in zip( - df["low_label"], df["true_low"], df["low_reals"] - ) - ], - hoverinfo="x+text", - orientation="h", - marker=dict(color="rgb(235, 0, 54)"), - ), - dict( - type="bar", - y=df["sensname"], - x=df["high"], - name="high", - customdata=df["high_reals"], - hovertext=[ - f"Case: {label}
True Value: {val:.2f}
Realizations:" - f"{min(reals) if reals else None}-{max(reals) if reals else None}" - for label, val, reals in zip( - df["high_label"], df["true_high"], df["high_reals"] - ) - ], - hoverinfo="x+text", - orientation="h", - marker=dict(color="rgb(36, 55, 70)"), - ), - ], - "layout": { + plot_data = [ + dict( + type="bar", + y=df["sensname"], + x=df["low"], + name="low", + customdata=df["low_reals"], + hovertext=[ + f"Case: {label}
True Value: {val:.2f}
Realizations:" + f"{min(reals) if reals else None}-{max(reals) if reals else None}" + for label, val, reals in zip( + df["low_label"], df["true_low"], df["low_reals"] + ) + ], + hoverinfo="x+text", + orientation="h", + ), + dict( + type="bar", + y=df["sensname"], + x=df["high"], + name="high", + customdata=df["high_reals"], + hovertext=[ + f"Case: {label}
True Value: {val:.2f}
Realizations:" + f"{min(reals) if reals else None}-{max(reals) if reals else None}" + for label, val, reals in zip( + df["high_label"], df["true_high"], df["high_reals"] + ) + ], + hoverinfo="x+text", + orientation="h", + ), + ] + layout = {} + layout.update(plotly_theme["layout"]) + layout.update( + { "barmode": "relative", "margin": {"l": 50, "r": 50, "b": 20, "t": 50}, "xaxis": { @@ -448,6 +461,7 @@ def tornado_plot( "zeroline": False, "showline": False, "automargin": True, + "title": None, }, "showlegend": False, "annotations": [ @@ -467,5 +481,6 @@ def tornado_plot( "ay": -25, } ], - }, - } + } + ) + return {"data": plot_data, "layout": layout} diff --git a/webviz_subsurface/plugins/_inplace_volumes.py b/webviz_subsurface/plugins/_inplace_volumes.py index 688680b7a..e592e9571 100644 --- a/webviz_subsurface/plugins/_inplace_volumes.py +++ b/webviz_subsurface/plugins/_inplace_volumes.py @@ -82,7 +82,6 @@ def __init__( super().__init__() self.csvfile = csvfile if csvfile else None - self.colorway = app.webviz_settings.get("plotly_layout", {}).get("colorway", []) if csvfile and ensembles: raise ValueError( 'Incorrent arguments. Either provide a "csvfile" or "ensembles" and "volfiles"' @@ -109,7 +108,7 @@ def __init__( self.initial_response = response self.uid = uuid4() self.selectors_id = {x: str(uuid4()) for x in self.selectors} - + self.plotly_theme = app.webviz_settings["theme"].plotly_theme self.set_callbacks(app) def ids(self, element): @@ -426,7 +425,7 @@ def _render_vol_chart(*args): return ( { "data": plot_traces, - "layout": plot_layout(plot_type, response, colors=self.colorway), + "layout": plot_layout(plot_type, response, theme=self.plotly_theme), }, table, ) @@ -506,26 +505,33 @@ def plot_table(dframe, response, name): @CACHE.memoize(timeout=CACHE.TIMEOUT) -def plot_layout(plot_type, response, colors): +def plot_layout(plot_type, response, theme): + layout = {} + layout.update(theme["layout"]) + layout["height"] = 400 if plot_type == "Histogram": - output = { - "barmode": "overlay", - "bargap": 0.01, - "bargroupgap": 0.2, - "xaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)}, - "yaxis": {"title": "Count"}, - } + layout.update( + { + "barmode": "overlay", + "bargap": 0.01, + "bargroupgap": 0.2, + "xaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)}, + "yaxis": {"title": "Count"}, + } + ) elif plot_type == "Box plot": - output = {"yaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)}} + layout.update({"yaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)}}) else: - output = { - "margin": {"l": 40, "r": 40, "b": 30, "t": 10}, - "yaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)}, - "xaxis": {"title": "Realization"}, - } - output["height"] = 400 - output["colorway"] = colors - return output + layout.update( + { + "margin": {"l": 40, "r": 40, "b": 30, "t": 10}, + "yaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)}, + "xaxis": {"title": "Realization"}, + } + ) + + # output["colorway"] = colors + return layout @CACHE.memoize(timeout=CACHE.TIMEOUT) diff --git a/webviz_subsurface/plugins/_inplace_volumes_onebyone.py b/webviz_subsurface/plugins/_inplace_volumes_onebyone.py index 8e260b5c5..c22a2fe6f 100644 --- a/webviz_subsurface/plugins/_inplace_volumes_onebyone.py +++ b/webviz_subsurface/plugins/_inplace_volumes_onebyone.py @@ -147,6 +147,7 @@ def __init__( self.tornadoplot = TornadoPlot(app, realizations, allow_click=True) self.uid = uuid4() self.selectors_id = {x: self.ids(x) for x in self.selectors} + self.plotly_theme = app.webviz_settings["theme"].plotly_theme self.set_callbacks(app) def ids(self, element): @@ -477,8 +478,12 @@ def _render_vol_chart(plot_type, ensemble, response, source, *filters): table = calculate_table_rows(data, response) # Make Plotly figure + layout = {} + layout.update(self.plotly_theme["layout"]) + layout.update({"margin": {"l": 100}}) if plot_type == "Per realization": # One bar per realization + layout.update({"xaxis": {"title": "Realizations"}}) plot_data = data.groupby("REAL").sum().reset_index() figure = wcc.Graph( config={"displayModeBar": False}, @@ -492,11 +497,12 @@ def _render_vol_chart(plot_type, ensemble, response, source, *filters): "type": "bar", } ], - "layout": {"xaxis": {"title": "Realizations"}}, + "layout": layout, }, ) elif plot_type == "Box plot": # One box per sensitivity name + layout.update({"title": "Distribution for each sensitivity"}) figure = wcc.Graph( config={"displayModeBar": False}, id=self.ids("graph"), @@ -511,11 +517,9 @@ def _render_vol_chart(plot_type, ensemble, response, source, *filters): } for sensname, dframe in data.groupby(["SENSNAME"]) ], - "layout": {"title": "Distribution for each sensitivity"}, + "layout": layout, }, ) - else: - print(plot_type) tornado = json.dumps( { "ENSEMBLE": ensemble, diff --git a/webviz_subsurface/plugins/_parameter_correlation.py b/webviz_subsurface/plugins/_parameter_correlation.py index 7306be932..27dd17afa 100644 --- a/webviz_subsurface/plugins/_parameter_correlation.py +++ b/webviz_subsurface/plugins/_parameter_correlation.py @@ -43,6 +43,8 @@ def __init__(self, app, ensembles, drop_constants: bool = True): for ens in ensembles } self.drop_constants = drop_constants + self.plotly_theme = app.webviz_settings["theme"].plotly_theme + self.uid = uuid4() self.set_callbacks(app) @@ -191,7 +193,9 @@ def _update_matrix(ens, param1, param2): and it is not possible to assign callbacks to individual elements of a Plotly graph object """ - fig = render_matrix(ens, self.drop_constants) + fig = render_matrix( + ens, theme=self.plotly_theme, drop_constants=self.drop_constants + ) # Finds index of the currently selected cell x_index = list(fig["data"][0]["x"]).index(param1) y_index = list(fig["data"][0]["y"]).index(param2) @@ -223,7 +227,9 @@ def _update_matrix(ens, param1, param2): ], ) def _update_scatter(ens1, param1, ens2, param2, color, density): - return render_scatter(ens1, param1, ens2, param2, color, density) + return render_scatter( + ens1, param1, ens2, param2, color, density, self.plotly_theme + ) @app.callback( [ @@ -264,7 +270,7 @@ def get_parameters(ensemble_path) -> pd.DataFrame: @CACHE.memoize(timeout=CACHE.TIMEOUT) -def render_scatter(ens1, x_col, ens2, y_col, color, density): +def render_scatter(ens1, x_col, ens2, y_col, color, density, theme): if ens1 == ens2: real_text = [f"Realization:{r}" for r in get_parameters(ens1)["REAL"]] else: @@ -285,24 +291,8 @@ def render_scatter(ens1, x_col, ens2, y_col, color, density): "showlegend": False, } ) - data.append( - { - "x": x, - "type": "histogram", - "yaxis": "y2", - "showlegend": False, - "marker": {"color": "rgb(31, 119, 180)"}, - } - ) - data.append( - { - "y": y, - "type": "histogram", - "xaxis": "x2", - "showlegend": False, - "marker": {"color": "rgb(31, 119, 180)"}, - } - ) + data.append({"x": x, "type": "histogram", "yaxis": "y2", "showlegend": False}) + data.append({"y": y, "type": "histogram", "xaxis": "x2", "showlegend": False}) if density: data.append( { @@ -324,7 +314,6 @@ def render_scatter(ens1, x_col, ens2, y_col, color, density): ], "contours": { "coloring": "fill", - # 'end': 80.05, "showlines": True, "size": 5, "start": 5, @@ -339,6 +328,7 @@ def render_scatter(ens1, x_col, ens2, y_col, color, density): layout = { "margin": {"t": 20, "b": 50, "l": 200, "r": 200}, "bargap": 0.05, + "colorway": theme["layout"]["colorway"], "xaxis": { "title": x_col, "domain": [0, 0.85], @@ -394,10 +384,11 @@ def get_corr_data(ensemble_path, drop_constants=True): @CACHE.memoize(timeout=CACHE.TIMEOUT) -def render_matrix(ensemble_path, drop_constants=True): +def render_matrix(ensemble_path, theme, drop_constants=True): corrdf = get_corr_data(ensemble_path, drop_constants) # pylint: disable=no-member corrdf = corrdf.mask(np.tril(np.ones(corrdf.shape)).astype(np.bool)) + data = { "type": "heatmap", "x": corrdf.columns, @@ -405,8 +396,8 @@ def render_matrix(ensemble_path, drop_constants=True): "z": list(corrdf.values), "zmin": -1, "zmax": 1, + "colorscale": theme["layout"]["colorscale"]["sequential"], } - layout = { "paper_bgcolor": "rgba(0,0,0,0)", "plot_bgcolor": "rgba(0,0,0,0)", diff --git a/webviz_subsurface/plugins/_parameter_distribution.py b/webviz_subsurface/plugins/_parameter_distribution.py index f1b63912d..5f5673c03 100644 --- a/webviz_subsurface/plugins/_parameter_distribution.py +++ b/webviz_subsurface/plugins/_parameter_distribution.py @@ -60,6 +60,7 @@ def __init__(self, app, csvfile: Path = None, ensembles: list = None): if col not in ["REAL", "ENSEMBLE"] ] self.uid = uuid4() + self.plotly_theme = app.webviz_settings["theme"].plotly_theme self.set_callbacks(app) def ids(self, element): @@ -168,6 +169,7 @@ def _set_parameter(column): nbins=10, range_x=[param[column].min(), param[column].max()], marginal="box", + template=self.plotly_theme, ).for_each_trace(lambda t: t.update(name=t.name.replace("ENSEMBLE=", ""))) return plot diff --git a/webviz_subsurface/plugins/_parameter_response_correlation.py b/webviz_subsurface/plugins/_parameter_response_correlation.py index 52defcf48..733544c09 100644 --- a/webviz_subsurface/plugins/_parameter_response_correlation.py +++ b/webviz_subsurface/plugins/_parameter_response_correlation.py @@ -127,6 +127,7 @@ def __init__( inplace=True, ) + self.plotly_theme = app.webviz_settings["theme"].plotly_theme self.uid = uuid4() self.set_callbacks(app) @@ -387,7 +388,9 @@ def _update_correlation_graph(ensemble, response, *filters): corr_response = ( corrdf[response].dropna().drop(["REAL", response], axis=0) ) - return make_correlation_plot(corr_response, response, self.corr_method) + return make_correlation_plot( + corr_response, response, self.plotly_theme, self.corr_method + ) except KeyError: return { "layout": { @@ -423,7 +426,7 @@ def _update_distribution_graph(clickdata, ensemble, response, *filters): df = pd.merge(responsedf, parameterdf, on=["REAL"])[ ["REAL", parameter, response] ] - return make_distribution_plot(df, parameter, response) + return make_distribution_plot(df, parameter, response, self.plotly_theme) def add_webvizstore(self): if self.parameter_csv and self.response_csv: @@ -517,7 +520,7 @@ def _correlate(inputdf, method="pearson"): ) -def make_correlation_plot(series, response, corr_method): +def make_correlation_plot(series, response, theme, corr_method): """Make Plotly trace for correlation plot""" return { @@ -525,6 +528,7 @@ def make_correlation_plot(series, response, corr_method): {"x": series.values, "y": series.index, "orientation": "h", "type": "bar"} ], "layout": { + "colorway": theme["layout"]["colorway"], "barmode": "relative", "margin": {"l": 200, "r": 50, "b": 20, "t": 100}, "font": {"size": 8}, @@ -535,7 +539,7 @@ def make_correlation_plot(series, response, corr_method): } -def make_distribution_plot(df, parameter, response): +def make_distribution_plot(df, parameter, response, theme): """Make plotly traces for scatterplot and histograms for selected response and input parameter""" @@ -563,27 +567,14 @@ def make_distribution_plot(df, parameter, response): 1, ) fig.add_trace( - { - "type": "histogram", - "marker": {"color": "rgb(31, 119, 180)"}, - "x": df[parameter], - "showlegend": False, - }, - 3, - 1, + {"type": "histogram", "x": df[parameter], "showlegend": False,}, 3, 1, ) fig.add_trace( - { - "type": "histogram", - "marker": {"color": "rgb(31, 119, 180)"}, - "x": df[response], - "showlegend": False, - }, - 3, - 2, + {"type": "histogram", "x": df[response], "showlegend": False,}, 3, 2, ) fig["layout"].update( { + "colorway": theme["layout"]["colorway"], "height": 800, "bargap": 0.05, "xaxis": {"title": parameter,}, diff --git a/webviz_subsurface/plugins/_reservoir_simulation_timeseries.py b/webviz_subsurface/plugins/_reservoir_simulation_timeseries.py index 33f2073cc..88c02d90e 100644 --- a/webviz_subsurface/plugins/_reservoir_simulation_timeseries.py +++ b/webviz_subsurface/plugins/_reservoir_simulation_timeseries.py @@ -103,41 +103,13 @@ def __init__( ] self.ensembles = list(self.smry["ENSEMBLE"].unique()) - self.plotly_layout = app.webviz_settings["plotly_layout"] + self.plotly_theme = app.webviz_settings["theme"].plotly_theme self.plot_options = options if options else {} self.plot_options["date"] = ( str(self.plot_options.get("date")) if self.plot_options.get("date") else None ) - - colors = self.plotly_layout.get( - "colors", - [ - "#243746", - "#eb0036", - "#919ba2", - "#7d0023", - "#66737d", - "#4c9ba1", - "#a44c65", - "#80b7bc", - "#ff1243", - "#919ba2", - "#be8091", - "#b2d4d7", - "#ff597b", - "#bdc3c7", - "#d8b2bd", - "#ffe7d6", - "#d5eaf4", - "#ff88a1", - ], - ) - self.ens_colors = { - ens: colors[self.ensembles.index(ens)] for ens in self.ensembles - } - self.allow_delta = len(self.ensembles) > 1 self.uid = uuid4() self.set_callbacks(app) @@ -146,6 +118,37 @@ def ids(self, element): """Generate unique id for dom element""" return f"{element}-id-{self.uid}" + @property + def ens_colors(self): + try: + colors = self.plotly_theme["layout"]["colorway"] + except KeyError: + print("test") + colors = self.plotly_theme.get( + "colorway", + [ + "#243746", + "#eb0036", + "#919ba2", + "#7d0023", + "#66737d", + "#4c9ba1", + "#a44c65", + "#80b7bc", + "#ff1243", + "#919ba2", + "#be8091", + "#b2d4d7", + "#ff597b", + "#bdc3c7", + "#d8b2bd", + "#ffe7d6", + "#d5eaf4", + "#ff88a1", + ], + ) + return {ens: colors[self.ensembles.index(ens)] for ens in self.ensembles} + @property def tour_steps(self): return [ @@ -490,10 +493,10 @@ def _update_graph( # Add additional styling to layout fig["layout"].update( height=800, - font=self.plotly_layout.get("font"), - hoverlabel=self.plotly_layout.get("hoverlabel"), - paper_bgcolor=self.plotly_layout.get("paper_bgcolor", "rgba(0,0,0,0)"), - plot_bgcolor=self.plotly_layout.get("plot_bgcolor", "rgba(0,0,0,0)"), + font=self.plotly_theme.get("font"), + hoverlabel=self.plotly_theme.get("hoverlabel"), + paper_bgcolor=self.plotly_theme.get("paper_bgcolor", "rgba(0,0,0,0)"), + plot_bgcolor=self.plotly_theme.get("plot_bgcolor", "rgba(0,0,0,0)"), margin={"t": 20, "b": 0}, barmode="overlay", bargap=0.01, diff --git a/webviz_subsurface/plugins/_reservoir_simulation_timeseries_onebyone.py b/webviz_subsurface/plugins/_reservoir_simulation_timeseries_onebyone.py index d23c21610..46ddbf80a 100644 --- a/webviz_subsurface/plugins/_reservoir_simulation_timeseries_onebyone.py +++ b/webviz_subsurface/plugins/_reservoir_simulation_timeseries_onebyone.py @@ -141,6 +141,7 @@ def __init__( ) self.tornadoplot = TornadoPlot(app, realizations, allow_click=True) self.uid = uuid4() + self.plotly_theme = app.webviz_settings["theme"].plotly_theme self.set_callbacks(app) def ids(self, element): @@ -367,6 +368,8 @@ def _render_tornado(tornado_click, ensemble, vector, date_click, figure): # Redraw figure if ensemble/vector hanges if ctx == self.ids("ensemble") or ctx == self.ids("vector"): + layout = {} + layout.update(self.plotly_theme["layout"]) data = filter_ensemble(self.data, ensemble, vector) traces = [ { @@ -380,7 +383,7 @@ def _render_tornado(tornado_click, ensemble, vector, date_click, figure): for r, df in data.groupby(["REAL"]) ] traces[0]["hoverinfo"] = "x" - layout = {"showlegend": False} + layout.update({"showlegend": False, "margin": {"t": 50}}) figure = {"data": traces, "layout": layout} # Update line colors if a sensitivity is selected in tornado @@ -393,10 +396,14 @@ def _render_tornado(tornado_click, ensemble, vector, date_click, figure): else: for trace in figure["data"]: if trace["customdata"] in tornado_click["real_low"]: - trace["marker"] = {"color": "rgb(235, 0, 54)"} + trace["marker"] = { + "color": self.plotly_theme["layout"]["colorway"][0] + } trace["opacity"] = 1 elif trace["customdata"] in tornado_click["real_high"]: - trace["marker"] = {"color": "rgb(36, 55, 70)"} + trace["marker"] = { + "color": self.plotly_theme["layout"]["colorway"][1] + } trace["opacity"] = 1 else: trace["marker"] = {"color": "grey"}