Skip to content

Commit

Permalink
Merge pull request #141 from kylebgorman/cleanup
Browse files Browse the repository at this point in the history
Cleanups and docs after #135.
  • Loading branch information
kylebgorman authored Oct 3, 2023
2 parents f372dc5 + 0559726 commit 98a585f
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 78 deletions.
88 changes: 40 additions & 48 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
Yoyodyne 🪀
==========
# Yoyodyne 🪀

[![PyPI
version](https://badge.fury.io/py/yoyodyne.svg)](https://pypi.org/project/yoyodyne)
Expand All @@ -18,8 +17,7 @@ models are particularly well-suited for problems where the source-target
alignments are roughly monotonic (e.g., `transducer`) and/or where source and
target vocabularies have substantial overlap (e.g., `pointer_generator_lstm`).

Philosophy
----------
## ehilosophy

Yoyodyne is inspired by [FairSeq](https://github.com/facebookresearch/fairseq)
(Ott et al. 2019) but differs on several key points of design:
Expand All @@ -38,8 +36,7 @@ Yoyodyne is inspired by [FairSeq](https://github.com/facebookresearch/fairseq)
- 🚧 UNDER CONSTRUCTION 🚧: It has exhaustive test suites.
- 🚧 UNDER CONSTRUCTION 🚧: It has performance benchmarks.

Install
-------
## Install

First install dependencies:

Expand All @@ -51,12 +48,11 @@ Then install:

It can then be imported like a regular Python module:

```python
``` python
import yoyodyne
```

Usage
-----
## Usage

### Training

Expand Down Expand Up @@ -87,8 +83,7 @@ The `--predict` file can either be a TSV file or an ordinary TXT file with one
source string per line; in the latter case, specify `--target_col 0`. Run
[`yoyodyne-predict --help`](yoyodyne/predict.py) for more information.

Data format
-----------
## Data format

The default data format is a two-column TSV file in which the first column is
the source string and the second the target string.
Expand All @@ -98,8 +93,8 @@ the source string and the second the target string.
To enable the use of a feature column, one specifies a (non-zero) argument to
`--features_col`. For instance in the [SIGMORPHON 2017 shared
task](https://sigmorphon.github.io/sharedtasks/2017/), the first column is the
source (a lemma), the second is the target (the inflection), and
the third contains semi-colon delimited feature strings:
source (a lemma), the second is the target (the inflection), and the third
contains semi-colon delimited feature strings:

source target feat1;feat2;...

Expand All @@ -115,16 +110,14 @@ this format is specified by `--features_col 2 --features_sep , --target_col 3`.
In order to ensure that targets are ignored during prediction, one can specify
`--target_col 0`.

Reserved symbols
----------------
## Reserved symbols

Yoyodyne reserves symbols of the form `<...>` for internal use.
Feature-conditioned models also use `[...]` to avoid clashes between feature
symbols and source and target symbols. Therefore, users should not provide any
symbols of the form `<...>` or `[...]`.

Model checkpointing
-------------------
## Model checkpointing

Checkpointing is handled by
[Lightning](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html).
Expand All @@ -146,8 +139,7 @@ During training, we save the best `--save_top_k` checkpoints (by default, 1)
ranked according to accuracy on the `--val` set. For example, `--save_top_k 5`
will save the top 5 most accurate models.

Models
------
## Models

The user specifies the overall architecture for the model using the `--arch`
flag. The value of this flag specifies the decoder's architecture and whether or
Expand All @@ -158,33 +150,39 @@ additional flags. Supported values for `--arch` are:
- `attentive_lstm`: This is an LSTM decoder with LSTM encoders (by default)
and an attention mechanism. The initial hidden state is treated as a learned
parameter.
- `lstm`: This is an LSTM decoder with LSTM encoders (by default); in lieu
of an attention mechanism, the last non-padding hidden state of the encoder
is concatenated with the decoder hidden state.
- `lstm`: This is an LSTM decoder with LSTM encoders (by default); in lieu of
an attention mechanism, the last non-padding hidden state of the encoder is
concatenated with the decoder hidden state.
- `pointer_generator_lstm`: This is an LSTM decoder with LSTM encoders (by
default) and a pointer-generator mechanism. Since this model contains a copy
mechanism, it may be superior to the `lstm` when the input and output
vocabularies overlap significantly. Note that this model requires that the
number of `--encoder_layers` and `--decoder_layers` match.
- `transducer`: This is an LSTM decoder with LSTM encoders (by default) and
a neural transducer mechanism. On model creation, expectation maximization
is used to learn a sequence of edit operations, and imitation learning is
used to train the model to implement the oracle policy, with roll-in
controlled by the `--oracle_factor` flag (default: `1`). Since this model
assumes monotonic alignment, it may be superior to attentive models when the
mechanism, it may be superior to an ordinary attentive LSTM when the source
and target vocabularies overlap significantly. Note that this model requires
that the number of `--encoder_layers` and `--decoder_layers` match.
- `pointer_generator_transformer`: This is a transformer decoder with
transformer encoders (by default) and a pointer-generator mechanism. Like
`pointer_generator_lstm`, it may be superior to an ordinary transformer when
the source and target vocabularies overlap significantly. When using
features, the user may wish to specify the n umber of features attention
heads (with `--features_attention_heads`; default: `1`).
- `transducer`: This is an LSTM decoder with LSTM encoders (by default) and a
neural transducer mechanism. On model creation, expectation maximization is
used to learn a sequence of edit operations, and imitation learning is used
to train the model to implement the oracle policy, with roll-in controlled
by the `--oracle_factor` flag (default: `1`). Since this model assumes
monotonic alignment, it may be superior to attentive models when the
alignment between input and output is roughly monotonic and when input and
output vocabularies overlap significantly.
- `transformer`: This is a transformer decoder with transformer encoders (by
default) and an attention mechanism. Sinusodial positional encodings and
layer normalization are used. The user may wish to specify the number of
attention heads (with `--source_attention_heads`; default: `4`).
default). Sinusodial positional encodings and layer normalization are used.
The user may wish to specify the number of attention heads (with
`--source_attention_heads`; default: `4`).

The user can override the default encoder architectures. One can override the
source encoder using the `--source_encoder` flag:

- `feature_invariant_transformer`: This is a variant of the transformer
encoder for use with features; it concatenates source and features and uses
a learned embedding to distinguish between source and features symbols.
encoder used with features; it concatenates source and features and uses a
learned embedding to distinguish between source and features symbols.
- `linear`: This is a linear encoder.
- `lstm`: This is a LSTM encoder.
- `transformer`: This is a transformer encoder.
Expand All @@ -202,8 +200,7 @@ For all models, the user may also wish to specify:
By default, LSTM encoders are bidirectional. One can disable this with the
`--no_bidirectional` flag.

Training options
----------------
## Training options

A non-exhaustive list includes:

Expand Down Expand Up @@ -285,8 +282,7 @@ decay scheduler.
[`wandb_sweeps`](examples/wandb_sweeps) shows how to use [Weights &
Biases](https://wandb.ai/site) to run hyperparameter sweeps.

Accelerators
------------
## Accelerators

[Hardware
accelerators](https://pytorch-lightning.readthedocs.io/en/stable/extensions/accelerator.html)
Expand All @@ -295,8 +291,7 @@ GPU (`--accelerator gpu`), [other
accelerators](https://pytorch-lightning.readthedocs.io/en/stable/extensions/accelerator.html)
may also be supported but not all have been tested yet.

Precision
---------
## Precision

By default, training uses 32-bit precision. However, the `--precision` flag
allows the user to perform training with half precision (`16`) or with the
Expand All @@ -305,16 +300,14 @@ format](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) if
supported by the accelerator. This may reduce the size of the model and batches
in memory, allowing one to use larger batches.

Examples
--------
## Examples

The [`examples`](examples) directory contains interesting examples, including:

- [`wandb_sweeps`](examples/wandb_sweeps) shows how to use [Weights &
Biases](https://wandb.ai/site) to run hyperparameter sweeps.

For developers
--------------
## For developers

*Developers, developers, developers!* - Steve Ballmer

Expand All @@ -341,8 +334,7 @@ This section contains instructions for the Yoyodyne maintainers.
9. Build the new release: `python -m build`
10. Upload the result to PyPI: `twine upload dist/*`

References
----------
## References

Ott, M., Edunov, S., Baevski, A., Fan, A., Gross, S., Ng, N., Grangier, D., and
Auli, M. 2019. [fairseq: a fast, extensible toolkit for sequence
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ exclude = ["examples*"]

[project]
name = "yoyodyne"
version = "0.2.5"
version = "0.2.6"
description = "Small-vocabulary neural sequence-to-sequence models"
readme = "README.md"
requires-python = ">= 3.9"
Expand Down
30 changes: 14 additions & 16 deletions yoyodyne/models/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __call__(
module: nn.Module,
module_in: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
module_out: Tuple[torch.Tensor, torch.Tensor],
):
) -> None:
"""Stores the second return argument of `module`.
This is intended to be called on a multiehaded attention, which returns
Expand All @@ -108,7 +108,7 @@ def __call__(
"""
self.outputs.append(module_out[1])

def clear(self):
def clear(self) -> None:
"""Clears the outputs."""
self.outputs.clear()

Expand Down Expand Up @@ -361,7 +361,7 @@ def forward(
self.norm1(x),
target_mask,
target_key_padding_mask,
# FIXME: Introduced in torch 2.0
# FIXME: Introduced in torch 2.0.
# is_causal=target_is_causal
)
x = self.norm2(x)
Expand All @@ -370,7 +370,7 @@ def forward(
memory,
memory_mask,
memory_key_padding_mask,
# FIXME Introduced in torch 2.0
# FIXME Introduced in torch 2.0.
# memory_is_causal,
)
# TODO: Do we want a nonlinear activation?
Expand All @@ -380,7 +380,7 @@ def forward(
features_memory,
features_memory_mask,
features_memory_mask,
# FIXME Introduced in torch 2.0
# FIXME Introduced in torch 2.0.
# memory_is_causal,
)
# TODO: Do we want a nonlinear activation?
Expand All @@ -394,7 +394,7 @@ def forward(
x,
target_mask,
target_key_padding_mask,
# FIXME: Introduced in torch 2.0
# FIXME: Introduced in torch 2.0.
# is_causal=target_is_causal
)
)
Expand All @@ -403,7 +403,7 @@ def forward(
memory,
memory_mask,
memory_key_padding_mask,
# FIXME Introduced in torch 2.0
# FIXME Introduced in torch 2.0.
# memory_is_causal,
)
# TODO: Do we want a nonlinear activation?
Expand All @@ -413,15 +413,14 @@ def forward(
features_memory,
features_memory_mask,
features_memory_mask,
# FIXME Introduced in torch 2.0
# FIXME Introduced in torch 2.0.
# memory_is_causal,
)
# TODO: Do we want a nonlinear activation?
feature_attention = self.features_linear(feature_attention)
x = x + torch.cat([symbol_attention, feature_attention], dim=2)
x = self.norm2(x)
x = self.norm3(x + self._ff_block(x))

return x

def _features_mha_block(
Expand All @@ -430,7 +429,7 @@ def _features_mha_block(
mem: torch.Tensor,
attn_mask: Optional[torch.Tensor],
key_padding_mask: Optional[torch.Tensor],
# FIXME: Introduced in torch 2.0
# FIXME: Introduced in torch 2.0.
# is_causal: bool = False,
) -> torch.Tensor:
"""Runs the multihead attention block that attends to features.
Expand All @@ -453,7 +452,7 @@ def _features_mha_block(
mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
# FIXME: Introduced in torch 2.0
# FIXME: Introduced in torch 2.0.
# is_causal=is_causal,
need_weights=False,
)[0]
Expand Down Expand Up @@ -501,7 +500,6 @@ def forward(
torch.Tensor: Output tensor.
"""
output = target

for mod in self.layers:
output = mod(
output,
Expand Down Expand Up @@ -624,7 +622,7 @@ def __init__(
self.separate_features = separate_features
self.features_attention_heads = features_attention_heads
super().__init__(*args, **kwargs)
# Call this to get the actual cross attentions
# Call this to get the actual cross attentions.
self.attention_output = AttentionOutput()
# multihead_attn refers to the attention from decoder to encoder.
self.patch_attention(self.module.layers[-1].multihead_attn)
Expand Down Expand Up @@ -664,7 +662,7 @@ def forward(
causal_mask = self.generate_square_subsequent_mask(
target_sequence_length
).to(self.device)
# -> B x seq_len x d_model
# -> B x seq_len x d_model.
if self.separate_features:
output = self.module(
target_embedding,
Expand All @@ -676,8 +674,8 @@ def forward(
target_key_padding_mask=target_mask,
)
else:
# TODO: Resolve mismatch between our 'target'
# naming convention and pytorch 'tgt'
# TODO: Resolve mismatch between our 'target' naming convention and
# torch's use of `tgt`.
output = self.module(
target_embedding,
encoder_hidden,
Expand Down
Loading

0 comments on commit 98a585f

Please sign in to comment.