import torch
import torch.nn as nn
from .ops import blocks
from typing import List, Any
from .utils import export, load_from_local_or_url


def get_stem(in_channels):
    return blocks.Stage(
        blocks.Conv2dBlock(in_channels, 32, kernel_size=3, stride=2, padding=0),
        blocks.Conv2dBlock(32, 32, kernel_size=3, padding=0),
        blocks.Conv2dBlock(32, 64, kernel_size=3),
        blocks.ConcatBranches(
            nn.MaxPool2d(3, stride=2),
            blocks.Conv2dBlock(64, 96, kernel_size=3, stride=2, padding=0)
        ),
        blocks.ConcatBranches(
            nn.Sequential(
                blocks.Conv2d1x1Block(160, 64),
                blocks.Conv2dBlock(64, 96, kernel_size=3, padding=0)
            ),
            nn.Sequential(
                blocks.Conv2d1x1Block(160, 64),
                blocks.Conv2dBlock(64, 64, kernel_size=(7, 1), padding=(3, 0)),
                blocks.Conv2dBlock(64, 64, kernel_size=(1, 7), padding=(0, 3)),
                blocks.Conv2dBlock(64, 96, kernel_size=3, padding=0)
            )
        ),
        blocks.ConcatBranches(
            blocks.Conv2dBlock(192, 192, kernel_size=3, stride=2, padding=0),
            nn.MaxPool2d(3, stride=2, padding=0)
        )
    )


class InceptionV4(nn.Module):
    r"""
    Paper: Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning, https://arxiv.org/abs/1602.07261
    """

    def __init__(
        self,
        in_channels: int = 3,
        num_classes: int = 1000,
        dropout_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        **kwargs: Any
    ) -> None:
        super().__init__()

        self.stem = get_stem(in_channels)

        self.stage1 = blocks.Stage(
            *[blocks.InceptionA(384, 96, [64, 96], [64, 96], 96) for _ in range(4)],
            blocks.ReductionA(384, 384, [192, 224, 256]),
        )

        self.stage2 = blocks.Stage(
            *[blocks.InceptionB(1024, 384, [192, 224, 256], [192, 224, 256], 128) for _ in range(7)],
            blocks.ReductionB(1024, [192, 192], [256, 320])
        )

        self.stage3 = blocks.Stage(
            *[blocks.InceptionC(1536, 256, [384, 256], [384, 448, 512, 256], 256) for _ in range(3)],
        )

        self.pool = nn.AdaptiveMaxPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate, inplace=True),
            nn.Linear(1536, num_classes)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x


@export
def inception_v4(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
    model = InceptionV4(**kwargs)

    if pretrained:
        load_from_local_or_url(model, pth, kwargs.get('url', None), progress)
    return model


class InceptionResNetV1(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        num_classes: int = 1000,
        dropout_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        **kwargs: Any
    ) -> None:
        super().__init__()

        self.stem = nn.Sequential(
            blocks.Conv2dBlock(in_channels, 32, kernel_size=3, stride=2, padding=0),
            blocks.Conv2dBlock(32, 32, kernel_size=3, padding=0),
            blocks.Conv2dBlock(32, 64, kernel_size=3),
            nn.MaxPool2d(3, stride=2),
            blocks.Conv2d1x1Block(64, 80),
            blocks.Conv2dBlock(80, 192, kernel_size=3, padding=0),
            blocks.Conv2dBlock(192, 256, kernel_size=3, stride=2, padding=0)
        )

        self.stage1 = blocks.Stage(
            *[blocks.InceptionResNetA(256, 32, [32, 32], [32, 32, 32]) for _ in range(5)],
            blocks.ReductionA(256, 384, [192, 192, 256])
        )

        self.stage2 = blocks.Stage(
            *[blocks.InceptionResNetB(896, 128, [128, 128, 128]) for _ in range(10)],
            blocks.ReductionC(896, [256, 384], [256, 256], [256, 256, 256])
        )

        self.stage3 = blocks.Stage(
            [blocks.InceptionResNetC(1792, 192, [192, 192, 192]) for _ in range(5)],
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate, inplace=True),
            nn.Linear(1792, num_classes)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x


@export
def inception_resnet_v1(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
    model = InceptionResNetV1(**kwargs)

    if pretrained:
        load_from_local_or_url(model, pth, kwargs.get('url', None), progress)
    return model


class InceptionResNetV2(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        num_classes: int = 1000,
        dropout_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        **kwargs: Any
    ) -> None:
        super().__init__()

        self.stem = get_stem(in_channels)

        self.stage1 = blocks.Stage(
            *[blocks.InceptionResNetA(384, 32, [32, 32], [32, 48, 64]) for _ in range(10)],
            blocks.ReductionA(384, 384, [256, 256, 384])
        )

        self.stage2 = blocks.Stage(
            *[blocks.InceptionResNetB(1152, 192, [128, 160, 192]) for _ in range(20)],
            blocks.ReductionC(1152, [256, 384], [256, 288], [256, 288, 320])
        )

        self.stage3 = blocks.Stage(
            *[blocks.InceptionResNetC(2144, 192, [192, 224, 256]) for _ in range(10)],
            blocks.Conv2d1x1Block(2144, 1536)
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate, inplace=True),
            nn.Linear(1536, num_classes)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x


@export
def inception_resnet_v2(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
    model = InceptionResNetV2(**kwargs)

    if pretrained:
        load_from_local_or_url(model, pth, kwargs.get('url', None), progress)
    return model