Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF implementation of RegNets #17554

Merged
merged 42 commits into from
Jun 29, 2022
Merged
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
99641bd
chore: initial commit
ariG23498 Jun 4, 2022
7ea1d01
chore: porting the rest of the modules to tensorflow
ariG23498 Jun 5, 2022
74cd9a0
Fix initilizations (#1)
sayakpaul Jun 6, 2022
1b42157
chore: styling nits.
sayakpaul Jun 6, 2022
f1bf27a
fix: cross-loading bn params.
sayakpaul Jun 7, 2022
adec9db
fix: regnet tf model, integration passing.
sayakpaul Jun 13, 2022
e90fdfc
add: tests for TF regnet.
sayakpaul Jun 13, 2022
ec9bf1c
fix: code quality related issues.
sayakpaul Jun 13, 2022
419da37
chore: added rest of the files.
sayakpaul Jun 13, 2022
213e9e9
Merge pull request #2 from ariG23498/feat/tf-regnets
ariG23498 Jun 13, 2022
cf3d797
minor additions..
sayakpaul Jun 14, 2022
b9aa7b5
fix: repo consistency.
sayakpaul Jun 15, 2022
b93abe8
fix: regnet tf tests.
sayakpaul Jun 16, 2022
c4cd6db
chore: reorganize dummy_tf_objects for regnet.
sayakpaul Jun 16, 2022
7c58838
chore: remove checkpoint var.
sayakpaul Jun 16, 2022
ab7f80e
chore: remov unnecessary files.
sayakpaul Jun 16, 2022
5bf7d28
chore: run make style.
sayakpaul Jun 16, 2022
9e26607
Merge branch 'main' into aritra-regnets
ariG23498 Jun 16, 2022
23a20ad
Update docs/source/en/model_doc/regnet.mdx
sayakpaul Jun 16, 2022
6fdcc6d
chore: PR feedback I.
sayakpaul Jun 16, 2022
736b521
fix: pt test. thanks to @ydshieh.
sayakpaul Jun 20, 2022
55a8a0f
New adaptive pooler (#3)
sayakpaul Jun 22, 2022
882959e
Empty-Commit
sayakpaul Jun 22, 2022
221c76d
chore: remove image_size comment.
sayakpaul Jun 22, 2022
5c285e4
chore: remove playground_tf.py
sayakpaul Jun 22, 2022
eafaeb9
chore: minor changes related to spacing.
sayakpaul Jun 22, 2022
0ce07cd
Merge branch 'main' into aritra-regnets
ariG23498 Jun 22, 2022
cc2c1fe
chore: make style.
sayakpaul Jun 22, 2022
c416b59
Update src/transformers/models/regnet/modeling_tf_regnet.py
sayakpaul Jun 22, 2022
579fd86
Update src/transformers/models/regnet/modeling_tf_regnet.py
sayakpaul Jun 22, 2022
040032f
chore: refactored __init__.
sayakpaul Jun 22, 2022
ab641fe
chore: copied from -> taken from./g
sayakpaul Jun 22, 2022
222493c
Merge branch 'main' into aritra-regnets
ariG23498 Jun 23, 2022
478b352
adaptive pool -> global avg pool, channel check.
sayakpaul Jun 23, 2022
d292694
chore: move channel check to stem.
sayakpaul Jun 23, 2022
9e07109
pr comments - minor refactor and add regnets to doc tests.
sayakpaul Jun 23, 2022
969762f
Update src/transformers/models/regnet/modeling_tf_regnet.py
sayakpaul Jun 23, 2022
0fa00bc
minor fix in the xlayer.
sayakpaul Jun 28, 2022
fd053e8
Merge branch 'main' into aritra-regnets
ariG23498 Jun 28, 2022
a893f3e
Empty-Commit
sayakpaul Jun 28, 2022
0360f45
chore: removed from_pt=True.
sayakpaul Jun 29, 2022
d7d395a
Merge branch 'main' into aritra-regnets
ariG23498 Jun 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chore: added rest of the files.
sayakpaul committed Jun 13, 2022

Verified

This commit was signed with the committer’s verified signature. The key has expired.
mjcarroll Michael Carroll
commit 419da3789603918100b51d0c227da4484ca704f1
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
@@ -249,7 +249,7 @@ Flax), PyTorch, and/or TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | | ❌ |
| RegNet | ❌ | ❌ | ✅ | | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
16 changes: 14 additions & 2 deletions docs/source/en/model_doc/regnet.mdx
Original file line number Diff line number Diff line change
@@ -27,7 +27,8 @@ Tips:
- One can use [`AutoFeatureExtractor`] to prepare images for the model.
- The huge 10B model from [Self-supervised Pretraining of Visual Features in the Wild](https://arxiv.org/abs/2103.01988), trained on one billion Instagram images, is available on the [hub](https://huggingface.co/facebook/regnet-y-10b-seer)

This model was contributed by [Francesco](https://huggingface.co/Francesco).
This model was contributed by [Francesco](https://huggingface.co/Francesco). TensorFlow version of the model
was contributed by [sayakpaul](https://github.com/sayakpaul) and [ariG23498](https://github.com/ariG23498).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer HuggingFace usernames and pages here, if you have any.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

The original code can be found [here](https://github.com/facebookresearch/pycls).


@@ -45,4 +46,15 @@ The original code can be found [here](https://github.com/facebookresearch/pycls)
## RegNetForImageClassification

[[autodoc]] RegNetForImageClassification
- forward
- forward

## TFRegNetModel

[[autodoc]] TFRegNetModel
- call


## TFRegNetForImageClassification

[[autodoc]] TFRegNetForImageClassification
- call
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -2224,6 +2224,14 @@
"TFRagTokenForGeneration",
]
)
_import_structure["models.regnet"].extend(
[
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRegNetForImageClassification",
"TFRegNetModel",
"TFRegNetPreTrainedModel",
]
)
_import_structure["models.rembert"].extend(
[
"TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -4450,6 +4458,12 @@
)
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.regnet import (
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRegNetForImageClassification,
TFRegNetModel,
TFRegNetPreTrainedModel,
)
from .models.rembert import (
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRemBertForCausalLM,
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
@@ -61,6 +61,7 @@
("mt5", "TFMT5Model"),
("openai-gpt", "TFOpenAIGPTModel"),
("pegasus", "TFPegasusModel"),
("regnet", "TFRegNetModel"),
("rembert", "TFRemBertModel"),
("roberta", "TFRobertaModel"),
("roformer", "TFRoFormerModel"),
@@ -171,6 +172,7 @@
# Model for Image-classsification
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"),
]
29 changes: 27 additions & 2 deletions src/transformers/models/regnet/__init__.py
Original file line number Diff line number Diff line change
@@ -18,8 +18,7 @@
from typing import TYPE_CHECKING

# rely on isort to merge the imports
from ...file_utils import _LazyModule, is_torch_available
from ...utils import OptionalDependencyNotAvailable
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available


_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}
@@ -37,6 +36,19 @@
"RegNetPreTrainedModel",
]

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_regnet"] = [
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRegNetForImageClassification",
"TFRegNetModel",
"TFRegNetPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
@@ -54,6 +66,19 @@
RegNetPreTrainedModel,
)

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_regnet import (
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRegNetForImageClassification,
TFRegNetModel,
TFRegNetPreTrainedModel,
)


else:
import sys
2 changes: 1 addition & 1 deletion src/transformers/models/regnet/configuration_regnet.py
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ class RegNetConfig(PretrainedConfig):

Args:
image_size (`int`, *optional*, defaults to 224):
Size of the input images.
Size (resolution) of the input images.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
embedding_size (`int`, *optional*, defaults to 64):
24 changes: 24 additions & 0 deletions src/transformers/utils/dummy_tf_objects.py
Original file line number Diff line number Diff line change
@@ -255,6 +255,30 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFRegNetForImageClassification(metaclass=DummyObject):
_backends = ["tf"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFRegNetModel(metaclass=DummyObject):
_backends = ["tf"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFRegNetPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])


TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = None


TF_MODEL_FOR_CAUSAL_LM_MAPPING = None