Skip to content

Commit

Permalink
Light and heavy chain shims
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Feb 16, 2025
1 parent 954c28c commit a4b545f
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def chunked(iterable, n):
yield chunk


# TODO this is defunct now right?
def assume_single_sequence_is_heavy_chain(seq_arg_idx=0):
"""Wraps a function that takes a heavy/light sequence pair as its first argument and
returns a tuple of results.
Expand All @@ -359,6 +360,28 @@ def wrapper(*args, **kwargs):
return decorator


def heavy_chain_shim(paired_evaluator):
"""Returns a function that evaluates only heavy chains given a paired evaluator."""

def evaluate_heavy_chains(sequences):
paired_seqs = [[h, ""] for h in sequences]
paired_outputs = paired_evaluator(paired_seqs)
return [output[0] for output in paired_outputs]

return evaluate_heavy_chains


def light_chain_shim(paired_evaluator):
"""Returns a function that evaluates only light chains given a paired evaluator."""

def evaluate_light_chains(sequences):
paired_seqs = [["", l] for l in sequences]
paired_outputs = paired_evaluator(paired_seqs)
return [output[1] for output in paired_outputs]

return evaluate_light_chains


def chunk_function(
first_chunkable_idx=0, default_chunk_size=2048, progress_bar_name=None
):
Expand Down

0 comments on commit a4b545f

Please sign in to comment.