From f6884924cdd9d0623c8eb694a749b522c2dbfae2 Mon Sep 17 00:00:00 2001 From: Shyue Ping Ong Date: Thu, 21 Nov 2024 13:29:43 -0800 Subject: [PATCH] Clickable PT. --- src/matpes/ui.py | 143 ++++++++++++++++++++++++++++++++++++-------- src/matpes/utils.py | 1 + 2 files changed, 118 insertions(+), 26 deletions(-) diff --git a/src/matpes/ui.py b/src/matpes/ui.py index 92c4293..24cdcfb 100644 --- a/src/matpes/ui.py +++ b/src/matpes/ui.py @@ -31,7 +31,15 @@ def get_full_data(functional: str) -> pd.DataFrame: return DB.get_df(functional) -def get_data(functional: str, element_filter: list, chemsys_filter: str, bandgap_filter) -> pd.DataFrame: +def get_data( + functional: str, + element_filter: list, + chemsys_filter: str, + min_coh_e_filter, + max_coh_e_filter, + min_form_e_filter, + max_form_e_filter, +) -> pd.DataFrame: """ Filter data based on the selected functional, element, and chemical system. @@ -39,7 +47,10 @@ def get_data(functional: str, element_filter: list, chemsys_filter: str, bandgap functional (str): Functional to filter data for. element_filter (list | None): Elements to filter (if any). chemsys_filter (str | None): Chemical system to filter (if any). - bandgap_filter (str | None): Bandgap to filter (not used at the moment). + min_coh_e_filter (float): Minimum cohesive energy filter. + max_coh_e_filter (float): Maximum cohesive energy filter. + min_form_e_filter (float): Minimum form energy filter. + max_form_e_filter (float): Maximum form energy filter. Returns: pd.DataFrame: Filtered data. @@ -51,9 +62,29 @@ def get_data(functional: str, element_filter: list, chemsys_filter: str, bandgap sorted_chemsys = "-".join(sorted(chemsys_filter.split("-"))) df = df[df["chemsys"] == sorted_chemsys] - # df = df[bandgap_filter[0] <= df["bandgap"]] - # df = df[df["bandgap"] <= bandgap_filter[1]] - return df + df = df[min_coh_e_filter <= df["cohesive_energy_per_atom"]] + df = df[df["cohesive_energy_per_atom"] <= max_coh_e_filter] + df = df[min_form_e_filter <= df["formation_energy_per_atom"]] + return df[df["formation_energy_per_atom"] <= max_form_e_filter] + + +@callback( + [ + Output("min_coh_e_filter", "value"), + Output("max_coh_e_filter", "value"), + Output("min_form_e_filter", "value"), + Output("max_form_e_filter", "value"), + ], + [ + Input("functional", "value"), + ], +) +def update_sliders(functional): + """Update sliders based on functional.""" + df = get_full_data(functional) + coh_energy = df["cohesive_energy_per_atom"] + form_energy = df["formation_energy_per_atom"] + return coh_energy.min(), coh_energy.max(), form_energy.min(), form_energy.max() @callback( @@ -68,12 +99,19 @@ def get_data(functional: str, element_filter: list, chemsys_filter: str, bandgap Input("functional", "value"), Input("el_filter", "value"), Input("chemsys_filter", "value"), - # Input("bandgap_filter", "value") + Input("min_coh_e_filter", "value"), + Input("max_coh_e_filter", "value"), + Input("min_form_e_filter", "value"), + Input("max_form_e_filter", "value"), ], ) -def update_graph(functional, el_filter, chemsys_filter, bandgap_filter): +def update_graph( + functional, el_filter, chemsys_filter, min_coh_e_filter, max_coh_e_filter, min_form_e_filter, max_form_e_filter +): """Update graphs based on user inputs.""" - df = get_data(functional, el_filter, chemsys_filter, bandgap_filter) + df = get_data( + functional, el_filter, chemsys_filter, min_coh_e_filter, max_coh_e_filter, min_form_e_filter, max_form_e_filter + ) element_counts = collections.Counter(itertools.chain(*df["elements"])) heatmap_figure = pt_heatmap(element_counts, label="Count", log=True) return ( @@ -119,6 +157,20 @@ def download_data(n_clicks, functional, el_filter, chemsys_filter): ) +@callback(Output("el_filter", "value"), Input("ptheatmap", "clickData"), State("el_filter", "value")) +def display_click_data(clickdata, el_filter): + """ + Update el filter when PT table is clicked. + + Args: + clickdata (dict): Click data. + el_filter (dict): Element filter. + """ + el_filter = el_filter or [] + new_el_filter = {*el_filter, Element.from_Z(clickdata["points"][0]["pointNumber"] + 1).symbol} + return list(new_el_filter) + + def main(): """Main entry point for MatPES Explorer UI.""" app = Dash("MatPES Explorer", external_stylesheets=[dbc.themes.CERULEAN], title="MatPES Explorer") @@ -146,7 +198,7 @@ def main(): ), dbc.Row( [ - html.Div("Explorer", className="text-primary text-center fs-3"), + html.H2("Explorer", className="text-primary text-center fs-3"), ] ), dbc.Row( @@ -161,14 +213,28 @@ def main(): clearable=False, ), ], - width=2, + width=4, + ) + ] + ), + dbc.Row( + [ + dbc.Col( + [ + html.Div("Filters: "), + ], + width=1, ), dbc.Col( [ - html.Label("Filter by Element(s)"), + html.Label("Element(s)"), dcc.Dropdown( id="el_filter", - options=[{"label": el.symbol, "value": el.symbol} for el in Element], + options=[ + {"label": el.symbol, "value": el.symbol} + for el in Element + if el.name not in ("D", "T") + ], multi=True, ), ], @@ -176,7 +242,7 @@ def main(): ), dbc.Col( [ - html.Div("Filter by Chemsys"), + html.Div("Chemsys"), dcc.Input( id="chemsys_filter", placeholder="Li-Fe-O", @@ -184,25 +250,50 @@ def main(): ], width=2, ), - # dbc.Col( - # [ - # html.Div("Bandgap", className="text-center"), - # dcc.RangeSlider(0, 10, 0.1, - # marks={i: str(i) for i in range(0, 10)}, - # value=[0, 10], id='bandgap_filter'), - # ], - # width=2, - # ), dbc.Col( [ - html.Label("Data Tools"), - html.Button("Download", id="btn-download"), - dcc.Download(id="download-data"), + html.Div("Coh. Energy", className="text-center"), + dcc.Input(0, type="number", id="min_coh_e_filter"), + dcc.Input(10, type="number", id="max_coh_e_filter"), ], - width=1, + width=2, + ), + dbc.Col( + [ + html.Div("Form. Energy", className="text-center"), + dcc.Input(0, type="number", id="min_form_e_filter"), + dcc.Input(10, type="number", id="max_form_e_filter"), + ], + width=2, ), ] ), + dbc.Col( + [ + html.Button("Download", id="btn-download"), + dcc.Download(id="download-data"), + ], + width=1, + ), + html.Div( + [ + html.P("Help:"), + html.Ul( + [ + html.Li("Clicking on the PT adds an element to the element filter."), + html.Li( + "Element filter is restrictive, i.e., only data containing all selected elements + any " + "other elements are shown." + ), + html.Li( + "Chemsys filter: Only data within the chemsys are shown. Typically you should only" + " use either element or chemsys but not both." + ), + ] + ), + ], + style={"padding": 5}, + ), dbc.Row( dbc.Col( dcc.Graph(id="ptheatmap", style={"marginLeft": "auto", "marginRight": "auto"}), diff --git a/src/matpes/utils.py b/src/matpes/utils.py index 6b9c05f..03ded99 100644 --- a/src/matpes/utils.py +++ b/src/matpes/utils.py @@ -34,6 +34,7 @@ def get_pt_df() -> pd.DataFrame: "category": get_category(el), } for el in Element + if el.name not in ["D", "T"] ] df = pd.DataFrame(elements) df["label"] = df.apply(lambda row: f"{row['Z']}
{row['symbol']}", axis=1)