Skip to content

Commit

Permalink
add embedding_dim to single model
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Jan 14, 2025
1 parent d4956bb commit 9d15324
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,7 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
# this model was defined, they are stripped out before feeding the
# sequence to the model, and the returned selection factors will be NaN
# at sites containing those unrecognized tokens.
if "embedding_dim" in self.hyperparameters:
model_valid_sites = aa_idxs < self.hyperparameters["embedding_dim"]
else:
model_valid_sites = torch.ones_like(aa_idxs, dtype=torch.bool)
model_valid_sites = aa_idxs < self.hyperparameters["embedding_dim"]
if self.hyperparameters["output_dim"] == 1:
result = torch.full((len(aa_str),), float("nan"), device=self.device)
else:
Expand Down Expand Up @@ -772,14 +769,15 @@ def predict(self, representation: Tensor):
class SingleValueBinarySelectionModel(AbstractBinarySelectionModel):
"""A one parameter selection model as a baseline."""

def __init__(self, output_dim: int = 1):
def __init__(self, output_dim: int = 1, embedding_dim: int = MAX_AA_TOKEN_IDX + 1):
super().__init__()
self.single_value = nn.Parameter(torch.tensor(0.0))
self.output_dim = output_dim
self.embedding_dim = embedding_dim

@property
def hyperparameters(self):
return {"output_dim": self.output_dim}
return {"output_dim": self.output_dim, "embedding_dim": self.embedding_dim}

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
"""Build a binary log selection matrix from an index-encoded parent sequence."""
Expand Down

0 comments on commit 9d15324

Please sign in to comment.