Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF implementation of RegNets #17554

Merged
merged 42 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
99641bd
chore: initial commit
ariG23498 Jun 4, 2022
7ea1d01
chore: porting the rest of the modules to tensorflow
ariG23498 Jun 5, 2022
74cd9a0
Fix initilizations (#1)
sayakpaul Jun 6, 2022
1b42157
chore: styling nits.
sayakpaul Jun 6, 2022
f1bf27a
fix: cross-loading bn params.
sayakpaul Jun 7, 2022
adec9db
fix: regnet tf model, integration passing.
sayakpaul Jun 13, 2022
e90fdfc
add: tests for TF regnet.
sayakpaul Jun 13, 2022
ec9bf1c
fix: code quality related issues.
sayakpaul Jun 13, 2022
419da37
chore: added rest of the files.
sayakpaul Jun 13, 2022
213e9e9
Merge pull request #2 from ariG23498/feat/tf-regnets
ariG23498 Jun 13, 2022
cf3d797
minor additions..
sayakpaul Jun 14, 2022
b9aa7b5
fix: repo consistency.
sayakpaul Jun 15, 2022
b93abe8
fix: regnet tf tests.
sayakpaul Jun 16, 2022
c4cd6db
chore: reorganize dummy_tf_objects for regnet.
sayakpaul Jun 16, 2022
7c58838
chore: remove checkpoint var.
sayakpaul Jun 16, 2022
ab7f80e
chore: remov unnecessary files.
sayakpaul Jun 16, 2022
5bf7d28
chore: run make style.
sayakpaul Jun 16, 2022
9e26607
Merge branch 'main' into aritra-regnets
ariG23498 Jun 16, 2022
23a20ad
Update docs/source/en/model_doc/regnet.mdx
sayakpaul Jun 16, 2022
6fdcc6d
chore: PR feedback I.
sayakpaul Jun 16, 2022
736b521
fix: pt test. thanks to @ydshieh.
sayakpaul Jun 20, 2022
55a8a0f
New adaptive pooler (#3)
sayakpaul Jun 22, 2022
882959e
Empty-Commit
sayakpaul Jun 22, 2022
221c76d
chore: remove image_size comment.
sayakpaul Jun 22, 2022
5c285e4
chore: remove playground_tf.py
sayakpaul Jun 22, 2022
eafaeb9
chore: minor changes related to spacing.
sayakpaul Jun 22, 2022
0ce07cd
Merge branch 'main' into aritra-regnets
ariG23498 Jun 22, 2022
cc2c1fe
chore: make style.
sayakpaul Jun 22, 2022
c416b59
Update src/transformers/models/regnet/modeling_tf_regnet.py
sayakpaul Jun 22, 2022
579fd86
Update src/transformers/models/regnet/modeling_tf_regnet.py
sayakpaul Jun 22, 2022
040032f
chore: refactored __init__.
sayakpaul Jun 22, 2022
ab641fe
chore: copied from -> taken from./g
sayakpaul Jun 22, 2022
222493c
Merge branch 'main' into aritra-regnets
ariG23498 Jun 23, 2022
478b352
adaptive pool -> global avg pool, channel check.
sayakpaul Jun 23, 2022
d292694
chore: move channel check to stem.
sayakpaul Jun 23, 2022
9e07109
pr comments - minor refactor and add regnets to doc tests.
sayakpaul Jun 23, 2022
969762f
Update src/transformers/models/regnet/modeling_tf_regnet.py
sayakpaul Jun 23, 2022
0fa00bc
minor fix in the xlayer.
sayakpaul Jun 28, 2022
fd053e8
Merge branch 'main' into aritra-regnets
ariG23498 Jun 28, 2022
a893f3e
Empty-Commit
sayakpaul Jun 28, 2022
0360f45
chore: removed from_pt=True.
sayakpaul Jun 29, 2022
d7d395a
Merge branch 'main' into aritra-regnets
ariG23498 Jun 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ Flax), PyTorch, and/or TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | | ❌ |
| RegNet | ❌ | ❌ | ✅ | | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
Expand Down
16 changes: 14 additions & 2 deletions docs/source/en/model_doc/regnet.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ Tips:
- One can use [`AutoFeatureExtractor`] to prepare images for the model.
- The huge 10B model from [Self-supervised Pretraining of Visual Features in the Wild](https://arxiv.org/abs/2103.01988), trained on one billion Instagram images, is available on the [hub](https://huggingface.co/facebook/regnet-y-10b-seer)

This model was contributed by [Francesco](https://huggingface.co/Francesco).
This model was contributed by [Francesco](https://huggingface.co/Francesco). The TensorFlow version of the model
was contributed by [sayakpaul](https://huggingface.com/sayakpaul) and [ariG23498](https://huggingface.com/ariG23498).
The original code can be found [here](https://github.com/facebookresearch/pycls).


Expand All @@ -45,4 +46,15 @@ The original code can be found [here](https://github.com/facebookresearch/pycls)
## RegNetForImageClassification

[[autodoc]] RegNetForImageClassification
- forward
- forward

## TFRegNetModel

[[autodoc]] TFRegNetModel
- call


## TFRegNetForImageClassification

[[autodoc]] TFRegNetForImageClassification
- call
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,6 +2279,14 @@
"TFRagTokenForGeneration",
]
)
_import_structure["models.regnet"].extend(
[
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRegNetForImageClassification",
"TFRegNetModel",
"TFRegNetPreTrainedModel",
]
)
_import_structure["models.rembert"].extend(
[
"TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -4551,6 +4559,12 @@
from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.regnet import (
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRegNetForImageClassification,
TFRegNetModel,
TFRegNetPreTrainedModel,
)
from .models.rembert import (
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRemBertForCausalLM,
Expand Down
62 changes: 62 additions & 0 deletions src/transformers/modeling_tf_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ class TFBaseModelOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None


@dataclass
class TFBaseModelOutputWithNoAttention(ModelOutput):
"""
Base class for model's outputs, with potential hidden states.

Args:
last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""

last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None


@dataclass
class TFBaseModelOutputWithPooling(ModelOutput):
"""
Expand Down Expand Up @@ -80,6 +99,28 @@ class TFBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None


@dataclass
class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.

Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a pooling operation on the spatial dimensions.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, num_channels, height, width)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""

last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None


@dataclass
class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Expand Down Expand Up @@ -825,3 +866,24 @@ class TFSequenceClassifierOutputWithPast(ModelOutput):
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None


@dataclass
class TFImageClassifierOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of image classification models.

Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called
feature maps) of the model at the output of each stage.
"""

loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
("openai-gpt", "TFOpenAIGPTModel"),
("opt", "TFOPTModel"),
("pegasus", "TFPegasusModel"),
("regnet", "TFRegNetModel"),
("rembert", "TFRemBertModel"),
("roberta", "TFRobertaModel"),
("roformer", "TFRoFormerModel"),
Expand Down Expand Up @@ -173,6 +174,7 @@
# Model for Image-classsification
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"),
]
Expand Down
29 changes: 27 additions & 2 deletions src/transformers/models/regnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from typing import TYPE_CHECKING

# rely on isort to merge the imports
from ...file_utils import _LazyModule, is_torch_available
from ...utils import OptionalDependencyNotAvailable
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available


_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}
Expand All @@ -37,6 +36,19 @@
"RegNetPreTrainedModel",
]

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_regnet"] = [
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRegNetForImageClassification",
"TFRegNetModel",
"TFRegNetPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
Expand All @@ -54,6 +66,19 @@
RegNetPreTrainedModel,
)

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_regnet import (
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRegNetForImageClassification,
TFRegNetModel,
TFRegNetPreTrainedModel,
)


else:
import sys
Expand Down
Loading