Skip to content

Commit

Permalink
fixup! Refactor 3Di functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
padix-key committed Oct 6, 2024
1 parent 121eafb commit 5a1ad8f
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 40 deletions.
10 changes: 5 additions & 5 deletions src/biotite/structure/alphabet/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@

__name__ = "biotite.structure.alphabet"
__author__ = "Martin Larralde"
__all__ = ["Encoder", "VirtualCenterEncoder", "PartnerIndexEncoder", "FeatureEncoder"]

import abc
import typing
from importlib.resources import files as resource_files

import numpy
import numpy.ma

from biotite.structure.alphabet.unkerasify import unkerasify
from biotite.structure.alphabet.unkerasify import load
from biotite.structure.alphabet.layers import CentroidLayer, Model
from biotite.structure.alphabet.unkerasify import load_kerasify


class _BaseEncoder(abc.ABC):
Expand Down Expand Up @@ -351,9 +352,8 @@ class Encoder(_BaseEncoder):

def __init__(self) -> None:
self.feature_encoder = FeatureEncoder()
with resource_files(__package__).joinpath("encoder_weights_3di.kerasify").open("rb") as f:
layers = unkerasify.load(f)
layers.append(CentroidLayer(self._CENTROIDS))
layers = load_kerasify(resource_files(__package__).joinpath("encoder_weights_3di.kerasify"))
layers.append(CentroidLayer(self._CENTROIDS))
self.vae_encoder = Model(layers)

def encode(
Expand Down
47 changes: 22 additions & 25 deletions src/biotite/structure/alphabet/i3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

__name__ = "biotite.structure.alphabet"
__author__ = "Martin Larralde"
__all__ = ["I3DSequence"]
__all__ = ["I3DSequence", "to_3di"]

from biotite.sequence.alphabet import LetterAlphabet
from biotite.sequence.sequence import Sequence
from biotite.structure.alphabet.encoder import Encoder
from biotite.structure.error import BadStructureError
import numpy as np


class I3DSequence(Sequence):
Expand Down Expand Up @@ -70,7 +72,7 @@ def __repr__(self):
return f'I3DSequence("{"".join(self.symbols)}")'


def to_3di(array):
def to_3di(atoms):
r"""
Encode the atoms to the 3di structure alphabet.
Expand Down Expand Up @@ -98,22 +100,20 @@ def to_3di(array):


def _encode_atoms(
atoms,
ca_residue: bool = True,
disordered_atom: Literal["best", "last"] = "best",
) -> T:
if not numpy.all(array.chain_id == array.chain_id[0]):
raise BadStructureError("structure contains more than one chain")

ca_atoms = array[array.atom_name == 'CA']
cb_atoms = array[array.atom_name == 'CB']
n_atoms = array[array.atom_name == 'N']
c_atoms = array[array.atom_name == 'C']

r = array.res_id.max()

ca = numpy.zeros((r + 1, 3), dtype=numpy.float32)
ca.fill(numpy.nan)
atoms
):
if not np.all(atoms.chain_id == atoms.chain_id[0]):
raise BadStructureError("Structure contains more than one chain")

ca_atoms = atoms[atoms.atom_name == 'CA']
cb_atoms = atoms[atoms.atom_name == 'CB']
n_atoms = atoms[atoms.atom_name == 'N']
c_atoms = atoms[atoms.atom_name == 'C']

r = atoms.res_id.max()

ca = np.zeros((r + 1, 3), dtype=np.float32)
ca.fill(np.nan)
cb = ca.copy()
n = ca.copy()
c = ca.copy()
Expand All @@ -123,12 +123,9 @@ def _encode_atoms(
n[n_atoms.res_id, :] = n_atoms.coord
c[c_atoms.res_id, :] = c_atoms.coord

if ca_residue:
ca = ca[ca_atoms.res_id]
cb = cb[ca_atoms.res_id]
n = n[ca_atoms.res_id]
c = c[ca_atoms.res_id]
else:
raise NotImplementedError
ca = ca[ca_atoms.res_id]
cb = cb[ca_atoms.res_id]
n = n[ca_atoms.res_id]
c = c[ca_atoms.res_id]

return Encoder().encode(ca, cb, n, c)
5 changes: 5 additions & 0 deletions src/biotite/structure/alphabet/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
"""
Mini implementation of the neural network layers used in ``foldseek``.
"""

__name__ = "biotite.structure.alphabet"
__author__ = "Martin Larralde"
__all__ = ["Layer", "DenseLayer", "CentroidLayer", "Model"]

import abc
import functools
from typing import Iterable, Optional
Expand Down
26 changes: 16 additions & 10 deletions src/biotite/structure/alphabet/unkerasify.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
Parser for extracting weights from Keras files.
Adapted from `moof2k/kerasify <https://github.com/moof2k/kerasify>`_.
"""

__name__ = "biotite.structure.alphabet"
__author__ = "Martin Larralde"
__all__ = ["load_kerasify"]

import functools
import enum
import itertools
import struct
import typing

import numpy
import numpy as np

from .layers import Layer, DenseLayer

Expand Down Expand Up @@ -47,7 +51,7 @@ class KerasifyParser:
"""

def __init__(self, file: typing.BinaryIO) -> None:
def __init__(self, file) -> None:
self.file = file
self.buffer = bytearray(1024)
(self.n_layers,) = self._get("I")
Expand All @@ -71,11 +75,11 @@ def _read(self, format: str) -> memoryview:
self.file.readinto(v) # type: ignore
return v

def _get(self, format: str) -> typing.Tuple[typing.Any, ...]:
def _get(self, format: str):
v = self._read(format)
return struct.unpack(format, v)

def read(self) -> typing.Optional[Layer]:
def read(self):
if self.n_layers == 0:
return None

Expand All @@ -86,11 +90,11 @@ def read(self) -> typing.Optional[Layer]:
(w1,) = self._get("I")
(b0,) = self._get("I")
weights = (
numpy.frombuffer(self._read(f"={w0*w1}f"), dtype="f4")
np.frombuffer(self._read(f"={w0*w1}f"), dtype="f4")
.reshape(w0, w1)
.copy()
)
biases = numpy.frombuffer(self._read(f"={b0}f"), dtype="f4").copy()
biases = np.frombuffer(self._read(f"={b0}f"), dtype="f4").copy()
activation = ActivationType(self._get("I")[0])
if activation not in (ActivationType.LINEAR, ActivationType.RELU):
raise NotImplementedError(f"Unsupported activation type: {activation!r}")
Expand All @@ -99,5 +103,7 @@ def read(self) -> typing.Optional[Layer]:
raise NotImplementedError(f"Unsupported layer type: {layer_type!r}")


def load(fh):
return list(KerasifyParser(fh))
@functools.cache
def load_kerasify(file_path):
with open(file_path, "rb") as file:
return list(KerasifyParser(file))

0 comments on commit 5a1ad8f

Please sign in to comment.