Skip to content

Commit

Permalink
Improved PT using heatmap.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Nov 24, 2024
1 parent 155d881 commit 2de1a60
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 52 deletions.
3 changes: 2 additions & 1 deletion src/matpes/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def display_click_data(clickdata, el_filter):
"""
new_el_filter = el_filter or []
if clickdata:
new_el_filter = {*new_el_filter, Element.from_Z(clickdata["points"][0]["pointNumber"] + 1).symbol}
z = clickdata["points"][0]["text"].split("<")[0]
new_el_filter = {*new_el_filter, Element.from_Z(int(z)).symbol}
return list(new_el_filter)


Expand Down
122 changes: 71 additions & 51 deletions src/matpes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from pymatgen.core.periodic_table import Element


@functools.lru_cache
def get_pt_df() -> pd.DataFrame:
def get_pt_df(include_artificial=False) -> pd.DataFrame:
"""
Creates a DataFrame containing periodic table data.
Expand All @@ -34,7 +34,7 @@ def get_pt_df() -> pd.DataFrame:
"category": get_category(el),
}
for el in Element
if el.name not in ["D", "T"]
if (el.name not in ["D", "T"]) and (el.Z <= 92 or include_artificial)
]
df = pd.DataFrame(elements)
df["label"] = df.apply(lambda row: f"{row['Z']}<br>{row['symbol']}", axis=1)
Expand All @@ -52,7 +52,7 @@ def get_period(el: Element) -> int:
int: The adjusted period number.
"""
if el.is_actinoid or el.is_lanthanoid:
return el.row + 3
return el.row + 2
return el.row


Expand Down Expand Up @@ -100,81 +100,101 @@ def get_category(el: Element) -> str:
return "other"


def pt_heatmap(values: dict[str, float], label: str = "value", log: bool = False) -> px.scatter:
def pt_heatmap(
values: dict[str, float], label: str = "value", log: bool = False, include_artificial=False
) -> go.Figure:
"""
Generate a heatmap visualization of the periodic table.
Args:
values (dict[str, float]): Mapping of element symbols to values to visualize.
label (str): Label for the values displayed.
log (bool): Whether to use logarithmic scaling for the color axis.
include_artificial (bool): Whether to include artificial elements. Defaults to False.
Returns:
plotly.graph_objects.Figure: A scatter plot representing the heatmap.
"""
df = get_pt_df()
df[label] = df["symbol"].map(values)
hover_data = {
"Z": False,
"name": False,
"label": False,
label: True,
"X": False,
"group": False,
"period": False,
}
df = get_pt_df(include_artificial=include_artificial)
df[label] = df["symbol"].map(values) if values else df["X"]
if log:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
df[f"log10_{label}"] = np.log10(df[label])
hover_data[f"log10_{label}"] = False

fig = px.scatter(
df,
x="group",
y="period",
color=label if not log else f"log10_{label}",
text="label",
hover_data=hover_data,
color_continuous_scale=px.colors.sequential.Viridis,
)

fig.update_traces(
marker=dict(
symbol="square",
size=40,
line=dict(color="black", width=1),
# Initialize periodic table grid
grid = np.full((9, 18), None, dtype=np.float_)
label_texts = np.full((9, 18), "", dtype=object)
hover_texts = np.full((9, 18), "", dtype=object)

# Fill grid with element symbols, hover text, and category colors
for _index, row in df.iterrows():
group, period = row["group"], row["period"]
grid[period - 1, group - 1] = row[label] if not log else row[f"log10_{label}"]
label_texts[period - 1, group - 1] = f'{row["Z"]}<br>{row["symbol"]}<br>{row[label]}'
hover_texts[period - 1, group - 1] = f'{row["Z"]}<br>{row["name"]}<br>{row[label]}'

# Create the plot
fig = go.Figure()
fig.add_trace(
go.Heatmap(
z=grid,
x=list(range(1, 19)),
y=list(range(1, 9)),
text=hover_texts,
hoverinfo="text",
showscale=True,
colorscale="Viridis",
xgap=1,
ygap=1,
coloraxis="coloraxis",
)
)

fig.update_layout(
xaxis=dict(title=None, range=[0.5, 18.5], dtick=1),
yaxis=dict(title=None, range=[0.5, 10.5], dtick=1, autorange="reversed"),
showlegend=False,
plot_bgcolor="white",
width=1080,
height=640,
font=dict(
family="Arial",
size=14,
color="black",
weight="bold",
),
)

# Add annotations for element symbols
for _index, row in df.iterrows():
group, period = row["group"], row["period"]
fig.add_annotation(
x=group,
y=period,
text=label_texts[period - 1, group - 1],
showarrow=False,
font=dict(
family="Arial",
size=14,
color="black",
weight="bold",
),
align="center",
)
# Hide x-axis
fig.update_xaxes(showticklabels=False, showgrid=False)

# Hide y-axis
fig.update_yaxes(showticklabels=False, showgrid=False)

# Update layout
fig.update_layout(
title=None,
xaxis=dict(title=None), # Maintain 1:1 aspect ratio
yaxis=dict(title=None, scaleanchor="x", scaleratio=1.33, autorange="reversed"),
plot_bgcolor="white",
width=1200,
height=900,
)

if log:
max_log = int(df[f"log10_{label}"].max())
fig.update_layout(
coloraxis_colorbar=dict(
title=label,
tickvals=list(range(1, max_log + 1)),
ticktext=[f"10^{i}" for i in range(1, max_log + 1)],
coloraxis=dict(
colorbar=dict(
title=label,
tickmode="array",
tickvals=list(range(1, max_log + 1)),
ticktext=[f"1e{i}" for i in range(1, max_log + 1)],
tickfont=dict(size=14, family="Arial", color="black"),
)
)
)

return fig

0 comments on commit 2de1a60

Please sign in to comment.