From 6d96ed50a3ed1cd49b9147711925784dae697ba2 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 15 Mar 2022 18:07:46 +0000 Subject: [PATCH] Porting docs, examples, tutorials and galleries (#5620) * Fix examples, tutorials and gallery * Update gallery/plot_optical_flow.py Co-authored-by: Nicolas Hug * Fix import * Revert hardcoded normalization * fix uncommitted changes * Fix bug * Fix more bugs * Making resize optional for segmentation * Fixing preset * Fix mypy * Fixing documentation strings * Fix flake8 * minor refactoring Co-authored-by: Nicolas Hug --- android/test_app/make_assets.py | 13 +++++-- examples/cpp/hello_world/trace_model.py | 2 +- gallery/plot_optical_flow.py | 34 +++++++--------- gallery/plot_repurposing_annotations.py | 8 ++-- gallery/plot_scripted_tensor_transforms.py | 12 ++---- gallery/plot_visualization_utils.py | 39 +++++++++++++------ ios/VisionTestApp/make_assets.py | 13 +++++-- test/tracing/frcnn/trace_model.py | 2 +- torchvision/models/_utils.py | 2 +- .../models/detection/backbone_utils.py | 27 ++++++++----- torchvision/models/detection/faster_rcnn.py | 8 ++-- torchvision/models/detection/fcos.py | 4 +- torchvision/models/detection/keypoint_rcnn.py | 4 +- torchvision/models/detection/mask_rcnn.py | 4 +- torchvision/models/detection/retinanet.py | 4 +- torchvision/models/detection/ssd.py | 2 +- torchvision/models/detection/ssdlite.py | 2 +- torchvision/models/googlenet.py | 2 +- torchvision/models/inception.py | 2 +- torchvision/transforms/_presets.py | 12 +++--- 20 files changed, 115 insertions(+), 81 deletions(-) diff --git a/android/test_app/make_assets.py b/android/test_app/make_assets.py index fedee39fc52..f99933e9a9d 100644 --- a/android/test_app/make_assets.py +++ b/android/test_app/make_assets.py @@ -1,11 +1,18 @@ import torch -import torchvision from torch.utils.mobile_optimizer import optimize_for_mobile +from torchvision.models.detection import ( + fasterrcnn_mobilenet_v3_large_320_fpn, + FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, +) print(torch.__version__) -model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150 +model = fasterrcnn_mobilenet_v3_large_320_fpn( + weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT, + box_score_thresh=0.7, + rpn_post_nms_top_n_test=100, + rpn_score_thresh=0.4, + rpn_pre_nms_top_n_test=150, ) model.eval() diff --git a/examples/cpp/hello_world/trace_model.py b/examples/cpp/hello_world/trace_model.py index c8b8d6911e7..41bbaf8b6dd 100644 --- a/examples/cpp/hello_world/trace_model.py +++ b/examples/cpp/hello_world/trace_model.py @@ -6,7 +6,7 @@ HERE = osp.dirname(osp.abspath(__file__)) ASSETS = osp.dirname(osp.dirname(HERE)) -model = torchvision.models.resnet18(pretrained=False) +model = torchvision.models.resnet18() model.eval() traced_model = torch.jit.script(model) diff --git a/gallery/plot_optical_flow.py b/gallery/plot_optical_flow.py index 505334f36da..770610fb971 100644 --- a/gallery/plot_optical_flow.py +++ b/gallery/plot_optical_flow.py @@ -19,7 +19,6 @@ import torch import matplotlib.pyplot as plt import torchvision.transforms.functional as F -import torchvision.transforms as T plt.rcParams["savefig.bbox"] = "tight" @@ -88,24 +87,19 @@ def plot(imgs, **imshow_kwargs): # reduce the image sizes for the example to run faster. Image dimension must be # divisible by 8. +from torchvision.models.optical_flow import Raft_Large_Weights -def preprocess(batch): - transforms = T.Compose( - [ - T.ConvertImageDtype(torch.float32), - T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] - T.Resize(size=(520, 960)), - ] - ) - batch = transforms(batch) - return batch +weights = Raft_Large_Weights.DEFAULT +transforms = weights.transforms() -# If you can, run this example on a GPU, it will be a lot faster. -device = "cuda" if torch.cuda.is_available() else "cpu" +def preprocess(img1_batch, img2_batch): + img1_batch = F.resize(img1_batch, size=[520, 960]) + img2_batch = F.resize(img2_batch, size=[520, 960]) + return transforms(img1_batch, img2_batch)[:2] + -img1_batch = preprocess(img1_batch).to(device) -img2_batch = preprocess(img2_batch).to(device) +img1_batch, img2_batch = preprocess(img1_batch, img2_batch) print(f"shape = {img1_batch.shape}, dtype = {img1_batch.dtype}") @@ -121,7 +115,10 @@ def preprocess(batch): from torchvision.models.optical_flow import raft_large -model = raft_large(pretrained=True, progress=False).to(device) +# If you can, run this example on a GPU, it will be a lot faster. +device = "cuda" if torch.cuda.is_available() else "cpu" + +model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) model = model.eval() list_of_flows = model(img1_batch.to(device), img2_batch.to(device)) @@ -182,10 +179,9 @@ def preprocess(batch): # from torchvision.io import write_jpeg # for i, (img1, img2) in enumerate(zip(frames, frames[1:])): # # Note: it would be faster to predict batches of flows instead of individual flows -# img1 = preprocess(img1[None]).to(device) -# img2 = preprocess(img2[None]).to(device) +# img1, img2 = preprocess(img1, img2) -# list_of_flows = model(img1_batch, img2_batch) +# list_of_flows = model(img1.to(device), img1.to(device)) # predicted_flow = list_of_flows[-1][0] # flow_img = flow_to_image(predicted_flow).to("cpu") # output_folder = "/tmp/" # Update this to the folder of your choice diff --git a/gallery/plot_repurposing_annotations.py b/gallery/plot_repurposing_annotations.py index fb4835496c3..a826a2523f2 100644 --- a/gallery/plot_repurposing_annotations.py +++ b/gallery/plot_repurposing_annotations.py @@ -139,12 +139,14 @@ def show(imgs): # Here is demo with a Faster R-CNN model loaded from # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` -from torchvision.models.detection import fasterrcnn_resnet50_fpn +from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights -model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) +weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT +model = fasterrcnn_resnet50_fpn(weights=weights, progress=False) print(img.size()) -img = F.convert_image_dtype(img, torch.float) +tranforms = weights.transforms() +img, _ = tranforms(img) target = {} target["boxes"] = boxes target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64) diff --git a/gallery/plot_scripted_tensor_transforms.py b/gallery/plot_scripted_tensor_transforms.py index a9205536821..995383d4603 100644 --- a/gallery/plot_scripted_tensor_transforms.py +++ b/gallery/plot_scripted_tensor_transforms.py @@ -85,20 +85,16 @@ def show(imgs): # Let's define a ``Predictor`` module that transforms the input tensor and then # applies an ImageNet model on it. -from torchvision.models import resnet18 +from torchvision.models import resnet18, ResNet18_Weights class Predictor(nn.Module): def __init__(self): super().__init__() - self.resnet18 = resnet18(pretrained=True, progress=False).eval() - self.transforms = nn.Sequential( - T.Resize([256, ]), # We use single int value inside a list due to torchscript type restrictions - T.CenterCrop(224), - T.ConvertImageDtype(torch.float), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ) + weights = ResNet18_Weights.DEFAULT + self.resnet18 = resnet18(weights=weights, progress=False).eval() + self.transforms = weights.transforms() def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): diff --git a/gallery/plot_visualization_utils.py b/gallery/plot_visualization_utils.py index 526c8c32493..27fd97681c0 100644 --- a/gallery/plot_visualization_utils.py +++ b/gallery/plot_visualization_utils.py @@ -73,14 +73,17 @@ def show(imgs): # :func:`~torchvision.models.detection.ssd300_vgg16`. For more details # on the output of such models, you may refer to :ref:`instance_seg_output`. -from torchvision.models.detection import fasterrcnn_resnet50_fpn -from torchvision.transforms.functional import convert_image_dtype +from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights batch_int = torch.stack([dog1_int, dog2_int]) -batch = convert_image_dtype(batch_int, dtype=torch.float) -model = fasterrcnn_resnet50_fpn(pretrained=True, progress=False) +weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT +transforms = weights.transforms() + +batch, _ = transforms(batch_int) + +model = fasterrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() outputs = model(batch) @@ -120,13 +123,15 @@ def show(imgs): # images must be normalized before they're passed to a semantic segmentation # model. -from torchvision.models.segmentation import fcn_resnet50 +from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights +weights = FCN_ResNet50_Weights.DEFAULT +transforms = weights.transforms(resize_size=None) -model = fcn_resnet50(pretrained=True, progress=False) +model = fcn_resnet50(weights=weights, progress=False) model = model.eval() -normalized_batch = F.normalize(batch, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) +normalized_batch, _ = transforms(batch) output = model(normalized_batch)['out'] print(output.shape, output.min().item(), output.max().item()) @@ -262,8 +267,14 @@ def show(imgs): # of them may not have masks, like # :func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn`. -from torchvision.models.detection import maskrcnn_resnet50_fpn -model = maskrcnn_resnet50_fpn(pretrained=True, progress=False) +from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights + +weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT +transforms = weights.transforms() + +batch, _ = transforms(batch_int) + +model = maskrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() output = model(batch) @@ -378,13 +389,17 @@ def show(imgs): # Note that the keypoint detection model does not need normalized images. # -from torchvision.models.detection import keypointrcnn_resnet50_fpn +from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights from torchvision.io import read_image person_int = read_image(str(Path("assets") / "person1.jpg")) -person_float = convert_image_dtype(person_int, dtype=torch.float) -model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False) +weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT +transforms = weights.transforms() + +person_float, _ = transforms(person_int) + +model = keypointrcnn_resnet50_fpn(weights=weights, progress=False) model = model.eval() outputs = model([person_float]) diff --git a/ios/VisionTestApp/make_assets.py b/ios/VisionTestApp/make_assets.py index 0f46364569b..f14223e6a42 100644 --- a/ios/VisionTestApp/make_assets.py +++ b/ios/VisionTestApp/make_assets.py @@ -1,11 +1,18 @@ import torch -import torchvision from torch.utils.mobile_optimizer import optimize_for_mobile +from torchvision.models.detection import ( + fasterrcnn_mobilenet_v3_large_320_fpn, + FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, +) print(torch.__version__) -model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150 +model = fasterrcnn_mobilenet_v3_large_320_fpn( + weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT, + box_score_thresh=0.7, + rpn_post_nms_top_n_test=100, + rpn_score_thresh=0.4, + rpn_pre_nms_top_n_test=150, ) model.eval() diff --git a/test/tracing/frcnn/trace_model.py b/test/tracing/frcnn/trace_model.py index 8cc1d344936..768954d29b2 100644 --- a/test/tracing/frcnn/trace_model.py +++ b/test/tracing/frcnn/trace_model.py @@ -6,7 +6,7 @@ HERE = osp.dirname(osp.abspath(__file__)) ASSETS = osp.dirname(osp.dirname(HERE)) -model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False) +model = torchvision.models.detection.fasterrcnn_resnet50_fpn() model.eval() traced_model = torch.jit.script(model) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 9e3a81411a1..08c878a8a67 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -32,7 +32,7 @@ class IntermediateLayerGetter(nn.ModuleDict): Examples:: - >>> m = torchvision.models.resnet18(pretrained=True) + >>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT) >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, >>> {'layer1': 'feat1', 'layer3': 'feat2'}) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index cac96b61f64..b767756692b 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -6,7 +6,8 @@ from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool from .. import mobilenet, resnet -from .._utils import IntermediateLayerGetter +from .._api import WeightsEnum +from .._utils import IntermediateLayerGetter, handle_legacy_interface class BackboneWithFPN(nn.Module): @@ -55,9 +56,13 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return x +@handle_legacy_interface( + weights=("pretrained", True), # type: ignore[arg-type] +) def resnet_fpn_backbone( + *, backbone_name: str, - pretrained: bool, + weights: Optional[WeightsEnum], norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 3, returned_layers: Optional[List[int]] = None, @@ -69,7 +74,7 @@ def resnet_fpn_backbone( Examples:: >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone - >>> backbone = resnet_fpn_backbone('resnet50', pretrained=True, trainable_layers=3) + >>> backbone = resnet_fpn_backbone('resnet50', weights=ResNet50_Weights.DEFAULT, trainable_layers=3) >>> # get some dummy image >>> x = torch.rand(1,3,64,64) >>> # compute the output @@ -85,7 +90,7 @@ def resnet_fpn_backbone( Args: backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2' - pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet + weights (WeightsEnum, optional): The pretrained weights for the model norm_layer (callable): it is recommended to use the default value. For details visit: (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) trainable_layers (int): number of trainable (not frozen) layers starting from final block. @@ -98,7 +103,7 @@ def resnet_fpn_backbone( a new list of feature maps and their corresponding names. By default a ``LastLevelMaxPool`` is used. """ - backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks) @@ -135,13 +140,13 @@ def _resnet_fpn_extractor( def _validate_trainable_layers( - pretrained: bool, + is_trained: bool, trainable_backbone_layers: Optional[int], max_value: int, default_value: int, ) -> int: # don't freeze any layers if pretrained model or backbone is not used - if not pretrained: + if not is_trained: if trainable_backbone_layers is not None: warnings.warn( "Changing trainable_backbone_layers has not effect if " @@ -160,16 +165,20 @@ def _validate_trainable_layers( return trainable_backbone_layers +@handle_legacy_interface( + weights=("pretrained", True), # type: ignore[arg-type] +) def mobilenet_backbone( + *, backbone_name: str, - pretrained: bool, + weights: Optional[WeightsEnum], fpn: bool, norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, trainable_layers: int = 2, returned_layers: Optional[List[int]] = None, extra_blocks: Optional[ExtraFPNBlock] = None, ) -> nn.Module: - backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + backbone = mobilenet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 18872adc029..efbd88906f2 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -117,7 +117,7 @@ class FasterRCNN(GeneralizedRCNN): >>> from torchvision.models.detection.rpn import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # FasterRCNN needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -415,7 +415,7 @@ def fasterrcnn_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT) >>> # For training >>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4) >>> boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4] @@ -532,7 +532,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( Example:: - >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True) + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(weights=FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) @@ -589,7 +589,7 @@ def fasterrcnn_mobilenet_v3_large_fpn( Example:: - >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True) + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 8d110d809f7..5627573836a 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -325,7 +325,7 @@ class FCOS(nn.Module): >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # FCOS needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -697,7 +697,7 @@ def fcos_resnet50_fpn( Example: - >>> model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.fcos_resnet50_fpn(weights=FCOS_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 3794b253ec7..272a6c3debe 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -119,7 +119,7 @@ class KeypointRCNN(FasterRCNN): >>> >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # KeypointRCNN needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -395,7 +395,7 @@ def keypointrcnn_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 38ba82af01d..04652d5a66f 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -118,7 +118,7 @@ class MaskRCNN(FasterRCNN): >>> >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # MaskRCNN needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -373,7 +373,7 @@ def maskrcnn_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index b1c371583bf..da3a521c36a 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -293,7 +293,7 @@ class RetinaNet(nn.Module): >>> from torchvision.models.detection.anchor_utils import AnchorGenerator >>> # load a pre-trained model for classification and return >>> # only the features - >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> # RetinaNet needs to know the number of >>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # so we need to add it here @@ -642,7 +642,7 @@ def retinanet_resnet50_fpn( Example:: - >>> model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True) + >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index cf3becc5fc4..32158ecfab3 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -596,7 +596,7 @@ def ssd300_vgg16( Example: - >>> model = torchvision.models.detection.ssd300_vgg16(pretrained=True) + >>> model = torchvision.models.detection.ssd300_vgg16(weights=SSD300_VGG16_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index a71da6b29ac..bb471fa6fa8 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -224,7 +224,7 @@ def ssdlite320_mobilenet_v3_large( Example: - >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True) + >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT) >>> model.eval() >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)] >>> predictions = model(x) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 2cac4a4fbbd..a47b478af03 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -307,7 +307,7 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T aux_logits (bool): If True, adds two auxiliary branches that can improve training. Default: *False* when pretrained is True otherwise *True* transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. + was trained on ImageNet. Default: True if ``weights=GoogLeNet_Weights.IMAGENET1K_V1``, else False. """ weights = GoogLeNet_Weights.verify(weights) diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 1628542482b..44b2bd56feb 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -443,7 +443,7 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo aux_logits (bool): If True, add an auxiliary branch that can improve training. Default: *True* transform_input (bool): If True, preprocesses the input according to the method with which it - was trained on ImageNet. Default: True if ``pretrained=True``, else False. + was trained on ImageNet. Default: True if ``weights=Inception_V3_Weights.IMAGENET1K_V1``, else False. """ weights = Inception_V3_Weights.verify(weights) diff --git a/torchvision/transforms/_presets.py b/torchvision/transforms/_presets.py index a6b85d05597..1776d876ccb 100644 --- a/torchvision/transforms/_presets.py +++ b/torchvision/transforms/_presets.py @@ -78,27 +78,29 @@ def forward(self, vid: Tensor) -> Tensor: class SemanticSegmentationEval(nn.Module): def __init__( self, - resize_size: int, + resize_size: Optional[int], mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation_target: InterpolationMode = InterpolationMode.NEAREST, ) -> None: super().__init__() - self._size = [resize_size] + self._size = [resize_size] if resize_size is not None else None self._mean = list(mean) self._std = list(std) self._interpolation = interpolation self._interpolation_target = interpolation_target def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: - img = F.resize(img, self._size, interpolation=self._interpolation) + if isinstance(self._size, list): + img = F.resize(img, self._size, interpolation=self._interpolation) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) img = F.normalize(img, mean=self._mean, std=self._std) if target: - target = F.resize(target, self._size, interpolation=self._interpolation_target) + if isinstance(self._size, list): + target = F.resize(target, self._size, interpolation=self._interpolation_target) if not isinstance(target, Tensor): target = F.pil_to_tensor(target) target = target.squeeze(0).to(torch.int64) @@ -107,7 +109,7 @@ def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, class OpticalFlowEval(nn.Module): def forward( - self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] + self, img1: Tensor, img2: Tensor, flow: Optional[Tensor] = None, valid_flow_mask: Optional[Tensor] = None ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask)