-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata_structs.py
121 lines (102 loc) · 4.03 KB
/
data_structs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch.utils.data as data
from utils import canonicalize_smiles_from_file,replace_halogen,construct_vocabulary,create_var,one_hot
import re
import numpy as np
import torch
from rdkit import Chem
class MolData(data.Dataset):
def __init__(self, root_path, voc ):
self.voc = voc
self.root_path = root_path
self.mols, self.scas = canonicalize_smiles_from_file(self.root_path)
def __getitem__(self, index):
mol = self.mols[index]
sca = self.scas[index]
tokenized = self.voc.tokenize(sca)
encoded = self.voc.encode(tokenized)
return create_var(encoded),mol,sca
def __len__(self):
return len(self.mols)
@classmethod
def collate_fn(self, batch):
max_length = max([seq.size(0) for seq, _, _ in batch])
collated_arr = create_var(torch.zeros(len(batch), max_length))
mol_batch = []
sca_batch = []
i =0
for seq, mol, sca in batch:
collated_arr[i, :seq.size(0)] = seq
mol_batch.append(mol)
sca_batch.append(sca)
i+=1
return mol_batch,sca_batch,collated_arr
class Vocabulary(object):
def __init__(self, init_from_file=None, max_length=140):
self.special_tokens = ["EOS","GO"]
self.additional_chars = set()
self.chars = self.special_tokens
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
self.reversed_vocab = {v: k for k, v in self.vocab.items()}
self.max_length = max_length
if init_from_file: self.init_from_file(init_from_file)
def encode(self, char_list):
"""Takes a list of characters (eg '[NH]') and encodes to array of indices"""
smiles_matrix = np.zeros(len(char_list), dtype=np.float32)
for i, char in enumerate(char_list):
smiles_matrix[i] = self.vocab[char]
return smiles_matrix
def decode(self, matrix):
"""Takes an array of indices and returns the corresponding SMILES"""
chars = []
for i in matrix:
if i == self.vocab['EOS']: break
chars.append(self.reversed_vocab[i])
smiles = "".join(chars)
smiles = smiles.replace("L", "Cl").replace("R", "Br")
return smiles
def tokenize(self, smiles):
"""Takes a SMILES and return a list of characters/tokens"""
regex = '(\[[^\[\]]{1,6}\])'
smiles = replace_halogen(smiles)
char_list = re.split(regex, smiles)
tokenized = []
for char in char_list:
if char.startswith('['):
tokenized.append(char)
else:
chars = [unit for unit in char]
[tokenized.append(unit) for unit in chars]
tokenized.append('EOS')
return tokenized
def add_characters(self, chars):
"""Adds characters to the vocabulary"""
for char in chars:
self.additional_chars.add(char)
char_list = list(self.additional_chars)
char_list.sort()
self.chars = char_list + self.special_tokens
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
self.reversed_vocab = {v: k for k, v in self.vocab.items()}
def init_from_file(self, file):
"""Takes a file containing \n separated characters to initialize the vocabulary"""
with open(file, 'r') as f:
chars = f.read().split()
self.add_characters(chars)
def __len__(self):
return len(self.chars)
def write_smiles_to_file(smiles_list, fname):
"""Write a list of SMILES to a file."""
with open(fname, 'w') as f:
for smiles in smiles_list:
f.write(smiles + "\n")
if __name__ == "__main__":
data_path = "D:\Python\ProjectOne\data\data.txt"
voc_chars , max_len = construct_vocabulary(data_path)
voc = Vocabulary(init_from_file='data/Voc',max_length=max_len)
print(voc_chars)
print(max_len)
moldata = MolData(data_path,voc)
test = moldata.__getitem__(0)
print(test)