Skip to content

Commit

Permalink
Add StaticVectors.v1 (#5)
Browse files Browse the repository at this point in the history
* Add StaticVectors.v1

* Switch to registry for tok2vec

For tok2vec, freeze the implementations by retrieving all layers and
architectures from the registry instead of importing from `thinc` or
`spacy`.

* Temporarily revert to _character_embed.CharacterEmbed

* Move layer to thinc_layers

* Fix formatting
  • Loading branch information
adrianeboyd authored Apr 21, 2021
1 parent 2d5d239 commit f5855a8
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 3 deletions.
7 changes: 6 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[metadata]
version = 3.0.2
version = 3.0.3
description = Legacy registered functions for spaCy backwards compatibility
url = https://spacy.io
author = Explosion
Expand Down Expand Up @@ -39,6 +39,11 @@ spacy_architectures =
spacy-legacy.MishWindowEncoder.v1 = spacy_legacy.architectures.tok2vec:MishWindowEncoder_v1
spacy-legacy.TextCatEnsemble.v1 = spacy_legacy.architectures.textcat:TextCatEnsemble_v1
spacy-legacy.WandbLogger.v1 = spacy_legacy.loggers:wandb_logger_v1
spacy-legacy.HashEmbedCNN.v1 = spacy_legacy.architectures.tok2vec:HashEmbedCNN_v1
spacy-legacy.MultiHashEmbed.v1 = spacy_legacy.architectures.tok2vec:MultiHashEmbed_v1
spacy-legacy.CharacterEmbed.v1 = spacy_legacy.architectures.tok2vec:CharacterEmbed_v1
thinc_layers =
spacy-legacy.StaticVectors.v1 = spacy_legacy.layers.staticvectors_v1:StaticVectors_v1

[bdist_wheel]
universal = true
Expand Down
224 changes: 222 additions & 2 deletions spacy_legacy/architectures/tok2vec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import List
from typing import List, Optional, Union
from thinc.api import Model, chain, with_array, clone, residual, expand_window
from thinc.api import Maxout, Mish
from thinc.api import concatenate, list2ragged, ragged2list
from thinc.types import Floats2d
from spacy.attrs import intify_attr
from spacy.errors import Errors
from spacy.ml import _character_embed
from spacy.tokens import Doc
from spacy.util import registry


def Tok2Vec_v1(
Expand Down Expand Up @@ -40,6 +44,7 @@ def MaxoutWindowEncoder_v1(
values are 2 or 3.
depth (int): The number of convolutional layers. Recommended value is 4.
"""
Maxout = registry.get("layers", "Maxout.v1")
cnn = chain(
expand_window(window_size=window_size),
Maxout(
Expand Down Expand Up @@ -69,10 +74,225 @@ def MishWindowEncoder_v1(
to construct the convolution. Recommended value is 1.
depth (int): The number of convolutional layers. Recommended value is 4.
"""
Mish = registry.get("layers", "Mish.v1")
cnn = chain(
expand_window(window_size=window_size),
Mish(nO=width, nI=width * ((window_size * 2) + 1), dropout=0.0, normalize=True),
)
model = clone(residual(cnn), depth)
model.set_dim("nO", width)
return model


def HashEmbedCNN_v1(
*,
width: int,
depth: int,
embed_size: int,
window_size: int,
maxout_pieces: int,
subword_features: bool,
pretrained_vectors: Optional[bool],
) -> Model[List[Doc], List[Floats2d]]:
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
with subword features and a CNN with layer-normalized maxout.
width (int): The width of the input and output. These are required to be the
same, so that residual connections can be used. Recommended values are
96, 128 or 300.
depth (int): The number of convolutional layers to use. Recommended values
are between 2 and 8.
window_size (int): The number of tokens on either side to concatenate during
the convolutions. The receptive field of the CNN will be
depth * (window_size * 2 + 1), so a 4-layer network with window_size of
2 will be sensitive to 17 words at a time. Recommended value is 1.
embed_size (int): The number of rows in the hash embedding tables. This can
be surprisingly small, due to the use of the hash embeddings. Recommended
values are between 2000 and 10000.
maxout_pieces (int): The number of pieces to use in the maxout non-linearity.
If 1, the Mish non-linearity is used instead. Recommended values are 1-3.
subword_features (bool): Whether to also embed subword features, specifically
the prefix, suffix and word shape. This is recommended for alphabetic
languages like English, but not if single-character tokens are used for
a language such as Chinese.
pretrained_vectors (bool): Whether to also use static vectors.
"""
build_Tok2Vec_model = registry.get("architectures", "spacy.Tok2Vec.v2")
MultiHashEmbed = registry.get("architectures", "spacy.MultiHashEmbed.v1")
MaxoutWindowEncoder = registry.get("architectures", "spacy.MaxoutWindowEncoder.v2")
if subword_features:
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
row_sizes = [embed_size, embed_size // 2, embed_size // 2, embed_size // 2]
else:
attrs = ["NORM"]
row_sizes = [embed_size]
return build_Tok2Vec_model(
embed=MultiHashEmbed(
width=width,
rows=row_sizes,
attrs=attrs,
include_static_vectors=bool(pretrained_vectors),
),
encode=MaxoutWindowEncoder(
width=width,
depth=depth,
window_size=window_size,
maxout_pieces=maxout_pieces,
),
)


def MultiHashEmbed_v1(
width: int,
attrs: List[Union[str, int]],
rows: List[int],
include_static_vectors: bool,
) -> Model[List[Doc], List[Floats2d]]:
"""Construct an embedding layer that separately embeds a number of lexical
attributes using hash embedding, concatenates the results, and passes it
through a feed-forward subnetwork to build a mixed representation.
The features used can be configured with the 'attrs' argument. The suggested
attributes are NORM, PREFIX, SUFFIX and SHAPE. This lets the model take into
account some subword information, without constructing a fully character-based
representation. If pretrained vectors are available, they can be included in
the representation as well, with the vectors table will be kept static
(i.e. it's not updated).
The `width` parameter specifies the output width of the layer and the widths
of all embedding tables. If static vectors are included, a learned linear
layer is used to map the vectors to the specified width before concatenating
it with the other embedding outputs. A single Maxout layer is then used to
reduce the concatenated vectors to the final width.
The `rows` parameter controls the number of rows used by the `HashEmbed`
tables. The HashEmbed layer needs surprisingly few rows, due to its use of
the hashing trick. Generally between 2000 and 10000 rows is sufficient,
even for very large vocabularies. A number of rows must be specified for each
table, so the `rows` list must be of the same length as the `attrs` parameter.
width (int): The output width. Also used as the width of the embedding tables.
Recommended values are between 64 and 300.
attrs (list of attr IDs): The token attributes to embed. A separate
embedding table will be constructed for each attribute.
rows (List[int]): The number of rows in the embedding tables. Must have the
same length as attrs.
include_static_vectors (bool): Whether to also use static word vectors.
Requires a vectors table to be loaded in the Doc objects' vocab.
"""
HashEmbed = registry.get("layers", "HashEmbed.v1")
FeatureExtractor = registry.get("layers", "spacy.FeatureExtractor.v1")
Maxout = registry.get("layers", "Maxout.v1")
StaticVectors = registry.get("layers", "spacy.StaticVectors.v1")
if len(rows) != len(attrs):
raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}")
seed = 7

def make_hash_embed(index):
nonlocal seed
seed += 1
return HashEmbed(width, rows[index], column=index, seed=seed, dropout=0.0)

embeddings = [make_hash_embed(i) for i in range(len(attrs))]
concat_size = width * (len(embeddings) + include_static_vectors)
if include_static_vectors:
model = chain(
concatenate(
chain(
FeatureExtractor(attrs),
list2ragged(),
with_array(concatenate(*embeddings)),
),
StaticVectors(width, dropout=0.0),
),
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
ragged2list(),
)
else:
model = chain(
FeatureExtractor(list(attrs)),
list2ragged(),
with_array(concatenate(*embeddings)),
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
ragged2list(),
)
return model


def CharacterEmbed_v1(
width: int,
rows: int,
nM: int,
nC: int,
include_static_vectors: bool,
feature: Union[int, str] = "LOWER",
) -> Model[List[Doc], List[Floats2d]]:
"""Construct an embedded representation based on character embeddings, using
a feed-forward network. A fixed number of UTF-8 byte characters are used for
each word, taken from the beginning and end of the word equally. Padding is
used in the centre for words that are too short.
For instance, let's say nC=4, and the word is "jumping". The characters
used will be jung (two from the start, two from the end). If we had nC=8,
the characters would be "jumpping": 4 from the start, 4 from the end. This
ensures that the final character is always in the last position, instead
of being in an arbitrary position depending on the word length.
The characters are embedded in a embedding table with a given number of rows,
and the vectors concatenated. A hash-embedded vector of the LOWER of the word is
also concatenated on, and the result is then passed through a feed-forward
network to construct a single vector to represent the information.
feature (int or str): An attribute to embed, to concatenate with the characters.
width (int): The width of the output vector and the feature embedding.
rows (int): The number of rows in the LOWER hash embedding table.
nM (int): The dimensionality of the character embeddings. Recommended values
are between 16 and 64.
nC (int): The number of UTF-8 bytes to embed per word. Recommended values
are between 3 and 8, although it may depend on the length of words in the
language.
include_static_vectors (bool): Whether to also use static word vectors.
Requires a vectors table to be loaded in the Doc objects' vocab.
"""
# TODO: replace with registered layer after spacy v3.0.6
#CharEmbed = registry.get("layers", "spacy.CharEmbed.v1")
CharEmbed = _character_embed.CharacterEmbed
FeatureExtractor = registry.get("layers", "spacy.FeatureExtractor.v1")
Maxout = registry.get("layers", "Maxout.v1")
HashEmbed = registry.get("layers", "HashEmbed.v1")
StaticVectors = registry.get("layers", "spacy.StaticVectors.v1")
feature = intify_attr(feature)
if feature is None:
raise ValueError(Errors.E911(feat=feature))
if include_static_vectors:
model = chain(
concatenate(
chain(CharEmbed(nM=nM, nC=nC), list2ragged()),
chain(
FeatureExtractor([feature]),
list2ragged(),
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
),
StaticVectors(width, dropout=0.0),
),
with_array(
Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0)
),
ragged2list(),
)
else:
model = chain(
concatenate(
chain(CharEmbed(nM=nM, nC=nC), list2ragged()),
chain(
FeatureExtractor([feature]),
list2ragged(),
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)),
),
),
with_array(
Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)
),
ragged2list(),
)
return model
Empty file added spacy_legacy/layers/__init__.py
Empty file.
96 changes: 96 additions & 0 deletions spacy_legacy/layers/staticvectors_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import List, Tuple, Callable, Optional, cast
from thinc.initializers import glorot_uniform_init
from thinc.util import partial
from thinc.types import Ragged, Floats2d, Floats1d
from thinc.api import Model, Ops
from spacy.tokens import Doc
from spacy.errors import Errors


def StaticVectors_v1(
nO: Optional[int] = None,
nM: Optional[int] = None,
*,
dropout: Optional[float] = None,
init_W: Callable = glorot_uniform_init,
key_attr: str = "ORTH"
) -> Model[List[Doc], Ragged]:
"""Embed Doc objects with their vocab's vectors table, applying a learned
linear projection to control the dimensionality. If a dropout rate is
specified, the dropout is applied per dimension over the whole batch.
"""
return Model(
"static_vectors",
forward,
init=partial(init, init_W),
params={"W": None},
attrs={"key_attr": key_attr, "dropout_rate": dropout},
dims={"nO": nO, "nM": nM},
)


def forward(
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
) -> Tuple[Ragged, Callable]:
if not sum(len(doc) for doc in docs):
return _handle_empty(model.ops, model.get_dim("nO"))
key_attr = model.attrs["key_attr"]
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
V = cast(Floats2d, docs[0].vocab.vectors.data)
rows = model.ops.flatten(
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
)
try:
vectors_data = model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True)
except ValueError:
raise RuntimeError(Errors.E896)
output = Ragged(
vectors_data, model.ops.asarray([len(doc) for doc in docs], dtype="i")
)
mask = None
if is_train:
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
if mask is not None:
output.data *= mask

def backprop(d_output: Ragged) -> List[Doc]:
if mask is not None:
d_output.data *= mask
model.inc_grad(
"W",
model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True),
)
return []

return output, backprop


def init(
init_W: Callable,
model: Model[List[Doc], Ragged],
X: Optional[List[Doc]] = None,
Y: Optional[Ragged] = None,
) -> Model[List[Doc], Ragged]:
nM = model.get_dim("nM") if model.has_dim("nM") else None
nO = model.get_dim("nO") if model.has_dim("nO") else None
if X is not None and len(X):
nM = X[0].vocab.vectors.data.shape[1]
if Y is not None:
nO = Y.data.shape[1]

if nM is None:
raise ValueError(Errors.E905)
if nO is None:
raise ValueError(Errors.E904)
model.set_dim("nM", nM)
model.set_dim("nO", nO)
model.set_param("W", init_W(model.ops, (nO, nM)))
return model


def _handle_empty(ops: Ops, nO: int):
return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: []


def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
return ops.get_dropout_mask((nO,), rate) if rate is not None else None

0 comments on commit f5855a8

Please sign in to comment.