Skip to content

Commit

Permalink
Clickable PT.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Nov 21, 2024
1 parent 376c72a commit f688492
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 26 deletions.
143 changes: 117 additions & 26 deletions src/matpes/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,26 @@ def get_full_data(functional: str) -> pd.DataFrame:
return DB.get_df(functional)


def get_data(functional: str, element_filter: list, chemsys_filter: str, bandgap_filter) -> pd.DataFrame:
def get_data(
functional: str,
element_filter: list,
chemsys_filter: str,
min_coh_e_filter,
max_coh_e_filter,
min_form_e_filter,
max_form_e_filter,
) -> pd.DataFrame:
"""
Filter data based on the selected functional, element, and chemical system.
Args:
functional (str): Functional to filter data for.
element_filter (list | None): Elements to filter (if any).
chemsys_filter (str | None): Chemical system to filter (if any).
bandgap_filter (str | None): Bandgap to filter (not used at the moment).
min_coh_e_filter (float): Minimum cohesive energy filter.
max_coh_e_filter (float): Maximum cohesive energy filter.
min_form_e_filter (float): Minimum form energy filter.
max_form_e_filter (float): Maximum form energy filter.
Returns:
pd.DataFrame: Filtered data.
Expand All @@ -51,9 +62,29 @@ def get_data(functional: str, element_filter: list, chemsys_filter: str, bandgap
sorted_chemsys = "-".join(sorted(chemsys_filter.split("-")))
df = df[df["chemsys"] == sorted_chemsys]

# df = df[bandgap_filter[0] <= df["bandgap"]]
# df = df[df["bandgap"] <= bandgap_filter[1]]
return df
df = df[min_coh_e_filter <= df["cohesive_energy_per_atom"]]
df = df[df["cohesive_energy_per_atom"] <= max_coh_e_filter]
df = df[min_form_e_filter <= df["formation_energy_per_atom"]]
return df[df["formation_energy_per_atom"] <= max_form_e_filter]


@callback(
[
Output("min_coh_e_filter", "value"),
Output("max_coh_e_filter", "value"),
Output("min_form_e_filter", "value"),
Output("max_form_e_filter", "value"),
],
[
Input("functional", "value"),
],
)
def update_sliders(functional):
"""Update sliders based on functional."""
df = get_full_data(functional)
coh_energy = df["cohesive_energy_per_atom"]
form_energy = df["formation_energy_per_atom"]
return coh_energy.min(), coh_energy.max(), form_energy.min(), form_energy.max()


@callback(
Expand All @@ -68,12 +99,19 @@ def get_data(functional: str, element_filter: list, chemsys_filter: str, bandgap
Input("functional", "value"),
Input("el_filter", "value"),
Input("chemsys_filter", "value"),
# Input("bandgap_filter", "value")
Input("min_coh_e_filter", "value"),
Input("max_coh_e_filter", "value"),
Input("min_form_e_filter", "value"),
Input("max_form_e_filter", "value"),
],
)
def update_graph(functional, el_filter, chemsys_filter, bandgap_filter):
def update_graph(
functional, el_filter, chemsys_filter, min_coh_e_filter, max_coh_e_filter, min_form_e_filter, max_form_e_filter
):
"""Update graphs based on user inputs."""
df = get_data(functional, el_filter, chemsys_filter, bandgap_filter)
df = get_data(
functional, el_filter, chemsys_filter, min_coh_e_filter, max_coh_e_filter, min_form_e_filter, max_form_e_filter
)
element_counts = collections.Counter(itertools.chain(*df["elements"]))
heatmap_figure = pt_heatmap(element_counts, label="Count", log=True)
return (
Expand Down Expand Up @@ -119,6 +157,20 @@ def download_data(n_clicks, functional, el_filter, chemsys_filter):
)


@callback(Output("el_filter", "value"), Input("ptheatmap", "clickData"), State("el_filter", "value"))
def display_click_data(clickdata, el_filter):
"""
Update el filter when PT table is clicked.
Args:
clickdata (dict): Click data.
el_filter (dict): Element filter.
"""
el_filter = el_filter or []
new_el_filter = {*el_filter, Element.from_Z(clickdata["points"][0]["pointNumber"] + 1).symbol}
return list(new_el_filter)


def main():
"""Main entry point for MatPES Explorer UI."""
app = Dash("MatPES Explorer", external_stylesheets=[dbc.themes.CERULEAN], title="MatPES Explorer")
Expand Down Expand Up @@ -146,7 +198,7 @@ def main():
),
dbc.Row(
[
html.Div("Explorer", className="text-primary text-center fs-3"),
html.H2("Explorer", className="text-primary text-center fs-3"),
]
),
dbc.Row(
Expand All @@ -161,48 +213,87 @@ def main():
clearable=False,
),
],
width=2,
width=4,
)
]
),
dbc.Row(
[
dbc.Col(
[
html.Div("Filters: "),
],
width=1,
),
dbc.Col(
[
html.Label("Filter by Element(s)"),
html.Label("Element(s)"),
dcc.Dropdown(
id="el_filter",
options=[{"label": el.symbol, "value": el.symbol} for el in Element],
options=[
{"label": el.symbol, "value": el.symbol}
for el in Element
if el.name not in ("D", "T")
],
multi=True,
),
],
width=2,
),
dbc.Col(
[
html.Div("Filter by Chemsys"),
html.Div("Chemsys"),
dcc.Input(
id="chemsys_filter",
placeholder="Li-Fe-O",
),
],
width=2,
),
# dbc.Col(
# [
# html.Div("Bandgap", className="text-center"),
# dcc.RangeSlider(0, 10, 0.1,
# marks={i: str(i) for i in range(0, 10)},
# value=[0, 10], id='bandgap_filter'),
# ],
# width=2,
# ),
dbc.Col(
[
html.Label("Data Tools"),
html.Button("Download", id="btn-download"),
dcc.Download(id="download-data"),
html.Div("Coh. Energy", className="text-center"),
dcc.Input(0, type="number", id="min_coh_e_filter"),
dcc.Input(10, type="number", id="max_coh_e_filter"),
],
width=1,
width=2,
),
dbc.Col(
[
html.Div("Form. Energy", className="text-center"),
dcc.Input(0, type="number", id="min_form_e_filter"),
dcc.Input(10, type="number", id="max_form_e_filter"),
],
width=2,
),
]
),
dbc.Col(
[
html.Button("Download", id="btn-download"),
dcc.Download(id="download-data"),
],
width=1,
),
html.Div(
[
html.P("Help:"),
html.Ul(
[
html.Li("Clicking on the PT adds an element to the element filter."),
html.Li(
"Element filter is restrictive, i.e., only data containing all selected elements + any "
"other elements are shown."
),
html.Li(
"Chemsys filter: Only data within the chemsys are shown. Typically you should only"
" use either element or chemsys but not both."
),
]
),
],
style={"padding": 5},
),
dbc.Row(
dbc.Col(
dcc.Graph(id="ptheatmap", style={"marginLeft": "auto", "marginRight": "auto"}),
Expand Down
1 change: 1 addition & 0 deletions src/matpes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_pt_df() -> pd.DataFrame:
"category": get_category(el),
}
for el in Element
if el.name not in ["D", "T"]
]
df = pd.DataFrame(elements)
df["label"] = df.apply(lambda row: f"{row['Z']}<br>{row['symbol']}", axis=1)
Expand Down

0 comments on commit f688492

Please sign in to comment.