from torch import nn

class Conv3by3(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Conv3by3, self).__init__()
        self.conv3by3 = nn.Sequential(
            nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        return self.conv3by3(x)

class Resnet(nn.Module):
    def __init__(self, classes=2, in_ch=3):
        super(Resnet, self).__init__()
        self.firstconv = nn.Sequential(
            nn.Conv2d(in_channels=in_ch, out_channels=64, kernel_size=7),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )
        self.block1_1 = nn.Sequential(
            Conv3by3(64, 64),
            Conv3by3(64, 64),
        )
        self.block1_2 = nn.Sequential(
            Conv3by3(64, 64),
            Conv3by3(64, 64),
        )
        self.block1_3 = nn.Sequential(
            Conv3by3(64, 64),
            Conv3by3(64, 64),
        )

        self.blockshort_1to2 = nn.Sequential(
            nn.AvgPool2d(kernel_size=2,stride=2)
        )

        self.block2_1 = nn.Sequential(
            Conv3by3(64, 128),
            Conv3by3(128, 128),
        )
        self.block2_2 = nn.Sequential(
            Conv3by3(128, 128),
            Conv3by3(128, 128),
        )
        self.block2_3 = nn.Sequential(
            Conv3by3(128, 128),
            Conv3by3(128, 128),
        )

        self.blockshort_2to3 = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1)
        )

        self.block3_1 = nn.Sequential(
            Conv3by3(128, 256),
            Conv3by3(256, 256),
        )
        self.block3_2 = nn.Sequential(
            Conv3by3(256, 256),
            Conv3by3(256, 256),
        )
        self.block3_3 = nn.Sequential(
            Conv3by3(256, 256),
            Conv3by3(256, 256),
        )

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, classes)

    def forward(self, x):
        x = self.firstconv(x)

        identity = x
        out = self.block1_1(x)
        out = out + identity
        out = self.block1_2(out)
        out = out + identity
        out = self.block1_3(out)
        out = out + identity


        out = self.block2_1(out)
        out = self.blockshort_1to2(out)
        identity = out
        out = self.block2_2(out)
        out = out + identity
        out = self.block2_3(out)
        out = out + identity


        out = self.block3_1(out)
        out = self.blockshort_2to3(out)
        identity = out
        out = self.block3_2(out)
        out = out + identity
        out = self.block3_3(out)
        out = out + identity

        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out