#########################################
# Johns Hopkins University              #
# 601.455 Computer Integrated Surgery 2 #
# Spring 2018                           #
# Query by Video For Surgical Activities#
# Felix Yu                              #
# JHED: fyu12                           #
# Gianluca Silva Croso                  #
# JHED: gsilvac1                        #
#########################################

import os
import numpy as np
import skvideo.io
import sys
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim
import torch.utils.data
import torch.utils.data.distributed

use_gpu = torch.cuda.is_available()
n_classes = 11

# Identifiers for videos used as validation set
val_list = [
                '164',
                '150',
                '127',
                '114',
                '038',
                '118',
                '117',
                '120',
            ]

data_transform = transforms.Compose([
  transforms.ToPILImage(),
  transforms.Scale((256, 256)),
  transforms.CenterCrop(224),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225]),
])

# DIRECTORY WHERE PHASE CLIPS ARE LOCATED - MODIFY IF NECESSARY
data_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/raw_data/cataract_phase_separated/'
# DIRECTORY WHERE SPATIAL FEATURE NPY FILES WILL BE WRITTEN - MODIFY IF NECESSARY
output_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/outputs/squeezenet_output_test/'
# DIRECTORY WHERE TEST PHASE CLIPS ARE LOCATED - MODIFY IF NECESSARY
test_dir  = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/new_videos/cataract_phase_separated/'

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
if not os.path.exists(output_dir + '/train'):
    os.makedirs(output_dir + '/train')
if not os.path.exists(output_dir + '/val'):
    os.makedirs(output_dir + '/val')
if not os.path.exists(output_dir + '/test'):
    os.makedirs(output_dir + '/test')

def get_squeezenet_spatial_feats(net, vid_file_path, out_file_path, out_file_name):
    """
    For a specific video, obtain a spatial feature vector for one frame per second
    of the video. Save as a TxF matrix where T is the number of frames used and F
    is the feature vector size.
    :param net: the pytorch model object
    :param vid_file_path: path to the video file
    :param out_file_path: path of the directory where the npy file will be saved
    :param out_file_name: name of the npy file to be saved
    """
    # Load in video
    try:
        vid = skvideo.io.FFmpegReader(vid_file_path)
    except ValueError:
        print("\t\tIssue with video encoding. Skipping.")
        sys.stdout.flush()
        return
 
    # Set up feature matrices.
    num_frames = vid.getShape()[0]
    # NOTE! THIS ASSUMES 30 FPS! Change the FPS when necessary!
    fps = 30
    num_frames -= num_frames % fps+1
    num_images = int(num_frames/fps)+1
    feats = np.empty((num_images, 512))

    # Get the iterator of video frames, as well as set up our clip matrix.
    cap = vid.nextFrame()

    cur_img_num = 0
    # We then iterate through the rest of the video
    for i in range(num_frames):
        # Print progress, as well as get the next video frame.
        if i % fps*10 == 0:
            print("\t\tProcessing clip " + str(int(i/fps)+1))
            sys.stdout.flush()
        img = np.flip(next(cap), axis=2)

        if i % fps == 0: 
            # Format the input to be put into the squeezenet network.
            img = data_transform(img)
            img = torch.unsqueeze(img, 0)
            img = Variable(img).float()
            if use_gpu:
                img = img.cuda()
    
            # Feed the input into the network, extract network features, and place them into the feature matrices.
            feat = net.features(img)
            feat = F.adaptive_avg_pool2d(feat, (1, 1)).view(1, -1)
            feat = feat.data.cpu().numpy()
            feats[cur_img_num, :] = feat
            cur_img_num += 1
    vid.close()

    # Save the feature matrices
    np.save(out_file_path + '/' + out_file_name, feats)


def squeezenet_driver():
    """
    Driver to load squeezenet with pretrained weights and obtain feature matrix for
    each phase clip in data_dir, writing it out as an npy file to output_dir. Weights
    file should be given as first system argument when running the driver
    """
    if len(sys.argv) < 2:
        print("Usage: python squeezenet_driver.py <weight_file>")
        return
    import pretrainedmodels
    model = pretrainedmodels.__dict__['squeezenet1_0'](num_classes=1000, pretrained='imagenet')
    dim_feats = 512 
    model.last_conv = nn.Conv2d(dim_feats, n_classes, kernel_size=(1, 1), stride=(1, 1)) 
    model.load_state_dict(torch.load(sys.argv[1]))
    if use_gpu:
        model.cuda()
    model.eval()
    sub_dirs = os.listdir(data_dir)
    print(" ")
    print("Extracting SqueezeNet Features.")    
    for sub_dir in sub_dirs:
        print("Working on phase: " + sub_dir)
        sys.stdout.flush()
        phase_dir = data_dir+sub_dir + '/'
        vid_names = os.listdir(phase_dir)
        for vid_name in vid_names:
            if not vid_name[-3:] == 'avi':
                continue
            vid_num = vid_name[-7:-4]
            sub_num = vid_name[-13:-12]
            print("\tWorking on Video : " + vid_num + " " + sub_num)
            sys.stdout.flush()

            if vid_num in val_list:
                phase = 'val'
            else:
                phase = 'train'
            input_file_path = phase_dir+vid_name
            output_file_name = 'p'+sub_dir+'_n'+sub_num+'_squeezenet_'+vid_num+'.npy'
            output_sub_dir = output_dir + '/' + phase + '/' + sub_dir
            if not os.path.exists(output_sub_dir):
                os.makedirs(output_sub_dir)
            get_squeezenet_spatial_feats(model, input_file_path, output_sub_dir, output_file_name)

    sub_dirs = os.listdir(test_dir)
    print(" ")
    print("Extracting SqueezeNet Features for test videos.")    
    phase = 'test'
    for sub_dir in sub_dirs:
        print("Working on phase: " + sub_dir)
        sys.stdout.flush()
        phase_dir = test_dir+sub_dir + '/'
        vid_names = os.listdir(phase_dir)
        for vid_name in vid_names:
            if not vid_name[-3:] == 'avi':
                continue
            vid_num = vid_name[-7:-4]
            sub_num = vid_name[-13:-12]
            print("\tWorking on Video : " + vid_num + " " + sub_num)
            sys.stdout.flush()
            
            input_file_path = phase_dir+vid_name
            output_file_name = 'p'+sub_dir+'_n'+sub_num+'_squeezenet_'+vid_num+'.npy'
            output_sub_dir = output_dir + '/' + phase + '/' + sub_dir
            if not os.path.exists(output_sub_dir):
                os.makedirs(output_sub_dir)
            get_squeezenet_spatial_feats(model, input_file_path, output_sub_dir, output_file_name)


if __name__ == "__main__":
    squeezenet_driver()
