-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_data_loading.py
50 lines (42 loc) · 1.7 KB
/
_data_loading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import warnings
from typing import Tuple
import numpy as np
import torch
from . import MoleculeGraph
def collate_graph(
batch: Tuple[np.ndarray, list[MoleculeGraph], list[np.ndarray]]
) -> Tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[MoleculeGraph], list[np.ndarray]]:
"""
Collate graph samples into a batch.
Arguments:
batch:
Tuple (**X**, **mols**, **xyz**), where
- **X** - Input AFM image. Array of shape ``(batch_size, x, y, z)``.
- **mols** - Input molecules.
- **xyz** - List of original molecules. Arrays of shape ``(n_atoms, 5)``.
Returns:
Tuple (**X**, **pos**, **node_classes**, **edges**, **mols**, **xyz**), where
- **X** - Input AFM images.
- **pos** - Graph node xyz coordinates.
- **node_classes** - Graph node class indices.
- **edges** - Graph edge indices.
- **mols** - Input molecules. Unchanged from input argument.
- **xyz** - List of original molecules. Unchanged from input argument.
"""
X, mols, xyz = batch
pos = []
node_classes = []
edges = []
remove_inds = []
for i, m in enumerate(mols):
if len(m) == 0:
warnings.warn("empty molecule in batch", RuntimeWarning)
remove_inds.append(i)
continue
m_array = m.array(xyz=True, class_index=True)
pos.append(torch.from_numpy(m_array[:, :3]).float())
node_classes.append(torch.from_numpy(m_array[:, 3]).long())
edges.append(torch.Tensor(m.bonds).T.long())
X = np.delete(X, remove_inds, axis=0)
X = torch.from_numpy(X).float()
return X, pos, node_classes, edges, mols, xyz