Skip to content

Commit

Permalink
MOre code cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Nov 20, 2024
1 parent cee872e commit 0e82d6d
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 150 deletions.
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,12 @@ exclude_lines = [
"if TYPE_CHECKING:",
"except PackageNotFoundError:"
]

[tool.pyright]
typeCheckingMode = "off"
reportPossiblyUnboundVariable = true
reportUnboundVariable = true
reportMissingImports = false
reportMissingModuleSource = false
reportInvalidTypeForm = false
exclude = ["**/tests"]
237 changes: 143 additions & 94 deletions src/matpes/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,28 @@
from pymatgen.core import Element
from pymongo import MongoClient

from matpes.utils import get_pt_heatmap
from matpes.utils import pt_heatmap

# Define constants
FUNCTIONALS = ("PBE", "r2SCAN")
MONGO_DB_NAME = "matpes"

# Set up MongoDB client and database
CLIENT = MongoClient()
DB = CLIENT["matpes"]
DB = CLIENT[MONGO_DB_NAME]


@functools.lru_cache
def get_df(functional: str) -> pd.DataFrame:
"""
Retrieve data for the given functional from the MongoDB database.
Args:
functional (str): The functional to query (e.g., 'PBE').
Returns:
pd.DataFrame: Dataframe containing the functional's data.
"""
collection = DB[functional]
return pd.DataFrame(
collection.find(
Expand All @@ -43,92 +54,49 @@ def get_df(functional: str) -> pd.DataFrame:
)


@functools.lru_cache
def get_data(functional, el, chemsys):
"""Filter data with caching for improved performance."""
def get_data(functional: str, element_filter: list, chemsys: str) -> pd.DataFrame:
"""
Filter data based on the selected functional, element, and chemical system.
Args:
functional (str): Functional to filter data for.
element_filter (list | None): Elements to filter (if any).
chemsys (str | None): Chemical system to filter (if any).
Returns:
pd.DataFrame: Filtered data.
"""
df = get_df(functional)
if el is not None:
df = df[df["elements"].apply(lambda x: el in x)]
if element_filter:
df = df[df["elements"].apply(lambda x: set(x).issuperset(element_filter))]
if chemsys:
chemsys = "-".join(sorted(chemsys.split("-")))
df = df[df["chemsys"] == chemsys]

sorted_chemsys = "-".join(sorted(chemsys.split("-")))
df = df[df["chemsys"] == sorted_chemsys]
return df


# Initialize the Dash app with a Bootstrap theme
external_stylesheets = [dbc.themes.CERULEAN]
app = Dash("MatPES Explorer", external_stylesheets=external_stylesheets)
def get_dist_plot(data: pd.Series, label: str, ignore_nan: bool = True, nbins: int = 100):
"""
Create a distribution plot for a given dataset.
# Define the app layout
app.layout = dbc.Container(
[
dbc.Row([html.Div("MatPES Explorer", className="text-primary text-center fs-3")]),
dbc.Row(
[
dbc.Col(
[
html.Label("Functional"),
dcc.Dropdown(
options=[{"label": f, "value": f} for f in FUNCTIONALS], value="PBE", id="functional"
),
],
width=2,
),
dbc.Col(
[
html.Label("Element Filter"),
dcc.Dropdown(
options=[{"label": el.symbol, "value": el.symbol} for el in Element], id="el_filter"
),
],
width=2,
),
dbc.Col(
[
html.Label("Chemsys Filter"),
dcc.Input(id="chemsys_filter", placeholder="Li-Fe-O"),
],
width=2,
),
dbc.Col(
[html.Div([html.Button("Download", id="btn-download"), dcc.Download(id="download-data")])],
width=2,
),
]
),
dbc.Row(
[
dbc.Col(
[dcc.Graph(id="ptheatmap", style={"marginLeft": "auto", "marginRight": "auto"})],
width={"size": 8, "order": "last", "offset": 2},
)
],
),
dbc.Row(
[
dbc.Col([dcc.Graph(id="coh_energy_hist")], width=6),
dbc.Col([dcc.Graph(id="form_energy_hist")], width=6),
]
),
dbc.Row(
[
dbc.Col([dcc.Graph(id="natoms_hist")], width=6),
dbc.Col([dcc.Graph(id="nelements_hist")], width=6),
]
),
]
)


def get_dist_plot(data, label, nbins=100):
fig = ff.create_distplot([data], [label], show_rug=False)
Args:
data (pd.Series): The data to plot.
label (str): Label for the x-axis.
ignore_nan (bool): Whether to ignore NaN values.
nbins (int): Number of bins for the histogram.
Returns:
plotly.graph_objects.Figure: The distribution plot figure.
"""
if ignore_nan:
data = data.dropna()
bin_size = (data.max() - data.min()) / nbins
fig = ff.create_distplot([data], [label], bin_size=bin_size, show_rug=False)
fig.update_layout(xaxis=dict(title=label), showlegend=False)
return fig


# Define callback to update the heatmap based on selected functional
# Callback to update visualizations
@callback(
[
Output("ptheatmap", "figure"),
Expand All @@ -140,23 +108,30 @@ def get_dist_plot(data, label, nbins=100):
[Input("functional", "value"), Input("el_filter", "value"), Input("chemsys_filter", "value")],
)
def update_graph(functional, el_filter, chemsys_filter):
"""Update graph based on input."""
"""Update graphs based on user inputs."""
df = get_data(functional, el_filter, chemsys_filter)
el_count = {el.symbol: 0 for el in Element}
el_count.update(collections.Counter(itertools.chain(*df["elements"])))
heatmap_figure = get_pt_heatmap(el_count, label="Count", log=True)
element_counts = collections.Counter(itertools.chain(*df["elements"]))
heatmap_figure = pt_heatmap(element_counts, label="Count", log=True)
return (
heatmap_figure,
get_dist_plot(df["cohesive_energy_per_atom"], "Cohesive Energy per Atom (eV/atom)"),
get_dist_plot(
df["formation_energy_per_atom"].dropna(),
"Formation Energy per Atom (eV/atom)",
px.histogram(
df,
x="cohesive_energy_per_atom",
labels={"cohesive_energy_per_atom": "Cohesive Energy per Atom (eV/atom)"},
nbins=100,
),
px.histogram(
df,
x="formation_energy_per_atom",
labels={"formation_energy_per_atom": "Formation Energy per Atom (eV/atom)"},
nbins=100,
),
px.histogram(df, x="natoms"),
px.histogram(df, x="nelements"),
)


# Callback to download data
@callback(
Output("download-data", "data"),
Input("btn-download", "n_clicks"),
Expand All @@ -165,20 +140,94 @@ def update_graph(functional, el_filter, chemsys_filter):
State("chemsys_filter", "value"),
prevent_initial_call=True,
)
def download(n_clicks, functional, el_filter, chemsys_filter):
def download_data(n_clicks, functional, el_filter, chemsys_filter):
"""Handle data download requests."""
collection = DB[functional]
criteria = {}
if el_filter is not None:
if el_filter:
criteria["elements"] = el_filter
if chemsys_filter is not None:
chemsys = "-".join(sorted(chemsys_filter.split("-")))
criteria["chemsys"] = chemsys
if chemsys_filter:
criteria["chemsys"] = "-".join(sorted(chemsys_filter.split("-")))
data = list(collection.find(criteria))
for d in data:
del d["_id"]
return dict(content=json.dumps(data), filename=f"matpes_{functional}_{el_filter}_{chemsys_filter}.json")
for entry in data:
entry.pop("_id", None) # Remove MongoDB's internal ID
return dict(
content=json.dumps(data), filename=f"matpes_{functional}_{el_filter or 'all'}_{chemsys_filter or 'all'}.json"
)


# Run the app
def main():
"""Main entry point for MatPES Explorer UI."""
app = Dash("MatPES Explorer", external_stylesheets=[dbc.themes.CERULEAN], title="MatPES Explorer")

# Define app layout
app.layout = dbc.Container(
[
dbc.Row([html.Div("MatPES Explorer", className="text-primary text-center fs-3")]),
dbc.Row(
[
dbc.Col(
[
html.Label("Functional"),
dcc.Dropdown(
id="functional",
options=[{"label": f, "value": f} for f in FUNCTIONALS],
value="PBE",
clearable=False,
),
],
width=2,
),
dbc.Col(
[
html.Label("Filter by Element(s)"),
dcc.Dropdown(
id="el_filter",
options=[{"label": el.symbol, "value": el.symbol} for el in Element],
multi=True,
),
],
width=2,
),
dbc.Col(
[
html.Label("Filter by Chemsys"),
dcc.Input(
id="chemsys_filter",
placeholder="Li-Fe-O",
),
],
width=2,
),
dbc.Col(
[
html.Label("Data Tools"),
html.Button("Download", id="btn-download"),
dcc.Download(id="download-data"),
],
width=1,
),
]
),
dbc.Row(
dbc.Col(
dcc.Graph(id="ptheatmap", style={"marginLeft": "auto", "marginRight": "auto"}),
width={"size": 8, "offset": 2},
)
),
dbc.Row(
[
dbc.Col(dcc.Graph(id="coh_energy_hist"), width=6),
dbc.Col(dcc.Graph(id="form_energy_hist"), width=6),
]
),
dbc.Row(
[
dbc.Col(dcc.Graph(id="natoms_hist"), width=6),
dbc.Col(dcc.Graph(id="nelements_hist"), width=6),
]
),
]
)

app.run(debug=True)
Loading

0 comments on commit 0e82d6d

Please sign in to comment.