-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathequation_vae.py
71 lines (61 loc) · 2.96 KB
/
equation_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import nltk
import re
import eq_grammar
import molecule_vae
import models.model_eq
import models.model_eq_str
import numpy as np
def tokenize(s):
funcs = ['sin', 'exp']
for fn in funcs: s = s.replace(fn+'(', fn+' ')
s = re.sub(r'([^a-z ])', r' \1 ', s)
for fn in funcs: s = s.replace(fn, fn+'(')
return s.split()
class EquationGrammarModel(molecule_vae.ZincGrammarModel):
def __init__(self, weights_file, latent_rep_size=25):
""" Load the (trained) equation encoder/decoder, grammar model. """
self._grammar = eq_grammar
self._model = models.model_eq
self.MAX_LEN = 15 # TODO: read from elsewhere
self._productions = self._grammar.GCFG.productions()
self._prod_map = {}
for ix, prod in enumerate(self._productions):
self._prod_map[prod] = ix
self._parser = nltk.ChartParser(self._grammar.GCFG)
self._tokenize = tokenize
self._n_chars = len(self._productions)
self._lhs_map = {}
for ix, lhs in enumerate(self._grammar.lhs_list):
self._lhs_map[lhs] = ix
self.vae = self._model.MoleculeVAE()
self.vae.load(self._productions, weights_file, max_length=self.MAX_LEN, latent_rep_size=latent_rep_size)
class EquationCharacterModel(object):
def __init__(self, weights_file, latent_rep_size=25):
self._model = models.model_eq_str
self.MAX_LEN = 19
self.vae = self._model.MoleculeVAE()
self.charlist = ['x', '+', '(', ')', '1', '2', '3', '*', '/', 's', 'i', 'n', 'e', 'p', ' ']
#self.charlist = ['C', '(', ')', 'c', '1', '2', 'o', '=', 'O', 'N', '3', 'F', '[',
# '@', 'H', ']', 'n', '-', '#', 'S', 'l', '+', 's', 'B', 'r', '/',
# '4', '\\', '5', '6', '7', 'I', 'P', '8', ' ']
self._char_index = {}
for ix, char in enumerate(self.charlist):
self._char_index[char] = ix
self.vae.load(self.charlist, weights_file, max_length=self.MAX_LEN, latent_rep_size=latent_rep_size)
def encode(self, smiles):
""" Encode a list of smiles strings into the latent space """
indices = [np.array([self._char_index[c] for c in entry], dtype=int) for entry in smiles]
one_hot = np.zeros((len(indices), self.MAX_LEN, len(self.charlist)), dtype=np.float32)
for i in xrange(len(indices)):
num_productions = len(indices[i])
one_hot[i][np.arange(num_productions),indices[i]] = 1.
one_hot[i][np.arange(num_productions, self.MAX_LEN),-1] = 1.
return self.vae.encoderMV.predict(one_hot)[0]
def decode(self, z):
""" Sample from the character decoder """
assert z.ndim == 2
out = self.vae.decoder.predict(z)
noise = np.random.gumbel(size=out.shape)
sampled_chars = np.argmax(np.log(out) + noise, axis=-1)
char_matrix = np.array(self.charlist)[np.array(sampled_chars, dtype=int)]
return [''.join(ch).strip() for ch in char_matrix]