Skip to content

Commit

Permalink
api refactored.
Browse files Browse the repository at this point in the history
bugs fixed
  • Loading branch information
stsouko committed Oct 7, 2024
1 parent e0fdf30 commit d10859e
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 121 deletions.
40 changes: 0 additions & 40 deletions chytorch/nn/activation.py

This file was deleted.

6 changes: 2 additions & 4 deletions chytorch/nn/transformer/attention/graphormer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, embed_dim, num_heads, dropout: float = .1, bias: bool = True,
self._register_load_state_dict_pre_hook(_update_packed)
self.o_proj = Linear(embed_dim, embed_dim, bias=bias)

def forward(self, x: Tensor, attn_mask: Optional[Tensor], pad_mask: Optional[Tensor] = None, *,
def forward(self, x: Tensor, attn_mask: Tensor, *,
cache: Optional[Tuple[Tensor, Tensor]] = None,
need_weights: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
if self.separate_proj:
Expand All @@ -126,9 +126,7 @@ def forward(self, x: Tensor, attn_mask: Optional[Tensor], pad_mask: Optional[Ten
v = v.unflatten(2, (self.num_heads, -1)).transpose(1, 2)

# BxHxTxE @ BxHxExS > BxHxTxS
a = (q @ k) * self._scale
if attn_mask is not None:
a = a + attn_mask
a = (q @ k) * self._scale + attn_mask
a = softmax(a, dim=-1)
if self.training and self.dropout:
a = dropout(a, self.dropout)
Expand Down
16 changes: 8 additions & 8 deletions chytorch/nn/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,23 @@ class EncoderLayer(Module):
"""
def __init__(self, d_model, nhead, dim_feedforward, dropout=0.1, activation=GELU, layer_norm_eps=1e-5,
norm_first: bool = False, attention: Type[Module] = GraphormerAttention, mlp: Type[Module] = MLP,
projection_bias: bool = True, ff_bias: bool = True):
norm_layer: Type[Module] = LayerNorm, projection_bias: bool = True, ff_bias: bool = True):
super().__init__()
self.self_attn = attention(d_model, nhead, dropout, projection_bias)
self.mlp = mlp(d_model, dim_feedforward, dropout, activation, ff_bias)

self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
self.norm1 = norm_layer(d_model, eps=layer_norm_eps)
self.norm2 = norm_layer(d_model, eps=layer_norm_eps)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.norm_first = norm_first
self._register_load_state_dict_pre_hook(_update)

def forward(self, x: Tensor, attn_mask: Optional[Tensor], pad_mask: Optional[Tensor] = None, *,
cache: Optional[Tuple[Tensor, Tensor]] = None,
need_embedding: bool = True, need_weights: bool = False) -> Tuple[Optional[Tensor], Optional[Tensor]]:
def forward(self, x: Tensor, attn_mask: Optional[Tensor], *,
need_embedding: bool = True, need_weights: bool = False,
**kwargs) -> Tuple[Optional[Tensor], Optional[Tensor]]:
nx = self.norm1(x) if self.norm_first else x # pre-norm or post-norm
e, a = self.self_attn(nx, attn_mask, pad_mask, cache=cache, need_weights=need_weights)
e, a = self.self_attn(nx, attn_mask, need_weights=need_weights, **kwargs)

if need_embedding:
x = x + self.dropout1(e)
Expand All @@ -96,4 +96,4 @@ def forward(self, x: Tensor, attn_mask: Optional[Tensor], pad_mask: Optional[Ten
return None, a


__all__ = ['EncoderLayer']
__all__ = ['EncoderLayer', 'MLP']
30 changes: 6 additions & 24 deletions chytorch/utils/data/molecule/_unpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ DTYPE = np.int32
ctypedef cnp.int32_t DTYPE_t


cdef extern from "Python.h":
dict _PyDict_NewPresized(Py_ssize_t minused)


# Format specification::
#
# Big endian bytes order
Expand Down Expand Up @@ -68,15 +64,15 @@ cdef extern from "Python.h":
@cython.cdivision(True)
@cython.wraparound(False)
def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsigned short symmetric_attention,
unsigned short components_attention, DTYPE_t max_neighbors, DTYPE_t max_distance, DTYPE_t padding):
unsigned short components_attention, DTYPE_t max_neighbors, DTYPE_t max_distance):
"""
Optimized chython pack to graph tensor converter.
Ignores charge, radicals, isotope, coordinates, bond order, and stereo info
"""
cdef unsigned char a, b, c, hydrogens, neighbors_count
cdef unsigned char *connections

cdef unsigned short atoms_count, bonds_count = 0, order_count = 0, cis_trans_count, padded_count
cdef unsigned short atoms_count, bonds_count = 0, order_count = 0, cis_trans_count
cdef unsigned short i, j, k, n, m
cdef unsigned short[4096] mapping
cdef unsigned int size, shift = 4
Expand All @@ -85,9 +81,6 @@ def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsig
cdef cnp.ndarray[DTYPE_t, ndim=2] distance
cdef DTYPE_t d, attention

cdef object py_n
cdef dict py_mapping

# read header
if data[0] != 2:
raise ValueError('invalid pack version')
Expand All @@ -98,12 +91,9 @@ def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsig
atoms_count = (a << 4| b >> 4) + add_cls
cis_trans_count = (b & 0x0f) << 8 | c

py_mapping = _PyDict_NewPresized(atoms_count)

padded_count = atoms_count + padding
atoms = np.empty(padded_count, dtype=DTYPE)
neighbors = np.zeros(padded_count, dtype=DTYPE)
distance = np.full((padded_count, padded_count), 9999, dtype=DTYPE) # fill with unreachable value
atoms = np.empty(atoms_count, dtype=DTYPE)
neighbors = np.zeros(atoms_count, dtype=DTYPE)
distance = np.full((atoms_count, atoms_count), 9999, dtype=DTYPE) # fill with unreachable value

# allocate memory
connections = <unsigned char*> PyMem_Malloc(atoms_count * sizeof(unsigned char))
Expand All @@ -126,7 +116,6 @@ def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsig
a, b = data[shift], data[shift + 1]
n = a << 4 | b >> 4
mapping[n] = i
py_mapping[n] = i
connections[i] = neighbors_count = b & 0x0f
bonds_count += neighbors_count

Expand Down Expand Up @@ -187,13 +176,6 @@ def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsig
else:
distance[i, j] = distance[j, i] = d + 2

# disable attention on padding
for i in range(atoms_count, padded_count):
atoms[i] = 2 # set explicit hydrogen
for j in range(padded_count):
distance[i, j] = distance[j, i] = 0
distance[i, i] = 1 # self-attention of padding

size = shift + order_count + 4 * cis_trans_count
PyMem_Free(connections)
return atoms, neighbors, distance, size, py_mapping
return atoms, neighbors, distance, size
51 changes: 7 additions & 44 deletions chytorch/utils/data/molecule/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torchtyping import TensorType
from typing import Sequence, Union, NamedTuple, Optional, Tuple
from typing import Sequence, Union, NamedTuple
from zlib import decompress
from .._abc import default_collate_fn_map

Expand Down Expand Up @@ -89,8 +89,7 @@ def collate_molecules(batch, *, padding_left: bool = False, collate_fn_map=None)

class MoleculeDataset(Dataset):
def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
hydrogens: Optional[Sequence[Sequence[Tuple[int, ...]]]] = None, cls_token: int = 1,
max_distance: int = 10, add_cls: bool = True, max_neighbors: int = 14,
cls_token: int = 1, max_distance: int = 10, add_cls: bool = True, max_neighbors: int = 14,
symmetric_attention: bool = True, components_attention: bool = True,
unpack: bool = False, compressed: bool = True, distance_cutoff=None):
"""
Expand All @@ -106,7 +105,6 @@ def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
that code unreachable atoms (e.g. salts).
:param molecules: molecules collection
:param hydrogens: shared hydrogen mapping. First element is hydrogen donor, other are acceptors
:param max_distance: set distances greater than cutoff to cutoff value
:param add_cls: add special token at first position
:param max_neighbors: set neighbors count greater than cutoff to cutoff value
Expand All @@ -116,10 +114,7 @@ def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
:param compressed: packed molecules are compressed
:param cls_token: idx of cls token
"""
assert hydrogens is None or len(hydrogens) == len(molecules), 'hydrogens and molecules must have the same size'

self.molecules = molecules
self.hydrogens = hydrogens
# distance_cutoff is deprecated
self.max_distance = distance_cutoff if distance_cutoff is not None else max_distance
self.add_cls = add_cls
Expand All @@ -132,14 +127,6 @@ def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,

def __getitem__(self, item: int) -> MoleculeDataPoint:
mol = self.molecules[item]

if self.hydrogens is not None:
hmap = self.hydrogens[item]
pad = len(hmap)
else:
hmap = None
pad = 0

if self.unpack:
try:
from ._unpack import unpack
Expand All @@ -148,22 +135,15 @@ def __getitem__(self, item: int) -> MoleculeDataPoint:
else:
if self.compressed:
mol = decompress(mol)
atoms, neighbors, distances, _, mapping = unpack(mol, self.add_cls, self.symmetric_attention,
self.components_attention, self.max_neighbors,
self.max_distance, pad)
if pad:
for n, da in enumerate(hmap, -pad):
neighbors[mapping[da[0]]] -= 1
for m in da:
m = mapping[m]
distances[n, m] = distances[m, n] = 1
atoms, neighbors, distances, _ = unpack(mol, self.add_cls, self.symmetric_attention,
self.components_attention, self.max_neighbors,
self.max_distance)
if self.add_cls and self.cls_token != 1:
atoms[0] = self.cls_token
return MoleculeDataPoint(IntTensor(atoms), IntTensor(neighbors), IntTensor(distances))

nc = self.max_neighbors
lp = len(mol) + pad
mapping = {}
lp = len(mol)

if self.add_cls:
lp += 1
Expand All @@ -176,7 +156,6 @@ def __getitem__(self, item: int) -> MoleculeDataPoint:
ngb = mol._bonds # noqa speedup
hgs = mol._hydrogens # noqa
for i, (n, a) in enumerate(mol.atoms(), self.add_cls):
mapping[n] = i
atoms[i] = a.atomic_number + 2
nb = len(ngb[n]) + (hgs[n] or 0) # treat bad valence as 0-hydrogen
if nb > nc:
Expand All @@ -188,23 +167,7 @@ def __getitem__(self, item: int) -> MoleculeDataPoint:
minimum(distances, self.max_distance + 2, out=distances)
distances = IntTensor(distances)

if pad:
atoms[-pad:] = 2 # set explicit hydrogens
tmp = eye(lp, dtype=int32)
if self.add_cls:
tmp[0] = 1 # enable CLS to atom attention
tmp[1:, 0] = 1 if self.symmetric_attention else 0 # enable or disable atom to CLS attention
tmp[1:-pad, 1:-pad] = distances
else:
tmp[:-pad, :-pad] = distances
distances = tmp

for n, da in enumerate(hmap, -pad):
neighbors[mapping[da[0]]] -= 1
for m in da:
m = mapping[m]
distances[n, m] = distances[m, n] = 1
elif self.add_cls:
if self.add_cls:
tmp = ones((lp, lp), dtype=int32)
if not self.symmetric_attention:
tmp[1:, 0] = 0 # disable atom to CLS attention
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = 'chytorch'
version = '1.64'
version = '1.65'
description = 'Library for modeling molecules and reactions in torch way'
authors = ['Ramil Nugmanov <[email protected]>']
license = 'MIT'
Expand Down

0 comments on commit d10859e

Please sign in to comment.