Skip to content

Commit

Permalink
Merge pull request #41 from sparks-baird/cdvae-cov-matching
Browse files Browse the repository at this point in the history
matching helper functions (StructureMatcher and CDVAE versions)
  • Loading branch information
sgbaird authored Aug 4, 2022
2 parents d647dbf + 883804d commit fc5e3bd
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 35 deletions.
104 changes: 73 additions & 31 deletions src/matbench_genmetrics/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import numpy as np
from mp_time_split.core import MPTimeSplit
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.core.structure import Structure
from scipy.stats import wasserstein_distance
from tqdm import tqdm
from tqdm.notebook import tqdm as ipython_tqdm

from matbench_genmetrics import __version__
from matbench_genmetrics.utils.match import ALLOWED_MATCH_TYPES, get_match_matrix

# causes pytest to fail (tests not found, DLL load error)
# from matbench_genmetrics.cdvae.metrics import RecEval, GenEval, OptEval
Expand Down Expand Up @@ -48,32 +48,25 @@ def fib(n):
return a


sm = StructureMatcher(stol=0.5, ltol=0.3, angle_tol=10.0)


def pairwise_match(s1: Structure, s2: Structure):
return sm.fit(s1, s2)


IN_COLAB = "google.colab" in sys.modules

# try:
# import google.colab # type: ignore # noqa: F401

# IN_COLAB = True
# except ImportError:
# IN_COLAB = False


class GenMatcher(object):
def __init__(
self,
test_structures,
gen_structures: Optional[List[Structure]] = None,
verbose=True,
match_type="cdvae_coverage",
**match_kwargs,
) -> None:
self.test_structures = test_structures
self.verbose = verbose
assert (
match_type in ALLOWED_MATCH_TYPES
), f"type must be one of {ALLOWED_MATCH_TYPES}"
self.match_type = match_type
self.match_kwargs = match_kwargs

if gen_structures is None:
self.gen_structures = test_structures
Expand Down Expand Up @@ -102,14 +95,13 @@ def match_matrix(self):
if self._match_matrix is not None:
return self._match_matrix

match_matrix = np.zeros((self.num_test, self.num_gen))
for i, ts in enumerate(self.tqdm(self.test_structures, **self.tqdm_kwargs)):
for j, gs in enumerate(self.gen_structures):
if not self.symmetric or (self.symmetric and i < j):
match_matrix[i, j] = pairwise_match(ts, gs)

if self.symmetric:
match_matrix = match_matrix + match_matrix.T
match_matrix = get_match_matrix(
self.test_structures,
self.gen_structures,
match_type=self.match_type,
symmetric=self.symmetric,
**self.match_kwargs,
)

self._match_matrix = match_matrix

Expand Down Expand Up @@ -161,12 +153,16 @@ def __init__(
gen_structures,
test_pred_structures=None,
verbose=True,
match_type="cdvae_coverage",
**match_kwargs,
):
self.train_structures = train_structures
self.test_structures = test_structures
self.gen_structures = gen_structures
self.test_pred_structures = test_pred_structures
self.verbose = verbose
self.match_type = match_type
self.match_kwargs = match_kwargs
self._cdvae_metrics = None
self._mpts_metrics = None

Expand Down Expand Up @@ -206,15 +202,23 @@ def validity(self):
def coverage(self):
"""Match rate between test structures and generated structures."""
self.coverage_matcher = GenMatcher(
self.test_structures, self.gen_structures, verbose=self.verbose
self.test_structures,
self.gen_structures,
verbose=self.verbose,
match_type=self.match_type,
**self.match_kwargs,
)
return self.coverage_matcher.match_rate

@property
def novelty(self):
"""One minus match rate between train structures and generated structures."""
self.similarity_matcher = GenMatcher(
self.train_structures, self.gen_structures, verbose=self.verbose
self.train_structures,
self.gen_structures,
verbose=self.verbose,
match_type=self.match_type,
**self.match_kwargs,
)
similarity = (
self.similarity_matcher.match_count / self.similarity_matcher.num_gen
Expand All @@ -225,7 +229,11 @@ def novelty(self):
def uniqueness(self):
"""One minus duplicity rate within generated structures."""
self.commonality_matcher = GenMatcher(
self.gen_structures, self.gen_structures, verbose=self.verbose
self.gen_structures,
self.gen_structures,
verbose=self.verbose,
match_type=self.match_type,
**self.match_kwargs,
)
commonality = self.commonality_matcher.duplicity_rate
return 1.0 - commonality
Expand All @@ -242,10 +250,19 @@ def metrics(self):


class MPTSMetrics(object):
def __init__(self, dummy=False, verbose=True, num_gen=None):
def __init__(
self,
dummy=False,
verbose=True,
num_gen=None,
match_type="cdvae_coverage",
**match_kwargs,
):
self.dummy = dummy
self.verbose = verbose
self.num_gen = num_gen
self.match_type = match_type
self.match_kwargs = match_kwargs
self.mpt = MPTimeSplit(target="energy_above_hull")
self.folds = self.mpt.folds
self.gms: List[Optional[GenMetrics]] = [None] * len(self.folds)
Expand Down Expand Up @@ -276,6 +293,8 @@ def evaluate_and_record(self, fold, gen_structures, test_pred_structures=None):
gen_structures,
test_pred_structures=test_pred_structures,
verbose=self.verbose,
match_type=self.match_type,
**self.match_kwargs,
)

self.recorded_metrics[fold] = self.gms[fold].metrics
Expand All @@ -287,22 +306,38 @@ def evaluate_and_record(self, fold, gen_structures, test_pred_structures=None):

class MPTSMetrics10(MPTSMetrics):
def __init__(self, dummy=False, verbose=True):
MPTSMetrics.__init__(self, dummy=dummy, verbose=verbose, num_gen=10)
MPTSMetrics.__init__(
self, dummy=dummy, verbose=verbose, num_gen=10, match_type="cdvae_coverage"
)


class MPTSMetrics100(MPTSMetrics):
def __init__(self, dummy=False, verbose=True):
MPTSMetrics.__init__(self, dummy=dummy, verbose=verbose, num_gen=100)
MPTSMetrics.__init__(
self, dummy=dummy, verbose=verbose, num_gen=100, match_type="cdvae_coverage"
)


class MPTSMetrics1000(MPTSMetrics):
def __init__(self, dummy=False, verbose=True):
MPTSMetrics.__init__(self, dummy=dummy, verbose=verbose, num_gen=1000)
MPTSMetrics.__init__(
self,
dummy=dummy,
verbose=verbose,
num_gen=1000,
match_type="cdvae_coverage",
)


class MPTSMetrics10000(MPTSMetrics):
def __init__(self, dummy=False, verbose=True):
MPTSMetrics.__init__(self, dummy=dummy, verbose=verbose, num_gen=10000)
MPTSMetrics.__init__(
self,
dummy=dummy,
verbose=verbose,
num_gen=10000,
match_type="cdvae_coverage",
)


# def get_rms_dist(gen_structures, test_structures):
Expand Down Expand Up @@ -446,3 +481,10 @@ def run():
# generate_features(pd.DataFrame(dict(formula=self.test_formulas, target=0.0)))
# self.gen_cbfv, _, _, _ = generate_features(dict(formula=self.gen_formulas,
# target=0.0))

# try:
# import google.colab # type: ignore # noqa: F401

# IN_COLAB = True
# except ImportError:
# IN_COLAB = False
152 changes: 152 additions & 0 deletions src/matbench_genmetrics/utils/match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import numpy as np
from matminer.featurizers.composition.composite import ElementProperty
from matminer.featurizers.site.fingerprint import CrystalNNFingerprint
from pymatgen.analysis.bond_valence import BVAnalyzer
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.core.structure import Structure
from scipy.spatial.distance import cdist, pdist, squareform

sm = StructureMatcher(stol=0.5, ltol=0.3, angle_tol=10.0)


def structure_matcher(s1: Structure, s2: Structure):
return sm.fit(s1, s2)


pairwise_match_fn_dict = {"StructureMatcher": structure_matcher}


def structure_pairwise_match_matrix(
test_structures,
gen_structures,
match_type="StructureMatcher",
symmetric=False,
):
# TODO: replace with group_structures to be faster
pairwise_match_fn = pairwise_match_fn_dict[match_type]
match_matrix = np.zeros((len(test_structures), len(gen_structures)))
for i, ts in enumerate(test_structures):
for j, gs in enumerate(gen_structures):
if not symmetric or (symmetric and i < j):
match_matrix[i, j] = pairwise_match_fn(ts, gs)
if symmetric:
match_matrix = match_matrix + match_matrix.T
return match_matrix


CompFP = ElementProperty.from_preset("magpie")


def cdvae_cov_comp_fingerprints(structures):
return [CompFP.featurize(s.composition) for s in structures]


CrystalNNFP = CrystalNNFingerprint.from_preset("ops")
bva = BVAnalyzer()


def cdvae_cov_struct_fingerprints(structures):
oxi_structures = []
for s in structures:
try:
oxi_struct = bva.get_oxi_state_decorated_structure(s)
except ValueError:
# TODO: track how many couldn't have valences assigned
oxi_struct = s
oxi_structures.append(oxi_struct)

struct_fps = []
for s in oxi_structures:
site_fps = [CrystalNNFP.featurize(s, i) for i in range(len(s))]
struct_fp = np.array(site_fps).mean(axis=0)
struct_fps.append(struct_fp)
return struct_fps


def cdvae_cov_dist_matrix(
test_structures, gen_structures, composition_only=False, symmetric=False
):
fingerprint_fn = (
cdvae_cov_comp_fingerprints
if composition_only
else cdvae_cov_struct_fingerprints
)
test_comp_fps = fingerprint_fn(test_structures)
if symmetric:
dm = squareform(pdist(test_comp_fps))
else:
gen_comp_fps = fingerprint_fn(gen_structures)
dm = cdist(test_comp_fps, gen_comp_fps)
return dm


def cdvae_cov_match_matrix(
test_structures,
gen_structures,
composition_only=False,
symmetric=False,
cutoff=10.0,
):
dm = cdvae_cov_dist_matrix(
test_structures,
gen_structures,
composition_only=composition_only,
symmetric=symmetric,
)
return dm <= cutoff


def cdvae_cov_compstruct_match_matrix(
test_structures,
gen_structures,
symmetric=False,
comp_cutoff=10.0,
struct_cutoff=0.4,
):
comp_match_matrix = cdvae_cov_match_matrix(
test_structures,
gen_structures,
composition_only=True,
symmetric=symmetric,
cutoff=comp_cutoff,
)
struct_match_matrix = cdvae_cov_match_matrix(
test_structures,
gen_structures,
composition_only=False,
symmetric=symmetric,
cutoff=struct_cutoff,
)
# multiply, since 0*0=0, 0*1=0, 1*0=0, 1*1=1
return comp_match_matrix * struct_match_matrix


ALLOWED_MATCH_TYPES = ["StructureMatcher", "cdvae_coverage"]


def get_match_matrix(
test_structures,
gen_structures,
match_type="cdvae_coverage",
symmetric=False,
**match_kwargs,
):
assert (
match_type in ALLOWED_MATCH_TYPES
), f"type must be one of {ALLOWED_MATCH_TYPES}"

if match_type == "cdvae_coverage":
return cdvae_cov_compstruct_match_matrix(
test_structures,
gen_structures,
symmetric=symmetric,
**match_kwargs,
)
elif match_type == "StructureMatcher":
return structure_pairwise_match_matrix(
test_structures,
gen_structures,
match_type="StructureMatcher",
symmetric=symmetric,
**match_kwargs,
)
Loading

0 comments on commit fc5e3bd

Please sign in to comment.