import os
import numpy as np
import skvideo.io
import skimage
import skimage.transform
from skimage import img_as_ubyte, io
import torch
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
import torch.optim
import Models
import lmdb
import pdb
import sys
import random
VIDEO_WIDTH = 112
VIDEO_HEIGHT = 112
NUM_CHANNELS = 3
NUM_ACTIVITIES = 10
use_gpu = torch.cuda.is_available()

class SpatialToolFeatDataset(Dataset):
    def __init__(self, spatial_dir, tool_dir, transform = None):
        self.spatial_dir = spatial_dir
        self.tool_dir = tool_dir
        self.transform = transform
        self.spatial_names = os.listdir(self.spatial_dir)

    def __getitem__(self, idx):
        spatial_name = self.spatial_names[idx]
        split_name = spatial_name.strip().split('.')
        tool_name = split_name[0]+'_tools.npy'
        spatial_feats = np.load(self.spatial_dir + '/' + spatial_name)
        tool_feats = np.load(self.tool_dir + '/' + tool_name)
        feats = np.hstack((spatial_feats, tool_feats))
        feats = np.float32(feats)
        feats = torch.from_numpy(feats)
        return feats
    def __len__(self):
        return len(self.spatial_names)

class SpatialFeatDataset(Dataset):
    def __init__(self, root_dir, transform = None):
        self.root_dir = root_dir
        self.transform = transform
        self.feat_names = os.listdir(self.root_dir)
        
    def __getitem__(self, idx):
        spatial_feats_name = os.path.join(self.root_dir, self.feat_names[idx])
        feats = np.load(spatial_feats_name)
        feats = np.float32(feats)
        feats = torch.from_numpy(feats)
        return feats

    def __len__(self):
        return len(self.feat_names)

class MNistDataset(Dataset):
    def __init__(self, root_dir, dataset_size = 300, C3D = True, transform = None):
        self.root_dir = root_dir
        self.C3D = C3D
        self.act = root_dir.strip('/').split('/')[-1]
        self.dataset_size = dataset_size
        self.image_names = os.listdir(self.root_dir)
        self.transform = transform
        if len(self.image_names) > self.dataset_size:
            random.shuffle(self.image_names)
            self.image_names = self.image_names[:self.dataset_size]
        else:
            dataset_size = len(self.image_names)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir,  self.image_names[idx])
        img = io.imread(img_name)
        if self.C3D:
            out = np.empty((3, 16, 112, 112))
            img = skimage.transform.resize(img, (112, 112))
            for i in range(3):
                for j in range(16): 
                    out[i,j,:,:] = img
            out = np.float32(out)
            out = torch.from_numpy(out)
        else:
            out = np.empty((img.shape[0], img.shape[1], 3), dtype = np.dtype(np.uint8))
            for i in range(3):
                out[:,:,i] = img
            if self.transform:
                out = self.transform(out)
        return out, self.act

    def __len__(self):
        return len(self.image_names)

class ActImageDataset(Dataset):
    def __init__(self, root_dir, transform = None, feat_dim = (256, 256, 3)):
        self.root_dir = root_dir
        self.keys = []
        self.act = root_dir.strip('/').split('/')[-1]
        self.feat_dim = feat_dim
        self.transform = transform
        self._init()
    def _init(self):
        data_root = self.root_dir
        lmdb_env = lmdb.open(data_root)
        lmdb_txn = lmdb_env.begin()
        lmdb_cursor = lmdb_txn.cursor()
        keys = [k for k, _ in lmdb_cursor]
        self.keys = keys
    def __getitem__(self, idx):
        key = self.keys[idx]
        lmdb_env = lmdb.open(self.root_dir)
        lmdb_txn = lmdb_env.begin()
        lmdb_cursor = lmdb_txn.cursor()
        img = np.frombuffer(lmdb_cursor.get(key), dtype = np.dtype(np.uint8))
        img = img.reshape(self.feat_dim)
        img = np.transpose(img, (2,0,1))
        if self.transform:
            img = self.transform(img)
        return img, self.act
    def __len__(self):
        return len(self.keys)

class ActSegmentDataset(Dataset):
    def __init__(self, root_dir, transform = None, feat_dim = (3, 16, 112, 112)):
        self.root_dir = root_dir
        self.keys = []
        #self.lmdb_cursor = None
        self.feat_dim = feat_dim
        self.act = root_dir.strip('/').split('/')[-1]
        self._init()
    def _init(self):
        data_root = self.root_dir
        lmdb_env = lmdb.open(data_root)
        lmdb_txn = lmdb_env.begin()
        lmdb_cursor = lmdb_txn.cursor()
        #self.lmdb_cursor = lmdb_cursor
        keys = [k for k, _ in lmdb_cursor]
        self.keys = keys
        
    def __getitem__(self, idx):
        key = self.keys[idx]
        lmdb_env = lmdb.open(self.root_dir)
        lmdb_txn = lmdb_env.begin()
        lmdb_cursor = lmdb_txn.cursor()
        seg_vid = np.frombuffer(lmdb_cursor.get(key), dtype=np.dtype(np.uint8))
        seg_vid = seg_vid.reshape(self.feat_dim)
        seg_vid = np.float32(np.reshape(seg_vid, self.feat_dim))
        seg_vid = torch.from_numpy(seg_vid)
        return seg_vid, self.act

    def __len__(self):
        return len(self.keys)


class FlowSegmentDataset(Dataset):
    def __init__(self, root_dir, transform, feat_dim = (3, 16, 112, 112), out_dim = (16, 3, 224, 224)):
        self.root_dir = root_dir
        self.keys = []
        #self.lmdb_cursor = None
        self.feat_dim = feat_dim
        self.out_dim = out_dim
        self.transform = transform
        self.act = root_dir.strip('/').split('/')[-1]
        self._init()
    def _init(self):
        data_root = self.root_dir
        lmdb_env = lmdb.open(data_root)
        lmdb_txn = lmdb_env.begin()
        lmdb_cursor = lmdb_txn.cursor()
        #self.lmdb_cursor = lmdb_cursor
        keys = [k for k, _ in lmdb_cursor]
        self.keys = keys

    def __getitem__(self, idx):
        key = self.keys[idx]
        lmdb_env = lmdb.open(self.root_dir)
        lmdb_txn = lmdb_env.begin()
        lmdb_cursor = lmdb_txn.cursor()
        seg_vid = np.frombuffer(lmdb_cursor.get(key), dtype=np.dtype(np.uint8))
        seg_vid = seg_vid.reshape(self.feat_dim)
        seg_vid = np.transpose(seg_vid, (1,0,2,3))
        seg_vid_t = torch.FloatTensor(self.out_dim[0], self.out_dim[1], self.out_dim[2], self.out_dim[3])
        for i in range(seg_vid.shape[0]):
            img = seg_vid[i, :]
            img_t = self.transform(img)
            seg_vid_t[i, :] = img_t
        return seg_vid_t, self.act

    def __len__(self):
        return len(self.keys)


def generate_val_exemplars(num_clip_per_act, num_exem_per_act):
    #the following list will  contain a 5-tuple
    #First element details class of anchor and positive
    #Second element details class of negative
    #Third, fourth, fifth elements detail indices of anchor, positive, and negative
    exemplar_indices = list()
    multiple = len(num_clip_per_act)-1
#    if min(num_clip_per_act) < num_exem_per_act:
#        num_exem_per_act = min(num_clip_per_act)
#        print("Number of exemplars per activity too large. Setting it to " + str(num_exem_per_act))
    if num_exem_per_act % multiple != 0:
        num_exem_per_act -= num_exem_per_act % multiple
        if num_exem_per_act == 0:
            num_exem_per_act = multiple
        print("Number of exemplars was not divisible by " + str(multiple)+ ". Setting it to " + str(num_exem_per_act))
    for i in range(len(num_clip_per_act)):
        act_clip_num = num_clip_per_act[i]
        anchors = np.random.choice(act_clip_num, num_exem_per_act, replace=True)
        pair_set = set()
        pair_dict = dict()
        for j in range(num_exem_per_act):
            pos = np.random.choice(act_clip_num, 1)[0]
#            while pos == anchors[j] or (pos, anchors[j]) in pair_set or (anchors[j], pos) in pair_set:
            while pos == anchors[j]:
                pos = np.random.choice(act_clip_num, 1)[0]
            pair_set.add((anchors[j], pos))
            pair_dict[anchors[j]] = pos
        j = 0
        for k in range(len(num_clip_per_act)):
            if k == i:
                continue
            negatives = np.random.choice(int(num_clip_per_act[k]), int(num_exem_per_act/multiple), replace=False)
            for neg in negatives:
                exemplar_indices.append((i, k, anchors[j], pair_dict[anchors[j]], neg))
                j += 1
    return exemplar_indices


def get_val_indexes(num_clip_per_act, batch_size):
    num_examples = min(num_clip_per_act)
    num_examples = num_examples - (num_examples % batch_size)
    ind_matrix = np.zeros((NUM_ACTIVITIES, num_examples))
    labels = np.zeros((NUM_ACTIVITIES, num_examples))
    for i in range(NUM_ACTIVITIES):
        act_clip_num = num_clip_per_act[i]
        ind_matrix[i, :] = np.random.choice(act_clip_num, num_examples, replace=False)
        labels[i, :] = i
    return ind_matrix, labels


def select_triplets(features, alpha, num_segs_per_act):
    """ Select the triplets for training
    DISCLAIMER: https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py
    """
    trip_idx = 0
    feat_start_idx = 0
    num_trips = 0
    triplets = []
    for i in range(NUM_ACTIVITIES):
        for j in range(1, num_segs_per_act):
            a_idx = feat_start_idx + j - 1
            neg_dists_sqr = np.sum(np.square(features[a_idx] - features), 1)
            neg_dists_sqr[feat_start_idx:feat_start_idx+num_segs_per_act] = np.NaN
            for pair in range(j, num_segs_per_act):   # For every possible positive pair.
                p_idx = feat_start_idx + pair
                pos_dist_sqr = np.sum(np.square(features[a_idx] - features[p_idx]))
                semi_hard_negs = np.where(np.logical_and(neg_dists_sqr-pos_dist_sqr<alpha, pos_dist_sqr<neg_dists_sqr))[0]  # FaceNet selection
                num_semi_negs = semi_hard_negs.shape[0]
                if num_semi_negs > 0:
                    rnd_idx = np.random.randint(num_semi_negs)
                    n_idx = semi_hard_negs[rnd_idx]
                    triplets.append((a_idx, p_idx, n_idx))
                    #print('Triplet %d: (%d, %d, %d), pos_dist=%2.6f, neg_dist=%2.6f (%d, %d, %d, %d, %d)' %
                    #    (trip_idx, a_idx, p_idx, n_idx, pos_dist_sqr, neg_dists_sqr[n_idx], num_semi_negs, rnd_idx, i, j, feat_start_idx))
                    trip_idx += 1
                num_trips += 1
        feat_start_idx += num_segs_per_act
    np.random.shuffle(triplets)
    triplets_a = np.zeros(len(triplets))
    triplets_p = np.zeros(len(triplets))
    triplets_n = np.zeros(len(triplets))
    for i in range(len(triplets)):
        triplets_a[i] = triplets[i][0]
        triplets_p[i] = triplets[i][1]
        triplets_n[i] = triplets[i][2]
    return triplets_a, triplets_p, triplets_n


def get_C3D_spatial_feats(net, vid_file_path, out_file_path, out_file_name):
    # Load in video
    try:
        vid = skvideo.io.FFmpegReader(vid_file_path)
    except ValueError:
        print("\t\tIssue with video encoding. Skipping.")
        sys.stdout.flush()
        return
    
    # Set up feature matrices.
    num_frames = vid.getShape()[0]
    num_frames -= num_frames%16
    num_clips = int(num_frames/16)-1
    pool5_feats = np.empty((num_clips, 8192))
    fc6_feats = np.empty((num_clips, 4096))
    fc7_feats = np.empty((num_clips, 4096))

    # Get the iterator of video frames, as well as set up our clip matrix.
    cap = vid.nextFrame()
    clip = np.empty((1, 16, VIDEO_HEIGHT, VIDEO_WIDTH, NUM_CHANNELS))

    # Fill the first 8 frames of the clip with the beginning of the video.
    for i in range(16):
        img = np.flip(next(cap), axis = 2) # We flip because skvideo color channels are backwards from cv2 channels.
        if i%2 == 0:
            clip[:, int(i/2), :, :, :] = skimage.transform.resize(img, (VIDEO_HEIGHT, VIDEO_WIDTH))

    #We then iterate through the rest of the video
    for i in range(num_frames - 16):
        #Print progress, as well as get the next video frame. 
        if i%160 == 0:
            print("\t\tProcessing clip " + str(int(i/16)+1))
            sys.stdout.flush()
        img = np.flip(next(cap), axis = 2)
        curFrame = i%16 #Helps keep the math easier since we generate features every 16 frames

        #Fill in the last 8 frames of the clip by using every other frame of the video.
        if curFrame%2 == 0:
            clip[:, int(curFrame/2)+8, :, :, :] = skimage.transform.resize(img, (VIDEO_HEIGHT, VIDEO_WIDTH))
        
        #After 16 frames of the video, we have added 8 frames onto the end of the clip. We then process this
        #clip to get the feature vectors associated with it.
        if curFrame == 15:
            #Format the input to be put into the C3D network.
            X = np.transpose(clip, (0, 4, 1, 2, 3))
            X = img_as_ubyte(X)
            pdb.set_trace()
            X = np.float32(X)
            X = torch.from_numpy(X)
            X = Variable(X)
            if use_gpu:
                X = X.cuda()
            
            #Feed the input into the network, extract network features, and place them into the feature matrices.
            prediction, pool_5, fc_6, fc_7 = net(X)
            pool_5 = pool_5.data.cpu().numpy()
            fc_6 = fc_6.data.cpu().numpy()
            fc_7 = fc_7.data.cpu().numpy()
            pool5_feats[int(i/16), :] = pool_5
            fc6_feats[int(i/16), :] = fc_6
            fc7_feats[int(i/16), :] = fc_7

            #Slide the last 8 frames of the clip to the first 8 frames.
            clip[:, 0:8, :, :, :] = clip[:, 8:16, :, :, :]
    vid.close()

    #Save the feature matrices
    np.save(out_file_path+'/pool5/'+out_file_name, pool5_feats)
    np.save(out_file_path+'/fc6/'+out_file_name, fc6_feats)
    np.save(out_file_path+'/fc7/'+out_file_name, fc7_feats)

'''
def get_video_spatial_feats(vid_file_path, net):
    #Set up the video and grab necessary video elements
    cap = skvideo.io.FFmpegReader(vid_file_path)
    num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    num_frames = num_frames - num_frames%16 #cut off last bit of video that won't be used
    num_clips = int(num_frames/16)

    #Set up variables that will contain the clips.
    clips = np.empty((num_clips, 16, VIDEO_HEIGHT, VIDEO_WIDTH, NUM_CHANNELS))
    clip = np.empty((16, VIDEO_HEIGHT, VIDEO_WIDTH, NUM_CHANNELS))
    
    for i in range(16):
        frame = cap.read()[1]
        if i%2 == 0:
            clip[int(i/2),:,:,:] = cv2.resize(frame, (VIDEO_WIDTH, VIDEO_HEIGHT))
    
    curFrame = 0
    curClip = 0
    while cap.isOpened():
        frame = cap.read()[1]
        if curFrame%2 == 0:
            clip[int(curFrame/2)+8, :,:,:] = cv2.resize(frame, (VIDEO_WIDTH, VIDEO_HEIGHT))
        curFrame += 1
        if curFrame == 16:
            clips[curClip, :, :, :, :] = clip
            clip[0:8, :, :, :] = clip [8:16, :, :, :]
            curFrame = 0
            curClip += 1

    return clips
'''
def read_train_test_split(file_dir):
    train_file_path = file_dir + 'train_phase_labels.txt'
    test_file_path = file_dir + 'test_phase_labels.txt'
    train_file = open(train_file_path)
    test_file = open(test_file_path)
    train_list = list()
    test_list = list()
    for line in train_file:
        components = line.split()
        train_list.append('p' + components[0] + '_n' + components[1] + '_vid_' + components[2] + '.avi')
    for line in test_file:
        components = line.split()
        test_list.append('p' + components[0] + '_n' + components[1] + '_vid_' + components[2] + '.avi')
    return train_list, test_list


