#########################################
# Johns Hopkins University              #
# 601.455 Computer Integrated Surgery 2 #
# Spring 2018                           #
# Query by Video For Surgical Activities#
# Felix Yu                              #
# JHED: fyu12                           #
# Gianluca Silva Croso                  #
# JHED: gsilvac1                        #
#########################################

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
from torch.autograd import Variable

# Size of normalized feature vector for triplet loss training
SCALE_FACTOR = 1


class GAP_RNN(nn.Module):
    """
    PyTorch model for Gap RNN.
    """
    def __init__(self, n_classes, feat_dim, hidden_dim, out_dim, triplet_training=True):
        super(GAP_RNN, self).__init__()
        self.n_classes = n_classes
        self.feat_dim = feat_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        self.triplet_training = triplet_training
        self.hidden = self.init_hidden() 
    
        self.lstm = nn.LSTM(input_size=self.feat_dim, hidden_size=self.hidden_dim)
        self.fc1 = nn.Linear(self.hidden_dim, self.out_dim) 
        self.gap = nn.AdaptiveAvgPool1d(output_size=1)
        self.fc2 = nn.Linear(self.out_dim, self.n_classes)
        self.relu = nn.ReLU()

    def init_hidden(self):
        return Variable(torch.zeros(1, 1, self.hidden_dim).cuda()), Variable(torch.zeros(1, 1, self.hidden_dim)).cuda()

    def forward(self, x):
        out, self.hidden = self.lstm(x, self.hidden)
        out = torch.transpose(out, 0, 1)
        out = torch.transpose(out, 1, 2)
        out = self.gap(out)
        out = torch.squeeze(out, 2)
        out = self.fc1(out)
        if self.triplet_training:
            norm = out.norm(p=2, dim=1, keepdim=True)/SCALE_FACTOR
            out = out.div(norm)
            return out

        out = self.fc(self.fc2(out))
        return out 
        
    def reset_hidden(self):
        self.hidden = self.init_hidden()
