From cee872e47dffb73eab8fc88723597f3f2f528d48 Mon Sep 17 00:00:00 2001 From: Shyue Ping Ong Date: Wed, 20 Nov 2024 10:43:54 -0800 Subject: [PATCH] Much improved matpes explorer. --- notebooks/Insert Calculations.ipynb | 76 +++++++++---------- pyproject.toml | 11 ++- requirements-ci.txt | 7 -- requirements.txt | 1 - src/matpes/__init__.py | 0 app.py => src/matpes/ui.py | 30 +++++--- src/matpes/utils.py | 113 ++++++++++++++++++++++++++++ 7 files changed, 178 insertions(+), 60 deletions(-) delete mode 100644 requirements-ci.txt delete mode 100644 requirements.txt create mode 100644 src/matpes/__init__.py rename app.py => src/matpes/ui.py (87%) create mode 100644 src/matpes/utils.py diff --git a/notebooks/Insert Calculations.ipynb b/notebooks/Insert Calculations.ipynb index b60c32d..7ae7baa 100644 --- a/notebooks/Insert Calculations.ipynb +++ b/notebooks/Insert Calculations.ipynb @@ -2,26 +2,27 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, - "id": "1671fa59-c236-4401-8548-afaee3e85fcd", + "execution_count": null, + "id": "0", "metadata": {}, "outputs": [], "source": [ "from __future__ import annotations\n", + "\n", "import os\n", - "from tqdm import tqdm\n", "import warnings\n", "\n", "from monty.serialization import loadfn\n", "from pymongo import MongoClient\n", + "from tqdm import tqdm\n", "\n", - "warnings.simplefilter('ignore')" + "warnings.simplefilter(\"ignore\")" ] }, { "cell_type": "code", - "execution_count": 3, - "id": "870765b8-719c-4405-b489-ba3f2733b10f", + "execution_count": null, + "id": "1", "metadata": {}, "outputs": [], "source": [ @@ -31,8 +32,8 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "5e60c4e6-c868-4dc1-ac1e-ed25ceca96ba", + "execution_count": null, + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -45,13 +46,13 @@ " Args:\n", " functional (str): The name of the functional. This is used to locate\n", " the appropriate file for loading the data and also defines the MongoDB collection name.\n", - " \n", + "\n", " Raises:\n", " FileNotFoundError: If the designated file does not exist or cannot be accessed.\n", "\n", " Example:\n", " To create the database and indexes for a given functional 'pbe', run:\n", - " \n", + "\n", " ```python\n", " make_db(\"pbe\")\n", " ```\n", @@ -60,7 +61,7 @@ " -------------\n", " 1. Load Data:\n", " - Reads data from a gzipped JSON file for the specified `functional`.\n", - " \n", + "\n", " 2. Extract & Process Fields:\n", " - Each dataset entry is extracted and processed to include information such as:\n", " - `matpesid`: A unique identifier for the material.\n", @@ -72,12 +73,12 @@ " - `composition`: Dictionary depicting the element counts in the structure.\n", " - `formation_energy_per_atom`: Energy per atom (derived from `formation_energy`).\n", " - `structure`: The structure in dictionary format.\n", - " \n", + "\n", " 3. Store Data in MongoDB:\n", " - Deletes any existing records in the collection corresponding to `functional`.\n", " - Inserts the processed records.\n", - " \n", - " 4. Create Indexes: \n", + "\n", + " 4. Create Indexes:\n", " - Indexes are created on the following fields to optimize searching:\n", " - `natoms`\n", " - `elements`\n", @@ -99,7 +100,7 @@ " - nelements: int\n", " Number of distinct chemical elements.\n", " - chemsys:\n", - " String representation of the elements in the chemical system, \n", + " String representation of the elements in the chemical system,\n", " sorted alphabetically (e.g., 'H-O').\n", " - formula: str\n", " The reduced chemical formula of the material (e.g., 'H2O').\n", @@ -109,11 +110,11 @@ " Formation energy per atom for the material (extracted from `formation_energy`).\n", " - structure: dict\n", " The detailed structure of the material in dictionary format.\n", - " \n", + "\n", " Indexes:\n", " --------\n", " The created MongoDB indexes optimize the following fields:\n", - " \n", + "\n", " - `natoms`: Number of atoms per structure.\n", " - `elements`: Chemical elements present in the structure.\n", " - `nelements`: Number of distinct elements in the structure.\n", @@ -127,15 +128,14 @@ " - The JSON file path is specific to the user's system configuration.\n", "\n", " \"\"\"\n", - " \n", " raw = loadfn(os.path.expanduser(f\"~/Desktop/2024_11_18_MatPES-20240214-{functional}-training-data.json.gz\"))\n", " data = []\n", - " \n", + "\n", " for k, v in tqdm(raw.items()):\n", " # Combine IDs and structure information\n", " d = {\"matpesid\": k} | v\n", " comp = d[\"structure\"].composition\n", - " \n", + "\n", " # Populate additional fields based on composition\n", " d[\"natoms\"] = len(d[\"structure\"])\n", " d[\"elements\"] = list(comp.chemical_system_set)\n", @@ -144,37 +144,37 @@ " d[\"formula\"] = comp.reduced_formula\n", " d[\"composition\"] = {el.symbol: amt for el, amt in comp.items()}\n", " d[\"structure\"] = d[\"structure\"].as_dict()\n", - " \n", + "\n", " # Restructure formation energy data\n", " d[\"formation_energy_per_atom\"] = d[\"formation_energy\"]\n", " del d[\"formation_energy\"]\n", - " \n", + "\n", " # Add processed entry to list\n", " data.append(d)\n", - " \n", + "\n", " # Get collection from DB and clear old data\n", " collection = db[functional]\n", " collection.delete_many({})\n", - " \n", + "\n", " # new data\n", " collection.insert_many(data)\n", - " \n", + "\n", " # Create indexes for optimized query performance\n", " for k in [\"natoms\", \"elements\", \"nelements\", \"chemsys\", \"formula\", \"matpesid\"]:\n", - " collection.create_index(k)\n" + " collection.create_index(k)" ] }, { "cell_type": "code", "execution_count": null, - "id": "e7eb2367-d29b-4de2-91bb-489906689761", + "id": "3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - " 62%|█████████████████████████████████████████████████▎ | 271161/434712 [01:07<01:16, 2149.16it/s]" + "100%|███████████████████████████████████████████████████████████████████████████████| 434712/434712 [03:45<00:00, 1930.57it/s]\n" ] } ], @@ -185,20 +185,20 @@ { "cell_type": "code", "execution_count": null, - "id": "d167bd6c-197a-4e95-9267-c7b10b0ad6ef", + "id": "4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████████████████████████████████████████| 387897/387897 [02:50<00:00, 2278.72it/s]\n" + ] + } + ], "source": [ "make_db(\"r2SCAN\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "89b44d7f-b5e1-4f0f-bff8-f6926a6240c4", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index fd925a1..063398c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,9 +47,16 @@ dependencies = [ "pymatgen", "dash", "dash_bootstrap_components", - "pymatviz","pymongo" + "pymongo" ] -version = "2024.11.13" +version = "0.0.1" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["matpes", "matpes.*"] + +[project.scripts] +matpes_explorer = "matpes.ui:main" [tool.versioningit.vcs] method = "git" diff --git a/requirements-ci.txt b/requirements-ci.txt deleted file mode 100644 index 7c01d3d..0000000 --- a/requirements-ci.txt +++ /dev/null @@ -1,7 +0,0 @@ -pytest -pytest-cov -coverage -coveralls -mypy -ruff -black diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 7e03e4f..0000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -dash \ No newline at end of file diff --git a/src/matpes/__init__.py b/src/matpes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app.py b/src/matpes/ui.py similarity index 87% rename from app.py rename to src/matpes/ui.py index 83ed38a..ca4c6ce 100644 --- a/app.py +++ b/src/matpes/ui.py @@ -10,11 +10,13 @@ import dash_bootstrap_components as dbc import pandas as pd import plotly.express as px -import pymatviz as pmv +import plotly.figure_factory as ff from dash import Dash, Input, Output, State, callback, dcc, html from pymatgen.core import Element from pymongo import MongoClient +from matpes.utils import get_pt_heatmap + FUNCTIONALS = ("PBE", "r2SCAN") # Set up MongoDB client and database @@ -56,7 +58,7 @@ def get_data(functional, el, chemsys): # Initialize the Dash app with a Bootstrap theme external_stylesheets = [dbc.themes.CERULEAN] -app = Dash(__name__, external_stylesheets=external_stylesheets) +app = Dash("MatPES Explorer", external_stylesheets=external_stylesheets) # Define the app layout app.layout = dbc.Container( @@ -67,7 +69,7 @@ def get_data(functional, el, chemsys): dbc.Col( [ html.Label("Functional"), - dcc.RadioItems( + dcc.Dropdown( options=[{"label": f, "value": f} for f in FUNCTIONALS], value="PBE", id="functional" ), ], @@ -119,6 +121,13 @@ def get_data(functional, el, chemsys): ) +def get_dist_plot(data, label, nbins=100): + fig = ff.create_distplot([data], [label], show_rug=False) + + fig.update_layout(xaxis=dict(title=label), showlegend=False) + return fig + + # Define callback to update the heatmap based on selected functional @callback( [ @@ -135,16 +144,13 @@ def update_graph(functional, el_filter, chemsys_filter): df = get_data(functional, el_filter, chemsys_filter) el_count = {el.symbol: 0 for el in Element} el_count.update(collections.Counter(itertools.chain(*df["elements"]))) - heatmap_figure = pmv.ptable_heatmap_plotly(el_count) + heatmap_figure = get_pt_heatmap(el_count, label="Count", log=True) return ( heatmap_figure, - px.histogram( - df, x="cohesive_energy_per_atom", labels={"cohesive_energy_per_atom": "Cohesive Energy per Atom (eV/atom)"} - ), - px.histogram( - df, - x="formation_energy_per_atom", - labels={"formation_energy_per_atom": "Formation Energy per Atom (eV/atom)"}, + get_dist_plot(df["cohesive_energy_per_atom"], "Cohesive Energy per Atom (eV/atom)"), + get_dist_plot( + df["formation_energy_per_atom"].dropna(), + "Formation Energy per Atom (eV/atom)", ), px.histogram(df, x="natoms"), px.histogram(df, x="nelements"), @@ -174,5 +180,5 @@ def download(n_clicks, functional, el_filter, chemsys_filter): # Run the app -if __name__ == "__main__": +def main(): app.run(debug=True) diff --git a/src/matpes/utils.py b/src/matpes/utils.py new file mode 100644 index 0000000..34b8b33 --- /dev/null +++ b/src/matpes/utils.py @@ -0,0 +1,113 @@ +import plotly.express as px +import pandas as pd +from pymatgen.core.periodic_table import Element, _pt_data +import functools +import numpy as np +import warnings +# Define periodic table data + +@functools.lru_cache +def get_pt_df(): + """ + Returns a DataFrame with PT data. + """ + with warnings.catch_warnings(): + df = pd.DataFrame([{"symbol": el.symbol, "name": el.long_name, "Z": el.Z, "X": el.X, "group": get_group(el), + "period": get_period(el), "category": get_category(el)} for el in Element]) + + # Create hover text for each element + df["label"] = df.apply( + lambda row: f"{row['Z']}
{row['symbol']}", + axis=1 + ) + + return df + + +def get_period(el): + """ + Special handling of period for pt plotting of rare earths. + """ + if el.is_actinoid or el.is_lanthanoid: + return el.row + 3 + return el.row + + +def get_group(el): + """ + Special handling of group for pt plotting of rare earths. + """ + if el.is_actinoid: + return el.group + el.Z - 89 + if el.is_lanthanoid: + return el.group + el.Z - 57 + return el.group + + +def get_category(el): + if el.Z > 92: + return "transuranic" + for cat in ["alkali", "alkaline", "actinoid", "lanthanoid", "halogen", "noble_gas", "metal", "chalcogen"]: + if getattr(el, f"is_{cat}"): + return cat + return "" + + +def get_pt_heatmap(values, label="value", log=False): + """ + Args: + values (dict[str, float]): Values to plot. + label (str): Label for values. + """ + df = get_pt_df() + df[label] = df.apply( + lambda row: values.get(row['symbol'], 0), + axis=1 + ) + if log: + df[f"log10_{label}"] = np.log10(df[label]) + + + + # Create the plot + fig = px.scatter( + df, + x="group", + y="period", + color=label if not log else f"log10_{label}", + text="label", + # hover_name=label, + hover_data={"Z": False, "name": False, "label": False, label: True, "X": False, "group": False, + "period": False, f"log10_{label}": False}, + color_continuous_scale=px.colors.sequential.Viridis, + ) + + fig.update_traces(marker=dict( + symbol='square', + size=40, + line=dict(color="black", width=1), + )) + + # Update layout + fig.update_layout( + xaxis=dict(title="Group", range=[0.5, 18.5], dtick=1), + yaxis=dict(title="Period", range=[0.5, 10.5], dtick=1, autorange="reversed"), + showlegend=False, + plot_bgcolor="white", + width=1100, + height=650, + font=dict( + family='Arial', + size=14, + color='black', + weight='bold' # Make the text bold + ) + ) + if log: + max_log = int(df[f"log10_{label}"].max()) + fig.update_layout(coloraxis_colorbar=dict( + title=label, + tickvals=list(range(1, max_log+1)), + ticktext=[f"10^{i}" for i in range(1, max_log+1)] + )) + return fig