Skip to content

Commit

Permalink
Merge pull request #664 from padix-key/pickle-performance
Browse files Browse the repository at this point in the history
Improve performance of pickling `KmerTable`
  • Loading branch information
padix-key authored Sep 16, 2024
2 parents 53afd0a + 51ad869 commit 231eefe
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 44 deletions.
94 changes: 52 additions & 42 deletions src/biotite/sequence/align/kmertable.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1384,8 +1384,7 @@ cdef class KmerTable:


def __getstate__(self):
relevant_kmers = self.get_kmers()
return _pickle_c_arrays(self._ptr_array, relevant_kmers)
return _pickle_c_arrays(self._ptr_array)


def __setstate__(self, state):
Expand Down Expand Up @@ -2836,12 +2835,7 @@ cdef class BucketKmerTable:


def __getstate__(self):
cdef int64[:] relevant_buckets = np.where(
np.asarray(self._ptr_array) != 0
)[0]
return _pickle_c_arrays(self._ptr_array, relevant_buckets)


return _pickle_c_arrays(self._ptr_array)

def __setstate__(self, state):
_unpickle_c_arrays(self._ptr_array, state)
Expand Down Expand Up @@ -3097,27 +3091,44 @@ def _append_entries(ptr[:] trg_ptr_array, ptr[:] src_ptr_array):

@cython.boundscheck(False)
@cython.wraparound(False)
def _pickle_c_arrays(ptr[:] ptr_array, int64[:] relevant_buckets):
def _pickle_c_arrays(ptr[:] ptr_array):
"""
Pickle the `relevant_buckets` (i.e. the buckets that actualy point
to an array) of the `ptr_array` into a list of bytes.
Pickle the C arrays into a single concatenated :class:`ndarray`.
The lengths of each C-array on these concatenated array is saved as well.
"""
cdef int64 i
cdef int64 bucket
cdef int64 pointer_i, bucket_i, concat_i
cdef int64 length
cdef uint32* bucket_ptr

cdef list pickled_arrays = [b""] * relevant_buckets.shape[0]

for i in range(relevant_buckets.shape[0]):
bucket = relevant_buckets[i]
bucket_ptr = <uint32*>ptr_array[bucket]
length = (<int64*>bucket_ptr)[0]
# Get directly the bytes coding for each C-array
pickled_arrays[i] \
= <bytes>(<char*>bucket_ptr)[:sizeof(uint32) * length]
# First pass: Count the total concatenated size
cdef int64 total_length = 0
for pointer_i in range(ptr_array.shape[0]):
bucket_ptr = <uint32*>ptr_array[pointer_i]
if bucket_ptr != NULL:
# The first element of the C-array is the length
# of the array
total_length += (<int64*>bucket_ptr)[0]

# Second pass: Copy the C-arrays into a single concatenated array
# and track the start position of each C-array
cdef uint32[:] concatenated_array = np.empty(total_length, dtype=np.uint32)
cdef int64[:] lengths = np.empty(ptr_array.shape[0], dtype=np.int64)
concat_i = 0
for pointer_i in range(ptr_array.shape[0]):
bucket_ptr = <uint32*>ptr_array[pointer_i]
if bucket_ptr != NULL:
length = (<int64*>bucket_ptr)[0]
lengths[pointer_i] = length
memcpy(
&concatenated_array[concat_i],
bucket_ptr,
length * sizeof(uint32),
)
concat_i += length
else:
lengths[pointer_i] = 0

return np.asarray(relevant_buckets), pickled_arrays
return np.asarray(concatenated_array), np.asarray(lengths)


@cython.boundscheck(False)
Expand All @@ -3126,28 +3137,27 @@ def _unpickle_c_arrays(ptr[:] ptr_array, state):
"""
Unpickle the pickled `state` into the given `ptr_array`.
"""
cdef int64 i
cdef int64 bucket
cdef int64 byte_length
cdef int64 pointer_i, concat_i
cdef int64 length
cdef uint32* bucket_ptr
cdef bytes pickled_bytes

cdef int64[:] relevant_buckets = state[0]
cdef list pickled_pointers = state[1]

for i in range(relevant_buckets.shape[0]):
bucket = relevant_buckets[i]
if bucket < 0 or bucket >= ptr_array.shape[0]:
raise ValueError("Invalid bucket found while unpickling")
pickled_bytes = pickled_pointers[i]
byte_length = len(pickled_bytes)
if byte_length != 0:
bucket_ptr = <uint32*>malloc(byte_length)

cdef uint32[:] concatenated_array = state[0]
cdef int64[:] lengths = state[1]

concat_i = 0
for pointer_i in range(ptr_array.shape[0]):
length = lengths[pointer_i]
if length != 0:
bucket_ptr = <uint32*>malloc(length * sizeof(uint32))
if not bucket_ptr:
raise MemoryError
# Convert bytes back into C-array
memcpy(bucket_ptr, <char*>pickled_bytes, byte_length)
ptr_array[bucket] = <ptr>bucket_ptr
memcpy(
bucket_ptr,
&concatenated_array[concat_i],
length * sizeof(uint32),
)
concat_i += length
ptr_array[pointer_i] = <ptr>bucket_ptr


cdef inline void _deallocate_ptrs(ptr[:] ptrs):
Expand Down
7 changes: 5 additions & 2 deletions tests/sequence/align/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from os.path import join
import pytest
import biotite.sequence as seq
import biotite.sequence.io.fasta as fasta
from tests.util import data_dir


Expand All @@ -14,5 +12,10 @@ def sequences():
"""
10 Cas9 sequences.
"""
# Import in function to avoid 'ModuleNotFoundError',
# if a Cython module is not compiled yet
import biotite.sequence as seq
import biotite.sequence.io.fasta as fasta

fasta_file = fasta.FastaFile.read(join(data_dir("sequence"), "cas9.fasta"))
return [seq.ProteinSequence(sequence) for sequence in fasta_file.values()]

0 comments on commit 231eefe

Please sign in to comment.