import glob
import os
import torch
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim
from torch.optim import lr_scheduler
import Models
import helpers
import sys

use_gpu = torch.cuda.is_available()

n_classes = 10
feat_dim = 512 
hidden_dim = 128

def video_encoder(which_model):
    if which_model == "0":
        print("Using a mean across the temporal component to generate database.")
        model_type = "mean"
        out_dim = 512
    if which_model == "1":
        print("Using TCN to generate database.")
        weight_dir = "/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/weights/TCN/154_0.439.pkl"
        model = Models.RESTCN_paper(Models.TCN_Residual, n_classes, feat_dim)
        model_type = "TCN"
        out_dim = 160
    if which_model == "2":
        print("Using RNN to generate database.")
        weight_dir = "/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/weights/RNN_TL/147_0.472.pkl"
        model = Models.simple_RNN(n_classes,feat_dim, hidden_dim)
        model_type = "RNN_TL"
        out_dim = 128
    if which_model == "3":
        print("Using GAP RNN to generate database.")
        weight_dir = "/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/weights/GAP_RNN_TL/189_0.528.pkl"
        model = Models.GAP_RNN(n_classes, feat_dim, feat_dim, hidden_dim)
        model_type = "GAP_RNN_TL"
        out_dim = 128
    if not which_model == "0":
        model.load_state_dict(torch.load(weight_dir))
        if use_gpu:
            model.cuda()
        model.eval()

    print("Running script to encode videos from spatial features.")
    print("--------------------------------")
    print("Constructing network and loading in weights.")
    sys.stdout.flush()
    
    data_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/outputs/squeezenet_output/'
    output_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/databases/'
    data_phases = ['train', 'val']
    
    print(" ")

    train_npy_names = list()
    val_npy_names = list()
    for data_phase in data_phases:
        data_phase_dir = data_dir + data_phase + "/"
        activities = os.listdir(data_phase_dir)
        for activity in activities:
            data_phase_activity_dir = data_phase_dir + activity + "/"
            npy_names = os.listdir(data_phase_activity_dir)
            for npy_name in npy_names:
                if data_phase == "train":
                    train_npy_names.append(data_phase+"/"+activity+"/"+npy_name)
                if data_phase == "val":
                    val_npy_names.append(data_phase+"/"+activity+"/"+npy_name)

    for data_phase in data_phases:
        print('Looking at phase: ' + data_phase)
        if data_phase == 'train':
            npy_file_names = train_npy_names 
        if data_phase == 'val':
            npy_file_names = val_npy_names
        out_mat_feats = np.empty((len(npy_file_names), out_dim))

        index = 0
        vid_num_key = list()

        for npy_file_name in npy_file_names:
            temp = npy_file_name.split('/')[-1]
            activity_num = temp.split('_')[0][1:]
            sub_num = temp.split('_')[1][1:]
            vid_num = temp[-7:-4]
           
            spatial_feats = np.load(data_dir + npy_file_name)
            spatial_feats = np.float32(spatial_feats)
            spatial_feats = torch.from_numpy(spatial_feats)
            if model_type == "TCN":
                spatial_feats = torch.unsqueeze(torch.transpose(spatial_feats, 0, 1), 0)
                if use_gpu:
                    spatial_feats = Variable(spatial_feats.cuda()).float()
                else:
                    spatial_feats =  Variable(spatial_feats).float()
                if list(spatial_feats.size())[2] < 4:
                    vid_num_key.append(("0", "0", "0"))
                    out_mat_feats[index,:] = -10000
                else:
                    vid_num_key.append((vid_num, sub_num, activity_num))
                    out_mat_feats[index,:] = model(spatial_feats).data.cpu().numpy()
            elif model_type == "RNN" or model_type == "GAP_RNN_TL":
                spatial_feats = torch.unsqueeze(spatial_feats, 1)
                spatial_feats = Variable(spatial_feats, requires_grad = False)
                if use_gpu:
                    spatial_feats = spatial_feats.cuda()
                out_mat_feats[index,:] = model(spatial_feats).data.cpu().numpy()
                vid_num_key.append((vid_num, sub_num, activity_num))
            elif model_type == "mean":
                out_mat_feats[index,:] = np.mean(torch.Tensor.numpy(spatial_feats), axis = 0)
                vid_num_key.append((vid_num, sub_num, activity_num))
            index += 1
        np.save(output_dir  +  model_type + "/" + data_phase + ".npy", out_mat_feats)
        vid_num_text = open(output_dir + model_type + "/" + data_phase + "_key.txt", 'w')
        for pair in vid_num_key:
            vid_num_text.write(pair[0] + "\t" + pair[1] + "\t" + pair[2] + "\n")
        vid_num_text.close()

def print_error():
    print("ERROR! Improper command line input detected.")
    print("Usage for temporal encoder: python temporal_encoder.py <model>")
    print("-"*40)
    print("Model legend")
    print("0: Temporal Averaging")
    print("1: TCN")
    print("2: RNN")
    print("3: GAP RNN\n")
    return

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print_error()

    else:
        which_model = sys.argv[1]
        if not (which_model == "0" or which_model == "1" or which_model == "2" or which_model == "3"):
            print_error()
        else:            
            video_encoder(which_model)
