diff --git a/setup.cfg b/setup.cfg index 849f37f..50bbe78 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [metadata] -version = 3.0.2 +version = 3.0.3 description = Legacy registered functions for spaCy backwards compatibility url = https://spacy.io author = Explosion @@ -39,6 +39,11 @@ spacy_architectures = spacy-legacy.MishWindowEncoder.v1 = spacy_legacy.architectures.tok2vec:MishWindowEncoder_v1 spacy-legacy.TextCatEnsemble.v1 = spacy_legacy.architectures.textcat:TextCatEnsemble_v1 spacy-legacy.WandbLogger.v1 = spacy_legacy.loggers:wandb_logger_v1 + spacy-legacy.HashEmbedCNN.v1 = spacy_legacy.architectures.tok2vec:HashEmbedCNN_v1 + spacy-legacy.MultiHashEmbed.v1 = spacy_legacy.architectures.tok2vec:MultiHashEmbed_v1 + spacy-legacy.CharacterEmbed.v1 = spacy_legacy.architectures.tok2vec:CharacterEmbed_v1 +thinc_layers = + spacy-legacy.StaticVectors.v1 = spacy_legacy.layers.staticvectors_v1:StaticVectors_v1 [bdist_wheel] universal = true diff --git a/spacy_legacy/architectures/tok2vec.py b/spacy_legacy/architectures/tok2vec.py index 39fb3b3..239251e 100644 --- a/spacy_legacy/architectures/tok2vec.py +++ b/spacy_legacy/architectures/tok2vec.py @@ -1,8 +1,12 @@ -from typing import List +from typing import List, Optional, Union from thinc.api import Model, chain, with_array, clone, residual, expand_window -from thinc.api import Maxout, Mish +from thinc.api import concatenate, list2ragged, ragged2list from thinc.types import Floats2d +from spacy.attrs import intify_attr +from spacy.errors import Errors +from spacy.ml import _character_embed from spacy.tokens import Doc +from spacy.util import registry def Tok2Vec_v1( @@ -40,6 +44,7 @@ def MaxoutWindowEncoder_v1( values are 2 or 3. depth (int): The number of convolutional layers. Recommended value is 4. """ + Maxout = registry.get("layers", "Maxout.v1") cnn = chain( expand_window(window_size=window_size), Maxout( @@ -69,6 +74,7 @@ def MishWindowEncoder_v1( to construct the convolution. Recommended value is 1. depth (int): The number of convolutional layers. Recommended value is 4. """ + Mish = registry.get("layers", "Mish.v1") cnn = chain( expand_window(window_size=window_size), Mish(nO=width, nI=width * ((window_size * 2) + 1), dropout=0.0, normalize=True), @@ -76,3 +82,217 @@ def MishWindowEncoder_v1( model = clone(residual(cnn), depth) model.set_dim("nO", width) return model + + +def HashEmbedCNN_v1( + *, + width: int, + depth: int, + embed_size: int, + window_size: int, + maxout_pieces: int, + subword_features: bool, + pretrained_vectors: Optional[bool], +) -> Model[List[Doc], List[Floats2d]]: + """Build spaCy's 'standard' tok2vec layer, which uses hash embedding + with subword features and a CNN with layer-normalized maxout. + + width (int): The width of the input and output. These are required to be the + same, so that residual connections can be used. Recommended values are + 96, 128 or 300. + depth (int): The number of convolutional layers to use. Recommended values + are between 2 and 8. + window_size (int): The number of tokens on either side to concatenate during + the convolutions. The receptive field of the CNN will be + depth * (window_size * 2 + 1), so a 4-layer network with window_size of + 2 will be sensitive to 17 words at a time. Recommended value is 1. + embed_size (int): The number of rows in the hash embedding tables. This can + be surprisingly small, due to the use of the hash embeddings. Recommended + values are between 2000 and 10000. + maxout_pieces (int): The number of pieces to use in the maxout non-linearity. + If 1, the Mish non-linearity is used instead. Recommended values are 1-3. + subword_features (bool): Whether to also embed subword features, specifically + the prefix, suffix and word shape. This is recommended for alphabetic + languages like English, but not if single-character tokens are used for + a language such as Chinese. + pretrained_vectors (bool): Whether to also use static vectors. + """ + build_Tok2Vec_model = registry.get("architectures", "spacy.Tok2Vec.v2") + MultiHashEmbed = registry.get("architectures", "spacy.MultiHashEmbed.v1") + MaxoutWindowEncoder = registry.get("architectures", "spacy.MaxoutWindowEncoder.v2") + if subword_features: + attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] + row_sizes = [embed_size, embed_size // 2, embed_size // 2, embed_size // 2] + else: + attrs = ["NORM"] + row_sizes = [embed_size] + return build_Tok2Vec_model( + embed=MultiHashEmbed( + width=width, + rows=row_sizes, + attrs=attrs, + include_static_vectors=bool(pretrained_vectors), + ), + encode=MaxoutWindowEncoder( + width=width, + depth=depth, + window_size=window_size, + maxout_pieces=maxout_pieces, + ), + ) + + +def MultiHashEmbed_v1( + width: int, + attrs: List[Union[str, int]], + rows: List[int], + include_static_vectors: bool, +) -> Model[List[Doc], List[Floats2d]]: + """Construct an embedding layer that separately embeds a number of lexical + attributes using hash embedding, concatenates the results, and passes it + through a feed-forward subnetwork to build a mixed representation. + + The features used can be configured with the 'attrs' argument. The suggested + attributes are NORM, PREFIX, SUFFIX and SHAPE. This lets the model take into + account some subword information, without constructing a fully character-based + representation. If pretrained vectors are available, they can be included in + the representation as well, with the vectors table will be kept static + (i.e. it's not updated). + + The `width` parameter specifies the output width of the layer and the widths + of all embedding tables. If static vectors are included, a learned linear + layer is used to map the vectors to the specified width before concatenating + it with the other embedding outputs. A single Maxout layer is then used to + reduce the concatenated vectors to the final width. + + The `rows` parameter controls the number of rows used by the `HashEmbed` + tables. The HashEmbed layer needs surprisingly few rows, due to its use of + the hashing trick. Generally between 2000 and 10000 rows is sufficient, + even for very large vocabularies. A number of rows must be specified for each + table, so the `rows` list must be of the same length as the `attrs` parameter. + + width (int): The output width. Also used as the width of the embedding tables. + Recommended values are between 64 and 300. + attrs (list of attr IDs): The token attributes to embed. A separate + embedding table will be constructed for each attribute. + rows (List[int]): The number of rows in the embedding tables. Must have the + same length as attrs. + include_static_vectors (bool): Whether to also use static word vectors. + Requires a vectors table to be loaded in the Doc objects' vocab. + """ + HashEmbed = registry.get("layers", "HashEmbed.v1") + FeatureExtractor = registry.get("layers", "spacy.FeatureExtractor.v1") + Maxout = registry.get("layers", "Maxout.v1") + StaticVectors = registry.get("layers", "spacy.StaticVectors.v1") + if len(rows) != len(attrs): + raise ValueError(f"Mismatched lengths: {len(rows)} vs {len(attrs)}") + seed = 7 + + def make_hash_embed(index): + nonlocal seed + seed += 1 + return HashEmbed(width, rows[index], column=index, seed=seed, dropout=0.0) + + embeddings = [make_hash_embed(i) for i in range(len(attrs))] + concat_size = width * (len(embeddings) + include_static_vectors) + if include_static_vectors: + model = chain( + concatenate( + chain( + FeatureExtractor(attrs), + list2ragged(), + with_array(concatenate(*embeddings)), + ), + StaticVectors(width, dropout=0.0), + ), + with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)), + ragged2list(), + ) + else: + model = chain( + FeatureExtractor(list(attrs)), + list2ragged(), + with_array(concatenate(*embeddings)), + with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)), + ragged2list(), + ) + return model + + +def CharacterEmbed_v1( + width: int, + rows: int, + nM: int, + nC: int, + include_static_vectors: bool, + feature: Union[int, str] = "LOWER", +) -> Model[List[Doc], List[Floats2d]]: + """Construct an embedded representation based on character embeddings, using + a feed-forward network. A fixed number of UTF-8 byte characters are used for + each word, taken from the beginning and end of the word equally. Padding is + used in the centre for words that are too short. + + For instance, let's say nC=4, and the word is "jumping". The characters + used will be jung (two from the start, two from the end). If we had nC=8, + the characters would be "jumpping": 4 from the start, 4 from the end. This + ensures that the final character is always in the last position, instead + of being in an arbitrary position depending on the word length. + + The characters are embedded in a embedding table with a given number of rows, + and the vectors concatenated. A hash-embedded vector of the LOWER of the word is + also concatenated on, and the result is then passed through a feed-forward + network to construct a single vector to represent the information. + + feature (int or str): An attribute to embed, to concatenate with the characters. + width (int): The width of the output vector and the feature embedding. + rows (int): The number of rows in the LOWER hash embedding table. + nM (int): The dimensionality of the character embeddings. Recommended values + are between 16 and 64. + nC (int): The number of UTF-8 bytes to embed per word. Recommended values + are between 3 and 8, although it may depend on the length of words in the + language. + include_static_vectors (bool): Whether to also use static word vectors. + Requires a vectors table to be loaded in the Doc objects' vocab. + """ + # TODO: replace with registered layer after spacy v3.0.6 + #CharEmbed = registry.get("layers", "spacy.CharEmbed.v1") + CharEmbed = _character_embed.CharacterEmbed + FeatureExtractor = registry.get("layers", "spacy.FeatureExtractor.v1") + Maxout = registry.get("layers", "Maxout.v1") + HashEmbed = registry.get("layers", "HashEmbed.v1") + StaticVectors = registry.get("layers", "spacy.StaticVectors.v1") + feature = intify_attr(feature) + if feature is None: + raise ValueError(Errors.E911(feat=feature)) + if include_static_vectors: + model = chain( + concatenate( + chain(CharEmbed(nM=nM, nC=nC), list2ragged()), + chain( + FeatureExtractor([feature]), + list2ragged(), + with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), + ), + StaticVectors(width, dropout=0.0), + ), + with_array( + Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0) + ), + ragged2list(), + ) + else: + model = chain( + concatenate( + chain(CharEmbed(nM=nM, nC=nC), list2ragged()), + chain( + FeatureExtractor([feature]), + list2ragged(), + with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), + ), + ), + with_array( + Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0) + ), + ragged2list(), + ) + return model diff --git a/spacy_legacy/layers/__init__.py b/spacy_legacy/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/spacy_legacy/layers/staticvectors_v1.py b/spacy_legacy/layers/staticvectors_v1.py new file mode 100644 index 0000000..c049c70 --- /dev/null +++ b/spacy_legacy/layers/staticvectors_v1.py @@ -0,0 +1,96 @@ +from typing import List, Tuple, Callable, Optional, cast +from thinc.initializers import glorot_uniform_init +from thinc.util import partial +from thinc.types import Ragged, Floats2d, Floats1d +from thinc.api import Model, Ops +from spacy.tokens import Doc +from spacy.errors import Errors + + +def StaticVectors_v1( + nO: Optional[int] = None, + nM: Optional[int] = None, + *, + dropout: Optional[float] = None, + init_W: Callable = glorot_uniform_init, + key_attr: str = "ORTH" +) -> Model[List[Doc], Ragged]: + """Embed Doc objects with their vocab's vectors table, applying a learned + linear projection to control the dimensionality. If a dropout rate is + specified, the dropout is applied per dimension over the whole batch. + """ + return Model( + "static_vectors", + forward, + init=partial(init, init_W), + params={"W": None}, + attrs={"key_attr": key_attr, "dropout_rate": dropout}, + dims={"nO": nO, "nM": nM}, + ) + + +def forward( + model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool +) -> Tuple[Ragged, Callable]: + if not sum(len(doc) for doc in docs): + return _handle_empty(model.ops, model.get_dim("nO")) + key_attr = model.attrs["key_attr"] + W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) + V = cast(Floats2d, docs[0].vocab.vectors.data) + rows = model.ops.flatten( + [doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs] + ) + try: + vectors_data = model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True) + except ValueError: + raise RuntimeError(Errors.E896) + output = Ragged( + vectors_data, model.ops.asarray([len(doc) for doc in docs], dtype="i") + ) + mask = None + if is_train: + mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate")) + if mask is not None: + output.data *= mask + + def backprop(d_output: Ragged) -> List[Doc]: + if mask is not None: + d_output.data *= mask + model.inc_grad( + "W", + model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True), + ) + return [] + + return output, backprop + + +def init( + init_W: Callable, + model: Model[List[Doc], Ragged], + X: Optional[List[Doc]] = None, + Y: Optional[Ragged] = None, +) -> Model[List[Doc], Ragged]: + nM = model.get_dim("nM") if model.has_dim("nM") else None + nO = model.get_dim("nO") if model.has_dim("nO") else None + if X is not None and len(X): + nM = X[0].vocab.vectors.data.shape[1] + if Y is not None: + nO = Y.data.shape[1] + + if nM is None: + raise ValueError(Errors.E905) + if nO is None: + raise ValueError(Errors.E904) + model.set_dim("nM", nM) + model.set_dim("nO", nO) + model.set_param("W", init_W(model.ops, (nO, nM))) + return model + + +def _handle_empty(ops: Ops, nO: int): + return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: [] + + +def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]: + return ops.get_dropout_mask((nO,), rate) if rate is not None else None