import os
import numpy as np
import skvideo.io
import skimage
import skimage.transform
from skimage import img_as_ubyte
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim
from torch.optim import lr_scheduler
import Models
import lmdb
import pdb
import sys

VIDEO_WIDTH = 112
VIDEO_HEIGHT = 112
NUM_CHANNELS = 3
NUM_ACTIVITIES = 10
use_gpu = torch.cuda.is_available()

def generate_val_exemplars(num_clip_per_act, num_exem_per_act):
    #the following list will  contain a 5-tuple
    #First element details class of anchor and positive
    #Second element details class of negative
    #Third, fourth, fifth elements detail indices of anchor, positive, and negative
    exemplar_indices = list()
    if min(num_clip_per_act) < num_exem_per_act:
        num_exem_per_act = min(num_clip_per_act)
        print("Number of exemplars per activity too large. Setting it to " + str(num_exem_per_act))
    if num_exem_per_act % 9 != 0:
        num_exem_per_act -= num_exem_per_act % 9
        if num_exem_per_act == 0:
            num_exem_per_act = 9
        print("Number of exemplars was not divisible by 9. Setting it to " + str(num_exem_per_act))
    for i in range(NUM_ACTIVITIES):
        act_clip_num = num_clip_per_act[i]
        anchors = np.random.choice(act_clip_num, num_exem_per_act, replace=False)
        pair_dict = dict()
        for j in range(num_exem_per_act):
            pos = np.random.choice(act_clip_num, 1)[0]
            while pos == anchors[j] or pair_dict.get(pos) == anchors[j]:
                pos = np.random.choice(act_clip_num, 1)[0]
            pair_dict[anchors[j]] = pos
        j = 0
        for k in range(NUM_ACTIVITIES):
            if k == i:
                continue
            negatives = np.random.choice(int(num_clip_per_act[k]), int(num_exem_per_act/9), replace=False)
            for neg in negatives:
                exemplar_indices.append((i, k, anchors[j], pair_dict[anchors[j]], neg))
                j += 1
    return exemplar_indices

def select_triplets(features, alpha, num_segs_per_act):
    """ Select the triplets for training
    DISCLAIMER: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py
    """
    trip_idx = 0
    feat_start_idx = 0
    num_trips = 0
    triplets = []
    for i in range(NUM_ACTIVITIES):
        for j in range(1, num_segs_per_act):
            a_idx = feat_start_idx + j - 1
            neg_dists_sqr = np.sum(np.square(features[a_idx] - features), 1)
            neg_dists_sqr[feat_start_idx:feat_start_idx+num_segs_per_act] = np.NaN
            for pair in range(j, num_segs_per_act):   # For every possible positive pair.
                p_idx = feat_start_idx + pair
                pos_dist_sqr = np.sum(np.square(features[a_idx] - features[p_idx]))
                semi_hard_negs = np.where(np.logical_and(neg_dists_sqr-pos_dist_sqr<alpha, pos_dist_sqr<neg_dists_sqr))[0]  # FaceNet selection
                num_semi_negs = semi_hard_negs.shape[0]
                if num_semi_negs > 0:
                    rnd_idx = np.random.randint(num_semi_negs)
                    n_idx = semi_hard_negs[rnd_idx]
                    triplets.append((a_idx, p_idx, n_idx))
                    #print('Triplet %d: (%d, %d, %d), pos_dist=%2.6f, neg_dist=%2.6f (%d, %d, %d, %d, %d)' %
                    #    (trip_idx, a_idx, p_idx, n_idx, pos_dist_sqr, neg_dists_sqr[n_idx], num_semi_negs, rnd_idx, i, j, feat_start_idx))
                    trip_idx += 1
                num_trips += 1
        feat_start_idx += num_segs_per_act
    np.random.shuffle(triplets)
    triplets_a = np.zeros(len(triplets))
    triplets_p = np.zeros(len(triplets))
    triplets_n = np.zeros(len(triplets))
    for i in range(len(triplets)):
        triplets_a[i] = triplets[i][0]
        triplets_p[i] = triplets[i][1]
        triplets_n[i] = triplets[i][2]
    return triplets_a, triplets_p, triplets_n


def load_lmdb_txns(file_path):

    phases = ['train', 'test']
    train_txn = list()
    test_txn = list()
    train_names = list()
    test_names = list()

    for phase in phases:
        for i in range(10):
            print("Loading activity " + str(i) + " from phase " + phase)
            lmdb_path = file_path + '/' + phase + '/' + str(i+1)
            env = lmdb.open(lmdb_path, readonly=True)
            txn = env.begin()
            names = list()
            cursor = txn.cursor()
            for key, value in cursor:
                names.append(key)
            if phase == 'train':
                train_txn.append(txn)
                train_names.append(names)
            if phase == 'test':
                test_txn.append(txn)
                test_names.append(names)

    return train_txn, test_txn, train_names, test_names


def get_batch(train_names, no_repeat_tracker, batch_seg_per_activity):
    segment_names = list()

    for i in range(10):

        names = train_names[i]
        tracker = no_repeat_tracker[i]

        possible_sample_ind = np.where(tracker == 0)[0]

        if possible_sample_ind.shape[0] < batch_seg_per_activity:
            no_repeat_tracker[i] = np.zeros(len(names))
            possible_sample_ind = np.array(range(len(names)))

        sample_ind = np.random.choice(possible_sample_ind, size=batch_seg_per_activity, replace=False)
        for j in range(len(sample_ind)):
            tracker[sample_ind[j]] = 1
            segment_names.append(names[sample_ind[j]])
    return segment_names


def read_train_test_split(file_dir):
    train_file_path = file_dir + 'train_phase_labels.txt'
    test_file_path = file_dir + 'test_phase_labels.txt'
    train_file = open(train_file_path)
    test_file = open(test_file_path)
    train_list = list()
    test_list = list()
    for line in train_file:
        components = line.split()
        train_list.append('p' + components[0] + '_n' + components[1] + '_vid_' + components[2] + '.avi')
    for line in test_file:
        components = line.split()
        test_list.append('p' + components[0] + '_n' + components[1] + '_vid_' + components[2] + '.avi')
    return train_list, test_list


def get_C3D_spatial_feats(net, vid_file_path, out_file_path, out_file_name):
    # Load in video
    try:
        vid = skvideo.io.FFmpegReader(vid_file_path)
    except ValueError:
        print("\t\tIssue with video encoding. Skipping.")
        sys.stdout.flush()
        return
    
    # Set up feature matrices.
    num_frames = vid.getShape()[0]
    num_frames -= num_frames%16
    num_clips = int(num_frames/16)-1
    pool5_feats = np.empty((num_clips, 8192))
    fc6_feats = np.empty((num_clips, 4096))
    fc7_feats = np.empty((num_clips, 4096))

    # Get the iterator of video frames, as well as set up our clip matrix.
    cap = vid.nextFrame()
    clip = np.empty((1, 16, VIDEO_HEIGHT, VIDEO_WIDTH, NUM_CHANNELS))

    # Fill the first 8 frames of the clip with the beginning of the video.
    for i in range(16):
        img = np.flip(next(cap), axis = 2) # We flip because skvideo color channels are backwards from cv2 channels.
        if i%2 == 0:
            clip[:, int(i/2), :, :, :] = skimage.transform.resize(img, (VIDEO_HEIGHT, VIDEO_WIDTH))

    #We then iterate through the rest of the video
    for i in range(num_frames - 16):
        #Print progress, as well as get the next video frame. 
        if i%160 == 0:
            print("\t\tProcessing clip " + str(int(i/16)+1))
            sys.stdout.flush()
        img = np.flip(next(cap), axis = 2)
        curFrame = i%16 #Helps keep the math easier since we generate features every 16 frames

        #Fill in the last 8 frames of the clip by using every other frame of the video.
        if curFrame%2 == 0:
            clip[:, int(curFrame/2)+8, :, :, :] = skimage.transform.resize(img, (VIDEO_HEIGHT, VIDEO_WIDTH))
        
        #After 16 frames of the video, we have added 8 frames onto the end of the clip. We then process this
        #clip to get the feature vectors associated with it.
        if curFrame == 15:
            #Format the input to be put into the C3D network.
            X = np.transpose(clip, (0, 4, 1, 2, 3))
            X = img_as_ubyte(X)
            pdb.set_trace()
            X = np.float32(X)
            X = torch.from_numpy(X)
            X = Variable(X)
            if use_gpu:
                X = X.cuda()
            
            #Feed the input into the network, extract network features, and place them into the feature matrices.
            prediction, pool_5, fc_6, fc_7 = net(X)
            pool_5 = pool_5.data.cpu().numpy()
            fc_6 = fc_6.data.cpu().numpy()
            fc_7 = fc_7.data.cpu().numpy()
            pool5_feats[int(i/16), :] = pool_5
            fc6_feats[int(i/16), :] = fc_6
            fc7_feats[int(i/16), :] = fc_7

            #Slide the last 8 frames of the clip to the first 8 frames.
            clip[:, 0:8, :, :, :] = clip[:, 8:16, :, :, :]
    vid.close()

    #Save the feature matrices
    np.save(out_file_path+'/pool5/'+out_file_name, pool5_feats)
    np.save(out_file_path+'/fc6/'+out_file_name, fc6_feats)
    np.save(out_file_path+'/fc7/'+out_file_name, fc7_feats)

'''
def get_video_spatial_feats(vid_file_path, net):
    #Set up the video and grab necessary video elements
    cap = skvideo.io.FFmpegReader(vid_file_path)
    num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    num_frames = num_frames - num_frames%16 #cut off last bit of video that won't be used
    num_clips = int(num_frames/16)

    #Set up variables that will contain the clips.
    clips = np.empty((num_clips, 16, VIDEO_HEIGHT, VIDEO_WIDTH, NUM_CHANNELS))
    clip = np.empty((16, VIDEO_HEIGHT, VIDEO_WIDTH, NUM_CHANNELS))
    
    for i in range(16):
        frame = cap.read()[1]
        if i%2 == 0:
            clip[int(i/2),:,:,:] = cv2.resize(frame, (VIDEO_WIDTH, VIDEO_HEIGHT))
    
    curFrame = 0
    curClip = 0
    while cap.isOpened():
        frame = cap.read()[1]
        if curFrame%2 == 0:
            clip[int(curFrame/2)+8, :,:,:] = cv2.resize(frame, (VIDEO_WIDTH, VIDEO_HEIGHT))
        curFrame += 1
        if curFrame == 16:
            clips[curClip, :, :, :, :] = clip
            clip[0:8, :, :, :] = clip [8:16, :, :, :]
            curFrame = 0
            curClip += 1

    return clips
'''
