import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.nn.functional as F
import torchvision.models as models
import math 
from collections import OrderedDict
from torch.autograd import Variable
import pdb

SCALE_FACTOR = 1

class C3D(nn.Module):

    def __init__(self):
        super(C3D, self).__init__()
        
        self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 

        self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 

        self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 

        self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 

        self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.pool5 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)) 

        #self.fc6 = nn.Linear(8192, 128)
        self.fc6 = nn.Linear(8192, 2048)
        self.fc7 = nn.Linear(2048, 512)
        self.fc8 = nn.Linear(512, 128)

        self.dropout = nn.Dropout(p=0.5)

        self.relu = nn.ReLU()

    def forward(self, x): 
        # Nx3x16x112x112
        h = self.relu(self.conv1(x))
        # Nx64x16x112x112
        h = self.pool1(h)
        # Nx64x16x56x56
        h = self.relu(self.conv2(h))
        # Nx128x16x56x56
        h = self.pool2(h)
        # Nx128x8x28x28
        h = self.relu(self.conv3a(h))
        # Nx256x8x28x28
        h = self.relu(self.conv3b(h))
        # Nx256x8x28x28
        h = self.pool3(h)
        # Nx256x4x14x14
        h = self.relu(self.conv4a(h))
        # Nx512x4x14x14
        h = self.relu(self.conv4b(h))
        # Nx512x4x14x14
        h = self.pool4(h)
        # Nx512x2x7x7
        
        h = self.relu(self.conv5a(h))
        # Nx512x2x7x7
        h = self.relu(self.conv5b(h))
        # Nx512x2x7x7
        h = self.pool5(h)
        # Nx512x1x4x4
        h = h.view(-1, 8192)
        pool_5_feats = h
        h = self.relu(self.fc6(h))
        #fc_6_feats = h 
        h = self.dropout(h)
        h = self.relu(self.fc7(h))
        h = self.dropout(h)
        h = self.relu(self.fc8(h))
        fc_8_feats = h
        norm = fc_8_feats.norm(p=2, dim=1, keepdim=True)/SCALE_FACTOR
        fc_8_feats_normalized = fc_8_feats.div(norm)

        #return probs, pool_5_feats, fc_6_feats, fc_8_feats
        return fc_8_feats_normalized

class C3D_trunc(nn.Module):

    def __init__(self, fc6_out):
        super(C3D_trunc, self).__init__()
    
        self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 

        self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 

        self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 

        self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 

        self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)) 
        self.pool5 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)) 

        self.fc6 = nn.Linear(8192, fc6_out)

        self.dropout = nn.Dropout(p=0.5)

        self.relu = nn.ReLU()

    def forward(self, x): 
        # Nx3x16x112x112
        h = self.relu(self.conv1(x))
        # Nx64x16x112x112
        h = self.pool1(h)
        # Nx64x16x56x56
        h = self.relu(self.conv2(h))
        # Nx128x16x56x56
        h = self.pool2(h)
        # Nx128x8x28x28
        h = self.relu(self.conv3a(h))
        # Nx256x8x28x28
        h = self.relu(self.conv3b(h))
        # Nx256x8x28x28
        h = self.pool3(h)
        # Nx256x4x14x14
        h = self.relu(self.conv4a(h))
        # Nx512x4x14x14
        h = self.relu(self.conv4b(h))
        # Nx512x4x14x14
        h = self.pool4(h)
        # Nx512x2x7x7

        h = self.relu(self.conv5a(h))
        # Nx512x2x7x7
        h = self.relu(self.conv5b(h))
        # Nx512x2x7x7
        h = self.pool5(h)
        # Nx512x1x4x4
        h = h.view(-1, 8192)
        pool_5_feats = h
        h = self.relu(self.fc6(h))
        fc_6_feats = h
        norm = fc_6_feats.norm(p=2, dim=1, keepdim=True)/SCALE_FACTOR
        fc_6_feats_normalized = fc_6_feats.div(norm)

        #return probs, pool_5_feats, fc_6_feats, fc_8_feats
        return fc_6_feats_normalized


class C3D_soft(nn.Module):

    def __init__(self):
        super(C3D_soft, self).__init__()

        self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))

        self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

        self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

        self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

        self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool5 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1))

        #self.fc6 = nn.Linear(8192, 128)
        self.fc6 = nn.Linear(8192, 2048)
        self.fc7 = nn.Linear(2048, 512)
        self.fc8 = nn.Linear(512, 10)

        self.dropout = nn.Dropout(p=0.5)

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()

    def forward(self, x):
        # Nx3x16x112x112
        h = self.relu(self.conv1(x))
        # Nx64x16x112x112
        h = self.pool1(h)
        # Nx64x16x56x56
        h = self.relu(self.conv2(h))
        # Nx128x16x56x56
        h = self.pool2(h)
        # Nx128x8x28x28
        h = self.relu(self.conv3a(h))
        # Nx256x8x28x28
        h = self.relu(self.conv3b(h))
        # Nx256x8x28x28
        h = self.pool3(h)
        # Nx256x4x14x14
        h = self.relu(self.conv4a(h))
        # Nx512x4x14x14
        h = self.relu(self.conv4b(h))
        # Nx512x4x14x14
        h = self.pool4(h)
        # Nx512x2x7x7

        h = self.relu(self.conv5a(h))
        # Nx512x2x7x7
        h = self.relu(self.conv5b(h))
        # Nx512x2x7x7
        h = self.pool5(h)
        # Nx512x1x4x4
        h = h.view(-1, 8192)
        pool_5_feats = h
        h = self.relu(self.fc6(h))
        #fc_6_feats = h
        h = self.dropout(h)
        h = self.relu(self.fc7(h))
        h = self.dropout(h)

        logits = self.fc8(h)
        probs = self.softmax(logits)

        return probs


class TCN_Residual(nn.Module):
    def __init__(self, inplanes, outplanes, dropout=0.5, stride=1):
        super(TCN_Residual, self).__init__()
        self.inplanes = inplanes
        self.outplanes = outplanes
        self.stride = stride
        self.dropout = dropout
        self._make_layer()

    def _make_layer(self):
        self.bn1 = nn.BatchNorm1d(self.inplanes)
        self.dr1 = nn.Dropout(p=self.dropout)
        self.conv1 = nn.Conv1d(self.inplanes,
                               self.outplanes,
                               kernel_size=9,
                               padding=4,
                               stride=self.stride,
                               bias=False)

        if self.inplanes != self.outplanes:
            self.shortcut = nn.Conv1d(self.inplanes, self.outplanes, kernel_size=1,
                                      stride=self.stride, bias=False)
        else:
            self.shortcut = None

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        # we do pre-activation
        out = self.bn1(x)
        out = self.relu(out)
        out = self.dr1(out)
        out = self.conv1(out)

        if self.shortcut is not None:
            residual = self.shortcut(residual)

        out += residual
        return out

class GAP_RNN(nn.Module):
    def __init__(self, n_classes, feat_dim, hidden_dim, out_dim, triplet_training = True):
        super(GAP_RNN, self).__init__()
        self.n_classes = n_classes
        self.feat_dim = feat_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        self.triplet_training = triplet_training
        self.hidden = self.init_hidden() 
    
        self.lstm = nn.LSTM(input_size = self.feat_dim, hidden_size = self.hidden_dim)
        self.fc1 = nn.Linear(self.hidden_dim, self.out_dim) 
        self.gap = nn.AdaptiveAvgPool1d(output_size=1)
        self.fc2 = nn.Linear(self.out_dim, self.n_classes)
        self.relu = nn.ReLU()

    def init_hidden(self):
        return (Variable(torch.zeros(1, 1, self.hidden_dim).cuda()), Variable(torch.zeros(1, 1, self.hidden_dim)).cuda())

    def forward(self, x):
        out, self.hidden = self.lstm(x, self.hidden)
        out = torch.transpose(out, 0, 1)
        out = torch.transpose(out, 1, 2)
        out = self.gap(out)
        out = torch.squeeze(out, 2)
        out = self.fc1(out)
        if self.triplet_training:
            norm = out.norm(p=2, dim=1, keepdim=True)/SCALE_FACTOR
            out = out.div(norm)
            return out

        out = self.fc(self.fc2(out))
        return out 
        
    def reset_hidden(self):
        self.hidden = self.init_hidden()


class simple_RNN(nn.Module):
    def __init__(self, n_classes, feat_dim, hidden_dim, triplet_training = True):
        super(simple_RNN, self).__init__()
        self.n_classes = n_classes
        self.feat_dim = feat_dim
        self.hidden_dim = hidden_dim
        self.triplet_training = triplet_training
        self.hidden = self.init_hidden() 
    
        self.lstm = nn.LSTM(input_size = self.feat_dim, hidden_size = self.hidden_dim)
        self.fc = nn.Linear(self.hidden_dim, self.n_classes)
        self.relu = nn.ReLU()

    def init_hidden(self):
        return (Variable(torch.zeros(1, 1, self.hidden_dim).cuda()), Variable(torch.zeros(1, 1, self.hidden_dim)).cuda())

    def forward(self, x):
        out, self.hidden = self.lstm(x, self.hidden)
        if self.triplet_training:
            out = self.hidden[0]
            norm = out.norm(p=2, dim=2, keepdim=True)/SCALE_FACTOR
            out = out.div(norm)
            return out

        x = self.relu(self.fc(self.hidden[0]))
        return output        
        
    def reset_hidden(self):
        self.hidden = self.init_hidden()

class RESTCN_paper(nn.Module):

    def __init__(self, res, n_classes, feat_dim, dropout=0.5, triplet_training = True):
        super(RESTCN_paper, self).__init__()
        self.n_classes = n_classes
        self.feat_dim = feat_dim
        self.dropout = dropout
        self.triplet_training = triplet_training
        self.num_filters = [96, 96, 128, 160]

        self.pool = nn.MaxPool1d(kernel_size = 2, stride = 2)

        self.conv1 = nn.Conv1d(self.feat_dim,
                               self.num_filters[0],
                               kernel_size=9,
                               stride=1,
                               padding=4,
                               bias=False)

        self.layer0 = self._make_layer(res,
                                       self.num_filters[0],
                                       self.num_filters[1],
                                       stride=1)
        self.layer1 = self._make_layer(res,
                                       self.num_filters[1],
                                       self.num_filters[2],
                                       stride=1)
        self.layer2 = self._make_layer(res,
                                       self.num_filters[2],
                                       self.num_filters[3],
                                       stride=1)

        self.last_bn = nn.BatchNorm1d(self.num_filters[-1])
        self.last_relu = nn.ReLU(inplace=True)
        self.gap = nn.AdaptiveAvgPool1d(output_size=1)
        self.fc = nn.Linear(self.num_filters[-1], self.n_classes)
        self.softmax = nn.Softmax()

    def _make_layer(self, res, inplanes, outplanes, stride=1):
        layers = []
        layers.append(res(inplanes, outplanes, stride=stride))
        return nn.Sequential(*layers)

    def forward(self, x): 
        # N x feat_dim x T
        x = self.conv1(x)
        # N x 96 x T
        x = self.layer0(x)
        # N x 96 x T
        x = self.pool(x)
        # N x 96 x T/2
        x = self.layer1(x)
        # N x 128 x T/2
        x = self.pool(x)
        # N x 128 x T/4
        x = self.layer2(x)
        # N x 160 x T/4
        x = self.last_bn(x)
        x = self.last_relu(x)
        x = self.gap(x)
        # N x 160 x 1
        x = x.view(x.size(0), -1)
        out2 = x
        if self.triplet_training:
            norm = out2.norm(p=2, dim=1, keepdim=True)/SCALE_FACTOR
            out2 = out2.div(norm)
            return out2
        # N x 160 
        x = self.fc(x)

        return x, out2

