import torch

class UpDimV3(torch.nn.Module):

    def __init__(self, num_class):
        super(UpDimV3, self).__init__()
        self.activation = torch.nn.LeakyReLU(0.001, inplace=True)
        self.dropout = torch.nn.Dropout()

        # Block 1D 1
        self.seq1 = torch.nn.Sequential(torch.nn.Conv1d(1, 32, 3, 1, 1),
                                        torch.nn.MaxPool1d(3, 2, 1),
                                        torch.nn.BatchNorm1d(32),
                                        self.activation,
                                        torch.nn.Conv1d(32, 32, 3, 1, 1),
                                        torch.nn.MaxPool1d(3, 2, 1),
                                        torch.nn.BatchNorm1d(32))
        self.skip11 = torch.nn.Conv1d(1, 32, 1, 1)
        self.skip_pool11 = torch.nn.MaxPool1d(5, 4, 2)

        # Block 1D 2
        self.seq2 = torch.nn.Sequential(torch.nn.Conv1d(32, 64, 3, 1, 1),
                                        torch.nn.MaxPool1d(3, 2, 1),
                                        torch.nn.BatchNorm1d(64),
                                        self.activation,
                                        torch.nn.Conv1d(64, 128, 5, 1, 2),
                                        torch.nn.MaxPool1d(5, 4, 2),
                                        torch.nn.BatchNorm1d(128))
        self.skip12 = torch.nn.Conv1d(32, 128, 1, 1)
        self.skip_pool12 = torch.nn.MaxPool1d(9, 8, 4)

        # Block 2D 1
        self.seq3 = torch.nn.Sequential(torch.nn.Conv2d(1, 32, (3, 5), 1, (1, 2)),
                                        torch.nn.MaxPool2d((1, 3), (1, 2), (0, 1)),
                                        torch.nn.BatchNorm2d(32),
                                        self.activation,
                                        torch.nn.Conv2d(32, 32, (3, 5), 1, (1, 2)),
                                        torch.nn.MaxPool2d((3, 5), (2, 4), (1, 2)),
                                        torch.nn.BatchNorm2d(32))
        self.skip21 = torch.nn.Conv2d(1, 32, 1)
        self.skip_pool21 = torch.nn.MaxPool2d((3, 9), (2, 8), (1, 4))

        # Block 2D 2
        self.seq4 = torch.nn.Sequential(torch.nn.Conv2d(32, 64, (3, 5), 1, (1, 2)),
                                        torch.nn.MaxPool2d((3, 5), (2, 4), (1, 2)),
                                        torch.nn.BatchNorm2d(64),
                                        self.activation,
                                        torch.nn.Conv2d(64, 128, (3, 5), 1, (1, 2)),
                                        torch.nn.MaxPool2d(3, 2, 1),
                                        torch.nn.BatchNorm2d(128))
        self.skip22 = torch.nn.Conv2d(32, 128, 1)
        self.skip_pool22 = torch.nn.MaxPool2d((5, 9), (4, 8), (2, 4))

        # Block 3D 1
        self.seq5 = torch.nn.Sequential(torch.nn.Conv3d(1, 32, (3, 5, 9), 1, (1, 2, 4)),
                                        torch.nn.MaxPool3d((1,1,3), (1, 1, 2), (0,0,1)),
                                        torch.nn.BatchNorm3d(32),
                                        self.activation,
                                        torch.nn.Conv3d(32, 64, (3, 5, 9), 1, (1, 2, 4)),
                                        torch.nn.MaxPool3d(3, 2, 1),
                                        torch.nn.BatchNorm3d(64))
        self.skip31 = torch.nn.Conv3d(1, 64, 1)
        self.skip_pool31 = torch.nn.MaxPool3d((3, 3, 5), (2, 2, 4), (1, 1, 2))

        # Block 3D 2
        self.seq6 = torch.nn.Sequential(torch.nn.Conv3d(64, 128, (3, 5, 15), 1, (1, 2, 7)),
                                        torch.nn.MaxPool3d(3, 2, 1),
                                        torch.nn.BatchNorm3d(128),
                                        self.activation,
                                        torch.nn.Conv3d(128, 256, (3, 5, 15), 1, (1, 2, 7)),
                                        torch.nn.MaxPool3d(3, 2, 1),
                                        torch.nn.BatchNorm3d(256))
        self.skip32 = torch.nn.Conv3d(64, 256, 1)
        self.skip_pool32 = torch.nn.MaxPool3d(5, 4, 2)

        # Fully connected
        self.obo = torch.nn.Sequential(torch.nn.Conv1d(8192, 1024, 1, 1),
                                       torch.nn.Conv1d(1024, 1024, 1, 1),
                                       torch.nn.Conv1d(1024, 1024, 1, 1),
                                       self.activation,
                                       torch.nn.Conv1d(1024, 1024, 1, 1),
                                       torch.nn.Conv1d(1024, 1024, 1, 1),
                                       torch.nn.Conv1d(1024, 1, 1, 1))
        self.soft_max = torch.nn.Softmax(-1)  # If the time stride is too big, the softmax will be done on a singleton
        # which always ouput a 1
        self.fc1 = torch.nn.Linear(8192, 1024)
        self.fc2 = torch.nn.Linear(1024, 512)
        self.fc3 = torch.nn.Linear(512, num_class)

    def forward(self, x):
        # Block 1D 1
        out = self.seq1(x)
        skip = self.skip_pool11(self.skip11(x))
        out = self.activation(out + skip)

        # Block 1D 2
        skip = self.skip_pool12(self.skip12(out))
        out = self.seq2(out)
        out = self.activation(out + skip)

        # Block 2D 1
        out = out.reshape((lambda b, c, h: (b, 1, c, h))(*out.shape))
        skip = self.skip_pool21(self.skip21(out))
        out = self.seq3(out)
        out = self.activation(out + skip)

        # Block 2D 2
        skip = self.skip_pool22(self.skip22(out))
        out = self.seq4(out)
        out = self.activation(out + skip)

        # Block 3D 1
        out = out.reshape((lambda b, c, w, h: (b, 1, c, w, h))(*out.shape))
        skip = self.skip_pool31(self.skip31(out))
        out = self.seq5(out)
        out = self.activation(out + skip)

        # Block 3D 2
        skip = self.skip_pool32(self.skip32(out))
        out = self.seq6(out)
        out = self.activation(out + skip)

        # Fully connected
        out = out.reshape(len(out), 8192, -1)
        out = torch.mean(out*self.soft_max(self.obo(out)), -1)
        out = self.dropout(self.activation(self.fc1(out)))
        out = self.dropout(self.activation(self.fc2(out)))
        return self.fc3(out)

