Skip to content

Commit

Permalink
[PT] remove submodule from textnet arch (#1436)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jan 26, 2024
1 parent f316489 commit abf0571
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
16 changes: 8 additions & 8 deletions doctr/models/classification/textnet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-c23a1b9a.pt&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_tiny-c5970fe0.pt&src=0",
},
"textnet_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-775169f7.pt&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_small-6e8ab0ce.pt&src=0",
},
"textnet_base": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-6121c044.pt&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/textnet_base-8295dc85.pt&src=0",
},
}

Expand Down Expand Up @@ -66,13 +66,13 @@ def __init__(
*conv_sequence_pt(
in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
),
nn.Sequential(*[
*[
nn.Sequential(*[
FASTConvLayer(**params) # type: ignore[arg-type]
for params in [{key: stage[key][i] for key in stage} for i in range(len(stage["in_channels"]))]
])
for stage in stages
]),
],
]

if include_top:
Expand Down Expand Up @@ -167,7 +167,7 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TextNet:
"stride": [2, 1, 1, 1],
},
],
ignore_keys=["4.2.weight", "4.2.bias"],
ignore_keys=["7.2.weight", "7.2.bias"],
**kwargs,
)

Expand Down Expand Up @@ -216,7 +216,7 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TextNet:
"stride": [2, 1, 1, 1, 1],
},
],
ignore_keys=["4.2.weight", "4.2.bias"],
ignore_keys=["7.2.weight", "7.2.bias"],
**kwargs,
)

Expand Down Expand Up @@ -270,6 +270,6 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TextNet:
"stride": [2, 1, 1, 1, 1],
},
],
ignore_keys=["4.2.weight", "4.2.bias"],
ignore_keys=["7.2.weight", "7.2.bias"],
**kwargs,
)
10 changes: 5 additions & 5 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,10 @@ def main(args):
# Augmentations
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
T.RandomApply(T.RandomShadow(), 0.4),
T.RandomApply(GaussianBlur(kernel_size=3), 0.3),
RandomPhotometricDistort(p=0.1),
RandomGrayscale(p=0.1),
T.RandomApply(T.RandomShadow(), 0.1),
T.RandomApply(GaussianBlur(kernel_size=3), 0.1),
RandomPhotometricDistort(p=0.05),
RandomGrayscale(p=0.05),
]),
sample_transforms=T.SampleCompose(
(
Expand Down Expand Up @@ -442,7 +442,7 @@ def parse_args():
action="store_true",
help="metrics evaluation with straight boxes instead of polygons to save time + memory",
)
parser.add_argument("--sched", type=str, default="cosine", help="scheduler to use")
parser.add_argument("--sched", type=str, default="onecycle", help="scheduler to use")
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR")
parser.add_argument("--early-stop", action="store_true", help="Enable early stopping")
Expand Down
6 changes: 3 additions & 3 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,12 @@ def main(args):
T.RandomApply(T.ColorInversion(), 0.1),
T.RandomJpegQuality(60),
T.RandomApply(T.GaussianNoise(mean=0.1, std=0.1), 0.1),
T.RandomApply(T.RandomShadow(), 0.4),
T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.3),
T.RandomApply(T.RandomShadow(), 0.1),
T.RandomApply(T.GaussianBlur(kernel_shape=3, std=(0.1, 0.1)), 0.1),
T.RandomSaturation(0.3),
T.RandomContrast(0.3),
T.RandomBrightness(0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.1),
T.RandomApply(T.ToGray(num_output_channels=3), 0.05),
]),
sample_transforms=T.SampleCompose(
(
Expand Down

0 comments on commit abf0571

Please sign in to comment.