Skip to content

Commit

Permalink
Add data downselector.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Nov 20, 2024
1 parent 910db52 commit 080ef28
Showing 1 changed file with 97 additions and 21 deletions.
118 changes: 97 additions & 21 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import collections
import functools
import itertools
import json

import dash_bootstrap_components as dbc
import pandas as pd
import plotly.express as px
import pymatviz as pmv
from dash import Dash, Input, Output, callback, dcc, html
from dash import Dash, Input, Output, callback, dcc, html, State
from pymatgen.core import Element
from pymongo import MongoClient

Expand All @@ -19,23 +20,37 @@
# Set up MongoDB client and database
CLIENT = MongoClient()
DB = CLIENT["matpes"]
# print(DB["PBE"].find_one())
RAW_DATA = {}
for f in FUNCTIONALS:
collection = DB[f]
RAW_DATA[f] = pd.DataFrame(


@functools.lru_cache
def get_df(functional: str) -> pd.DataFrame:
collection = DB[functional]
return pd.DataFrame(
collection.find(
{}, projection=["elements", "energy", "cohesive_energy_per_atom", "formation_energy", "natoms", "nelements"]
{},
projection=[
"elements",
"energy",
"chemsys",
"cohesive_energy_per_atom",
"formation_energy_per_atom",
"natoms",
"nelements",
],
)
)


@functools.lru_cache
def get_data(functional, el):
def get_data(functional, el, chemsys):
"""Filter data with caching for improved performance."""
df = RAW_DATA[functional]
df = get_df(functional)
if el is not None:
df = df[df["elements"].apply(lambda x: el in x)]
if chemsys:
chemsys = "-".join(sorted(chemsys.split("-")))
df = df[df["chemsys"] == chemsys]

return df


Expand All @@ -49,13 +64,45 @@ def get_data(functional, el):
dbc.Row([html.Div("MatPES Explorer", className="text-primary text-center fs-3")]),
dbc.Row(
[
html.Label("Functional"),
dcc.RadioItems(options=[{"label": f, "value": f} for f in FUNCTIONALS], value="PBE", id="functional"),
html.Label("Element Filter"),
dcc.Dropdown(options=[{"label": el.symbol, "value": el.symbol} for el in Element], id="el_filter"),
dbc.Col(
[
html.Label("Functional"),
dcc.RadioItems(
options=[{"label": f, "value": f} for f in FUNCTIONALS], value="PBE", id="functional"
),
],
width=2,
),
dbc.Col(
[
html.Label("Element Filter"),
dcc.Dropdown(
options=[{"label": el.symbol, "value": el.symbol} for el in Element], id="el_filter"
),
],
width=2,
),
dbc.Col(
[
html.Label("Chemsys Filter"),
dcc.Input(id="chemsys_filter", placeholder="Li-Fe-O"),
],
width=2,
),
dbc.Col(
[html.Div([html.Button("Download", id="btn-download"), dcc.Download(id="download-data")])],
width=2,
),
]
),
dbc.Row([dcc.Graph(id="ptheatmap")]),
dbc.Row(
[
dbc.Col(
[dcc.Graph(id="ptheatmap", style={"marginLeft": "auto", "marginRight": "auto"})],
width={"size": 8, "order": "last", "offset": 2},
)
],
),
dbc.Row(
[
dbc.Col([dcc.Graph(id="coh_energy_hist")], width=6),
Expand All @@ -81,22 +128,51 @@ def get_data(functional, el):
Output("natoms_hist", "figure"),
Output("nelements_hist", "figure"),
],
[Input("functional", "value"), Input("el_filter", "value")],
[Input("functional", "value"), Input("el_filter", "value"), Input("chemsys_filter", "value")],
)
def update_graph(functional, el_filter):
def update_graph(functional, el_filter, chemsys_filter):
"""Update graph based on input."""
df = get_data(functional, el_filter)
el_count = collections.Counter(itertools.chain(*df["elements"]))
heatmap_figure = pmv.ptable_heatmap_plotly(el_count, log=True)
df = get_data(functional, el_filter, chemsys_filter)
el_count = {el.symbol: 0 for el in Element}
el_count.update(collections.Counter(itertools.chain(*df["elements"])))
heatmap_figure = pmv.ptable_heatmap_plotly(el_count)
return (
heatmap_figure,
px.histogram(df, x="cohesive_energy_per_atom"),
px.histogram(df, x="formation_energy"),
px.histogram(
df, x="cohesive_energy_per_atom", labels={"cohesive_energy_per_atom": "Cohesive Energy per Atom (eV/atom)"}
),
px.histogram(
df,
x="formation_energy_per_atom",
labels={"formation_energy_per_atom": "Formation Energy per Atom (eV/atom)"},
),
px.histogram(df, x="natoms"),
px.histogram(df, x="nelements"),
)


@callback(
Output("download-data", "data"),
Input("btn-download", "n_clicks"),
State("functional", "value"),
State("el_filter", "value"),
State("chemsys_filter", "value"),
prevent_initial_call=True,
)
def download(n_clicks, functional, el_filter, chemsys_filter):
collection = DB[functional]
criteria = {}
if el_filter is not None:
criteria["elements"] = el_filter
if chemsys_filter is not None:
chemsys = "-".join(sorted(chemsys_filter.split("-")))
criteria["chemsys"] = chemsys
data = list(collection.find(criteria))
for d in data:
del d["_id"]
return dict(content=json.dumps(data), filename=f"matpes_{functional}_{el_filter}_{chemsys_filter}.json")


# Run the app
if __name__ == "__main__":
app.run(debug=True)

0 comments on commit 080ef28

Please sign in to comment.