#########################################
# Johns Hopkins University              #
# 601.455 Computer Integrated Surgery 2 #
# Spring 2018                           #
# Query by Video For Surgical Activities#
# Felix Yu                              #
# JHED: fyu12                           #
# Gianluca Silva Croso                  #
# JHED: gsilvac1                        #
#########################################

import os
import numpy as np
import torch
from torch.autograd import Variable
import torch.optim
import Models
import sys

use_gpu = torch.cuda.is_available()

n_classes = 10
feat_dim = 512
feat_dim_tool = 528 
hidden_dim = 128


def video_encoder(which_model):
    """
    Create database of spatio-temporal features based on either mean averaging
    or GAP_RNN model using spatial features outputed by squeezenet and tool features.
    Saves database as 3 files: train.npy, val.npy and test.npy
    Each npy file contains the feature vectors for all examples in the
    corresponding set.
    :param which_model: 0 for mean averaging, 1 for GAP_RNN
    """
    if which_model == "0":
        print("Using a mean across the temporal component to generate database.")
        model_type = "mean"
        out_dim = 512
    if which_model == "1":
        print("Using GAP RNN to generate database.")
        # FILE CONTAINING WEIGHTS - MODIFY IF NECESSARY
        weight_file = "/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/weights/GAP_RNN_TOOL/128_0.736.pkl"
        model = Models.GAP_RNN(n_classes, feat_dim_tool, feat_dim_tool, hidden_dim)
        model_type = "GAP_RNN_TOOL"
        out_dim = 128
        model.load_state_dict(torch.load(weight_file))
        if use_gpu:
            model.cuda()
        model.eval()

    print("Running script to encode videos from spatial features.")
    print("--------------------------------")
    print("Constructing network and loading in weights.")
    sys.stdout.flush()

    # DIRECTORY CONTAINING SPATIAL AND TOOL FEATURES - MODIFY IF NECESSARY
    base_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/outputs'
    squeezenet_dir = base_dir + '/squeezenet_output/'
    tool_dir = base_dir + '/tools_output/' 
    # DIRECTORY WHERE NPY FILES WILL BE WRITTEN - MODIFY IF NECESSARY
    output_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/databases/'

    data_phases = ['train', 'val', 'test']
    
    print(" ")

    train_npy_names = list()
    val_npy_names = list()
    test_npy_names = list()
    for data_phase in data_phases:
        data_phase_dir = squeezenet_dir + data_phase + "/"
        activities = os.listdir(data_phase_dir)
        for activity in activities:
            data_phase_activity_dir = data_phase_dir + activity + "/"
            npy_names = os.listdir(data_phase_activity_dir)
            for npy_name in npy_names:
                if data_phase == "train":
                    train_npy_names.append(data_phase+"/"+activity+"/"+npy_name)
                if data_phase == "val":
                    val_npy_names.append(data_phase+"/"+activity+"/"+npy_name)
                if data_phase == "test":
                    test_npy_names.append(data_phase+"/"+activity+"/"+npy_name)
    for data_phase in data_phases:
        print('Looking at phase: ' + data_phase)
        if data_phase == 'train':
            npy_file_names = train_npy_names 
        if data_phase == 'val':
            npy_file_names = val_npy_names
        if data_phase == 'test':
            npy_file_names = test_npy_names
        out_mat_feats = np.empty((len(npy_file_names), out_dim))
        index = 0
        vid_num_key = list()
        for npy_file_name in npy_file_names:
            temp = npy_file_name.split('/')[-1]
            activity_num = temp.split('_')[0][1:]
            sub_num = temp.split('_')[1][1:]
            vid_num = temp[-7:-4]
           
            spatial_feats = np.load(squeezenet_dir + npy_file_name)
            tool_feats = np.load(tool_dir + npy_file_name.strip().split('.')[0] + '_tools.npy')
            feats = np.hstack((spatial_feats, tool_feats))
            feats = np.float32(feats)
            feats = torch.from_numpy(feats)
            if model_type == "GAP_RNN_TOOL":
                feats = torch.unsqueeze(feats, 1)
                feats = Variable(feats, requires_grad=False)
                if use_gpu:
                    feats = feats.cuda()
                out = model(feats).data.cpu().numpy()
                out_mat_feats[index, :] = out
                vid_num_key.append((vid_num, sub_num, activity_num))
            elif model_type == "mean":
                out_mat_feats[index, :] = np.mean(torch.Tensor.numpy(spatial_feats), axis=0)
                vid_num_key.append((vid_num, sub_num, activity_num))
            index += 1
        np.save(output_dir + model_type + "/" + data_phase + ".npy", out_mat_feats)
        vid_num_text = open(output_dir + model_type + "/" + data_phase + "_key.txt", 'w')
        for pair in vid_num_key:
            vid_num_text.write(pair[0] + "\t" + pair[1] + "\t" + pair[2] + "\n")
        vid_num_text.close()


def print_error():
    """
    Prints usage instructions
    """
    print("ERROR! Improper command line input detected.")
    print("Usage for temporal encoder: python temporal_encoder.py <model>")
    print("-"*40)
    print("Model legend")
    print("0: Temporal Averaging")
    print("1: GAP RNN\n")
    return


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print_error()
    else:
        model_used = sys.argv[1]
        if not (model_used == "0" or model_used == "1"):
            print_error()
        else:
            video_encoder(model_used)
