Skip to content

Commit

Permalink
fix docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Oct 28, 2024
1 parent 67c20aa commit bbef2e0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,8 @@ def predict(self, representation: Tensor) -> Tensor:
representation: A tensor of shape (B, L, E) representing the
embedded parent sequences.
Returns:
A tensor of shape (B, L, 1) representing the log level of selection
for each amino acid site.
A tensor of shape (B, L, out_features) representing the log level
of selection for each amino acid site.
"""
return self.linear(representation).squeeze(-1)

Expand All @@ -656,8 +656,8 @@ def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
mask: A tensor of shape (B, L) representing the mask of valid amino acid sites.
Returns:
A tensor of shape (B, L, 20) representing the log level of selection
for each possible amino acid at each site.
A tensor of shape (B, L, out_features) representing the log level
of selection for each possible amino acid at each site.
"""
return self.predict(self.represent(amino_acid_indices, mask))

Expand Down

0 comments on commit bbef2e0

Please sign in to comment.