import math
import glob
import os
import torch
import numpy as np
import skvideo.io
import time
import pdb
import sys

from torchvision import transforms

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F

from torch.autograd import Variable

import torch.optim
from torch.optim import lr_scheduler

import torch.utils.data
import torch.utils.data.distributed

use_gpu = torch.cuda.is_available()
n_classes = 11

val_list = [
                '164',
                '150',
                '127',
                '114',
                '038',
                '118',
                '117',
                '120',
            ]

data_transform = transforms.Compose([
  transforms.ToPILImage(),
  transforms.Scale((256, 256)),
  transforms.CenterCrop(224),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225]),
])  

def get_squeezenet_spatial_feats(net, vid_file_path, out_file_path, out_file_name):
    # Load in video
    try:
        vid = skvideo.io.FFmpegReader(vid_file_path)
    except ValueError:
        print("\t\tIssue with video encoding. Skipping.")
        sys.stdout.flush()
        return
 
    # Set up feature matrices.
    num_frames = vid.getShape()[0]
    num_frames = int(math.ceil(float(num_frames)*59/600))
    num_frames -= num_frames%59+1
    num_images = int(num_frames/59)+1
    feats = np.empty((num_images, 512))

    # Get the iterator of video frames, as well as set up our clip matrix.
    cap = vid.nextFrame()

    cur_img_num = 0
    #We then iterate through the rest of the video
    for i in range(num_frames):
        #Print progress, as well as get the next video frame. 
        if i%590 == 0:
            print("\t\tProcessing clip " + str(int(i/59)+1))
            sys.stdout.flush()
        img = np.flip(next(cap), axis = 2)

        if i % 59 == 0: 
            #Format the input to be put into the C3D network.
            img = data_transform(img)
            img = torch.unsqueeze(img, 0)
            img = Variable(img).float()
            if use_gpu:
                img = img.cuda()
    
            #Feed the input into the network, extract network features, and place them into the feature matrices.

            feat = net.features(img)
            feat = F.adaptive_avg_pool2d(feat,(1,1)).view(1,-1)
            feat = feat.data.cpu().numpy()
            feats[cur_img_num,:] = feat
            cur_img_num += 1
    vid.close()

    #Save the feature matrices
    np.save(out_file_path + '/' + out_file_name, feats)

def squeezenet_driver():
    import pretrainedmodels
    model = pretrainedmodels.__dict__['squeezenet1_0'](num_classes=1000, pretrained='imagenet')
    dim_feats = 512 
    model.last_conv = nn.Conv2d(dim_feats, n_classes, kernel_size=(1, 1), stride=(1, 1)) 
    model.load_state_dict(torch.load(sys.argv[1]))
    if use_gpu:
        model.cuda()
    model.eval()
    data_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/new_videos/cataract_phase_separated/'
    output_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/outputs/squeezenet_output/'
    sub_dirs = os.listdir(data_dir)
    print(" ")
    print("Extracting SqueezeNet Features.")    
    for sub_dir in sub_dirs:
        print("Working on phase: " + sub_dir)
        sys.stdout.flush()
        phase_dir = data_dir+sub_dir + '/'
        vid_names = os.listdir(phase_dir)
        for vid_name in vid_names:
            if not vid_name[-3:] == 'avi':
                continue
            vid_num = vid_name[-7:-4]
            sub_num = vid_name[-13:-12]
            print("\tWorking on Video : " + vid_num + " " + sub_num)
            sys.stdout.flush()
            phase = 'test'
##            if vid_num in val_list:
#                phase = 'val'
#            else:
#                phase = 'train'
            input_file_path = phase_dir+vid_name
            output_file_name = 'p'+sub_dir+'_n'+sub_num+'_squeezenet_'+vid_num+'.npy'
            output_sub_dir = output_dir + '/' + phase + '/' + sub_dir
            get_squeezenet_spatial_feats(model, input_file_path, output_sub_dir, output_file_name)
    '''
    outFilePath = "test.npy"
    vidFilePath = "/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/raw_data/cataract_phase_separated/9/p9_n1_vid_120.avi"
    get_squeezenet_spatial_feats(model, 120, vidFilePath, './',outFilePath)
    '''
if __name__ == "__main__":
    squeezenet_driver()
