import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
from torch.autograd import Variable

# Configuration
# from config import CONFIG
# HIDDEN_DIM_SIZE     = CONFIG['lstm_hidden_dim']
# LSTM_DROPOUT        = CONFIG['lstm_dropout']
# DATASET_NUM_CLASSES = CONFIG['dataset_classes']
HIDDEN_DIM_SIZE = 128
LSTM_DROPOUT = 0.5
DATASET_NUM_CLASSES = 10

SCALE_FACTOR = 1

class Model_1(nn.Module):
    ''' Spatial '''
    def __init__(self):
        super(Model_1, self).__init__()
        # Set up base image feature extractor
        self.base_model = nn.Sequential(*list(models.resnet18(pretrained=True).children())[:-1])
        base_model_fc_size = list(self.base_model.parameters())[-1].size(0)

        # Freeze weights
        for param in list(self.base_model.parameters())[:30]:
            param.requires_grad = False

        # LSTM Layer
        self.lstmlayer = nn.LSTM(
            input_size  = base_model_fc_size,
            hidden_size = HIDDEN_DIM_SIZE,
            dropout     = LSTM_DROPOUT,
            batch_first = True
        )

        # Per-frame prediction layer
        self.preds = nn.Linear(HIDDEN_DIM_SIZE, DATASET_NUM_CLASSES)

    def forward(self, X):
        # Unpad the videos in each batch
        # empty_im = Variable(torch.zeros(3,224,224)).cuda()
        # videos = []
        # for batch in X:
        #     images = []
        #     for im in batch:
        #         # Only add non-zero images (zero images added as padding)
        #         if(torch.equal(im, empty_im)):
        #             break
        #         images.append(im)
        #     videos.append(images)
        videos = X
        # Get image features for each frame in the video
        vid_features = []
        for video in videos:
            x = self.base_model(torch.stack(video, 0))
            x = x.view(x.size(0), int(np.prod(x.size()[1:]))) # Flatten
            vid_features.append(x)

        # Put features into lstm
        lstm_out = [self.lstmlayer(torch.unsqueeze(vid, 0)) for vid in vid_features]

        # Per-task predictions
        preds = [self.preds(x[0, -1]) for x, _ in lstm_out]

        # Final output should be shape [batch size, 9]
        return torch.stack(preds, 0)


class Model_2(nn.Module): # adapted for triplet loss
    ''' Spatial '''
    def __init__(self):
        super(Model_2, self).__init__()
        # Set up base image feature extractor
        self.base_model = nn.Sequential(*list(models.resnet18(pretrained=True).children())[:-1])
        base_model_fc_size = list(self.base_model.parameters())[-1].size(0)

        # Freeze weights
        for param in list(self.base_model.parameters())[:30]:
            param.requires_grad = False

        # LSTM Layer
        self.lstmlayer = nn.LSTM(
            input_size  = base_model_fc_size,
            hidden_size = HIDDEN_DIM_SIZE,
            dropout     = LSTM_DROPOUT,
            batch_first = True
        )

    def forward(self, X):
        videos = X
        # Get image features for each frame in the video
        vid_features = []
        for video in videos:
            x = self.base_model(torch.stack(video, 0))
            x = x.view(x.size(0), int(np.prod(x.size()[1:]))) # Flatten
            vid_features.append(x)

        # Put features into lstm
        lstm_out = [self.lstmlayer(torch.unsqueeze(vid, 0)) for vid in vid_features]

        hidden = [x[0, -1] for x, _ in lstm_out]
        out = torch.stack(hidden, 0)
        norm = out.norm(p=2, dim=1, keepdim=True)/SCALE_FACTOR
        normalized = out.div(norm)
        return normalized
