import glob
import os
import torch
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim
from torch.optim import lr_scheduler
import Models
import helpers
import sys

use_gpu = torch.cuda.is_available()

n_classes = 10
feat_dim = 516

def TCN_driver():
    
    print("Running script to extract TCN Features.")
    print("--------------------------------")
    print("Constructing network and loading in weights.")
    sys.stdout.flush()
    model = Models.RESTCN_paper(Models.TCN_Residual, n_classes, feat_dim)
    model.load_state_dict(torch.load('weights/TCN/036_0.312.pkl'))
    if use_gpu:
        model.cuda()
    model.eval()
    
    data_dir = '/home-2/fyu12@jhu.edu/scratch/ReiterLab/cataract_densenet_output/'
    output_dir = '/home-2/fyu12@jhu.edu/scratch/ReiterLab/cataract_TCN_output/'
    data_phases = ['train', 'val']
    
    print(" ")
    for data_phase in data_phases:
        print('Looking at phase: ' + data_phase)
        input_dir = data_dir + data_phase + '/'
        npy_file_names = os.listdir(input_dir)
        npy_file_names.sort()
        phase_dict = {}
        for i in range(1, 11):
            phase_dict[str(i)] = list()
        for npy_file_name in npy_file_names:
            phase_num = npy_file_name.split('_')[0][1:]
            phase_dict[phase_num].append(npy_file_name)
        for i in range(1, 11):
            print('Working on video phase: ' + str(i))
            sys.stdout.flush()
            npy_files = phase_dict[str(i)]
            vid_num_text_file = open(output_dir + data_phase + '/' + str(i) + '_vid_nums.txt', 'w')
            feat_matrix = np.empty((len(npy_files), 160))
            index = 0
            skipped_vids = 0
            for npy_file in npy_files:
                inputs = np.load(input_dir + npy_file)
                inputs = np.transpose(inputs)
                if inputs.shape[1] < 4:
                    skipped_vids += 1
                    continue
                vid_num = npy_file[-7:-4]
                vid_num_text_file.write(vid_num + '\n')
                inputs = torch.from_numpy(inputs)
                inputs = inputs.unsqueeze(0)
                if use_gpu:
                    inputs = Variable(inputs.cuda()).float()
                else:
                    inputs = Variable(inputs).float()
                outputs, feats = model(inputs)
                feat_matrix[index, :] = feats.data.cpu().numpy()
                index += 1
            feat_matrix = feat_matrix[0:len(npy_files)-skipped_vids, :]
            np.save(output_dir + data_phase + '/' + str(i) + '_feats.npy', feat_matrix)
            vid_num_text_file.close()
if __name__ == "__main__":
    TCN_driver()
