#########################################
# Johns Hopkins University              #
# 601.455 Computer Integrated Surgery 2 #
# Spring 2018                           #
# Query by Video For Surgical Activities#
# Felix Yu                              #
# JHED: fyu12                           #
# Gianluca Silva Croso                  #
# JHED: gsilvac1                        #
#########################################

import math
import os
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import Models
import helpers
import sys
import gc 

use_gpu = torch.cuda.is_available()
# DIRECTORY WHERE WEIGHT FILES WILL BE WRITTEN - MODIFY IF NECESSARY
weight_out_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/weights/GAP_RNN_TOOL/'

# Training parameters
NUM_EPOCHS = 200
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-4
ALPHA = 0.2
FINAL_LAYER_DIM = 128
INPUT_FEATURE_DIM = 528
BATCH_SEGS_PER_ACTIVITY = 10
NUM_ACTIVITIES = 10
VAL_NUM_EXEM_PER_ACT = 16 


def train_model(model, criterion, optimizer, scheduler, num_epochs=100):
    """
    Train GAP_RNN model with triplet loss
    :param model: the model object - here, should be of class GAP_RNN
    from models.py
    :param criterion: the loss function - here, should be nn.TripletMarginLoss
    :param optimizer: torch optimizer
    :param scheduler: torch scheduler
    :param num_epochs: number of epochs for training
    """
    # DIRECTORY CONTAINING SPATIAL AND TOOL FEATURES - MODIFY IF NECESSARY
    spatial_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/outputs/squeezenet_output/'
    tool_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/outputs/tools_output/'
    best_acc = 0.0

    train_datasets = list()
    val_datasets = list()
    for phase in ['train', 'val']:
        for i in range(NUM_ACTIVITIES):
            act = i+1
            print('Currently loading dataset for activity ' + str(act) + ' in phase ' + phase)
            sys.stdout.flush()
            spatial_npy_dir = spatial_dir + '/' + phase + '/' + str(act)
            tool_npy_dir = tool_dir + '/' + phase + '/' + str(act)
            if phase == 'train':
                train_datasets.append(helpers.SpatialToolFeatDataset(spatial_npy_dir, tool_npy_dir))
            if phase == 'val':
                if not i == NUM_ACTIVITIES-1:
                    val_datasets.append(helpers.SpatialToolFeatDataset(spatial_npy_dir, tool_npy_dir))
    print()
    sys.stdout.flush()
    train_dataloaders = list()
    for i in range(NUM_ACTIVITIES):
        train_dataloaders.append(DataLoader(train_datasets[i], batch_size = 1,
                                            shuffle = True, num_workers = 1))

    max_iter_per_epoch = math.inf

    for i in range(len(train_datasets)):
        size_activity = int(len(train_datasets[i])/BATCH_SEGS_PER_ACTIVITY)
        if size_activity < max_iter_per_epoch:
            max_iter_per_epoch = size_activity

    num_clip_per_act = list()
    for i in range(len(val_datasets)):
        num_clip_per_act.append(len(val_datasets[i]))
        print(len(val_datasets[i]))
    val_set = helpers.generate_val_exemplars(num_clip_per_act, VAL_NUM_EXEM_PER_ACT)

    for epoch in range(num_epochs):
        print('\t\t  Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 55)
        sys.stdout.flush()
        running_loss = 0.0
        running_num_exemplar = 0
        # TRAINING
        model.train(True)  # Set model to training mode
        for i in range(max_iter_per_epoch):
            # get the inputs            
            optimizer.zero_grad()
            features = Variable(torch.randn(BATCH_SEGS_PER_ACTIVITY*NUM_ACTIVITIES, FINAL_LAYER_DIM))
            if use_gpu:
                features = features.cuda() 
            for j in range(NUM_ACTIVITIES):
                for k in range(BATCH_SEGS_PER_ACTIVITY):
                    x = next(iter(train_dataloaders[j]))
                    x = torch.transpose(x, 0, 1)
                    x = Variable(x, requires_grad = True)
                    if use_gpu:
                        x = x.cuda()
                    features[j*BATCH_SEGS_PER_ACTIVITY+k,:] = model(x)

            feats = features.data.cpu().numpy()
            trip_a, trip_p, trip_n = helpers.select_triplets(feats, ALPHA, BATCH_SEGS_PER_ACTIVITY)
            num_trip = len(trip_a)
            running_num_exemplar += num_trip
            trip_a = Variable(torch.LongTensor(trip_a))
            trip_p = Variable(torch.LongTensor(trip_p))
            trip_n = Variable(torch.LongTensor(trip_n))
            if use_gpu:
                trip_a = trip_a.cuda()
                trip_p = trip_p.cuda()
                trip_n = trip_n.cuda()
            a = features.index_select(0, trip_a)
            p = features.index_select(0, trip_p)
            n = features.index_select(0, trip_n)
            # pdb.set_trace()
            loss = criterion(a, p, n)
            running_loss += loss.data[0]*num_trip
            # backward + optimize only if in training phase
            loss.backward()
            optimizer.step()
            model.reset_hidden()
            #Clean out memory 
            del trip_a, trip_p, trip_n
            del a, p, n
            del feats
            del features
            gc.collect()
            print("\t batch: {:d}/{:d}, Average Loss: {:3f}".format(i+1, max_iter_per_epoch, running_loss/float(running_num_exemplar)), end = '\r')
            sys.stdout.flush()
        print('Done training epoch {:d}. Average training loss: {:.5f}'.format(epoch+1, running_loss/float(running_num_exemplar)))
        sys.stdout.flush()
        # VALIDATION
        model.train(False)  # Set model to evaluate mode
        success = 0
        smaller_success = 0
        running_loss = 0.0
        avg_loss = 0.0
        for i in range(len(val_set)):
            print('Validating exemplar number: {:d}/{:d}, with average loss of: {:.5f}'.format(i+1, len(val_set), avg_loss), end = '\r')
            sys.stdout.flush()
            triplet = val_set[i]
            anch_act = triplet[0]
            neg_act = triplet[1]
            anch_ind = triplet[2]
            pos_ind = triplet[3]
            neg_ind = triplet[4]
            anch_in = val_datasets[anch_act][anch_ind]
            pos_in = val_datasets[anch_act][pos_ind]
            neg_in = val_datasets[neg_act][neg_ind]
            inputs = [anch_in, pos_in, neg_in] 
            feats = Variable(torch.randn(len(inputs),FINAL_LAYER_DIM))
            for j in range(len(inputs)):
                single_in = inputs[j] 
                single_in = torch.unsqueeze(single_in, 1)
                single_in = Variable(single_in, requires_grad=False)
                if use_gpu:
                    single_in = single_in.cuda()
                feats[j,:] = model(single_in)
             
            feats = feats.data.cpu().numpy()
            dist_1 = np.linalg.norm(feats[0]-feats[1])
            dist_2 = np.linalg.norm(feats[0]-feats[2])
            loss = max(0, dist_1 - dist_2 + ALPHA)
            running_loss += loss
            avg_loss = running_loss/float(i+1)
            # Remove references to try and free up memory
            del feats
            del anch_in
            del pos_in
            del neg_in
            del inputs
            gc.collect()
 
            if loss <= 0:
                success += 1
            if loss <= ALPHA:
                smaller_success += 1

        epoch_acc = float(success)/float(len(val_set))
        smaller_epoch_acc = float(smaller_success)/float(len(val_set))
        # Print statistics
        print("Validation Results for Epoch {:d}, accuracy: {:.3f}, smaller accuracy: {:.3f}, average loss: {:.3f}".format(epoch+1, epoch_acc, smaller_epoch_acc, avg_loss))
        print()
        sys.stdout.flush()
        # deep copy the model
        if epoch_acc >= best_acc:
            best_acc = epoch_acc
            best_model_wts = model.state_dict()

            if not os.path.exists(weight_out_dir):
                os.makedirs(weight_out_dir)

            torch.save(best_model_wts, weight_out_dir + '/{:03d}_{:.3f}.pkl'.format(epoch, epoch_acc))
            
            # Remove references to try and free up memory
            best_model_wts = 0
            gc.collect()

    print('Best val Acc: {:4f}'.format(best_acc))
    sys.stdout.flush()

def main():
    """
    Set up model training for GAP_RNN
    """
    n_classes = NUM_ACTIVITIES
    feat_dim = INPUT_FEATURE_DIM
    hidden_dim = feat_dim
    out_dim = FINAL_LAYER_DIM 
    model = Models.GAP_RNN(n_classes, feat_dim, hidden_dim, out_dim)
    criterion = nn.TripletMarginLoss(margin=ALPHA, p=2)
    num_epochs = NUM_EPOCHS
    lr = LEARNING_RATE
    weight_decay = WEIGHT_DECAY
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Decay LR by a factor of 0.1 every 40 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=130,
                                           gamma=0.1)
    print("Starting!")
    if use_gpu:
        print("using gpu")
        model = model.cuda()
    sys.stdout.flush()
    train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=num_epochs)



if __name__ == "__main__":
    main()
