#########################################
# Johns Hopkins University              #
# 601.455 Computer Integrated Surgery 2 #
# Spring 2018                           #
# Query by Video For Surgical Activities#
# Felix Yu                              #
# JHED: fyu12                           #
# Gianluca Silva Croso                  #
# JHED: gsilvac1                        #
#########################################

import os
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torch.optim

NUM_ACTIVITIES = 10
use_gpu = torch.cuda.is_available()


class SpatialToolFeatDataset(Dataset):
    """
    Dataset class that loads spatial features and tool presence information
    """
    def __init__(self, spatial_dir, tool_dir, transform=None):
        """
        Initialize the dataset class
        :param spatial_dir: directory containing spatial features npy files
        :param tool_dir:  directory containing tool presence npy files.
        Name of npy files should match corresponding spatial feature npy
        exactly except with _tools appended to the end before .npy
        :param transform: Not applicable to this dataset
        """
        self.spatial_dir = spatial_dir
        self.tool_dir = tool_dir
        self.transform = transform
        self.spatial_names = os.listdir(self.spatial_dir)

    def __getitem__(self, idx):
        """
        Retrives item from dataset
        :param idx: int the index of the item
        :return: Nx(F+T) matrix with both spatial and tool features
        N is number of frames in the clip. F is size of spatial feature
        vectors. T is number of tools possible (size of tool vector)
        """
        spatial_name = self.spatial_names[idx]
        split_name = spatial_name.strip().split('.')
        tool_name = split_name[0]+'_tools.npy'
        spatial_feats = np.load(self.spatial_dir + '/' + spatial_name)
        tool_feats = np.load(self.tool_dir + '/' + tool_name)
        feats = np.hstack((spatial_feats, tool_feats))
        feats = np.float32(feats)
        feats = torch.from_numpy(feats)
        return feats

    def __len__(self):
        """
        :return: Number of examples in this dataset
        """
        return len(self.spatial_names)


class SpatialFeatDataset(Dataset):
    """
    Dataset class that loads spatial features matrix
    """
    def __init__(self, root_dir, transform=None):
        """
        Initialize the dataset class
        :param root_dir: directory containing spatial features npy files
        :param transform: Not applicable to this dataset
        """
        self.root_dir = root_dir
        self.transform = transform
        self.feat_names = os.listdir(self.root_dir)
        
    def __getitem__(self, idx):
        """
        Retrives item from dataset
        :param idx: int the index of the item
        :return: NxF matrix with both spatial and tool features
        N is number of frames in the clip. F is size of spatial feature
        vectors.
        """
        spatial_feats_name = os.path.join(self.root_dir, self.feat_names[idx])
        feats = np.load(spatial_feats_name)
        feats = np.float32(feats)
        feats = torch.from_numpy(feats)
        return feats

    def __len__(self):
        """
        :return: Number of examples in this dataset
        """
        return len(self.feat_names)


def generate_val_exemplars(num_clip_per_act, num_exem_per_act):
    """
    Creates a list of triplets to be used as validation examples during
    training.
    The following elements are 5-tuples:
    First element details class of anchor and positive
    Second element details class of negative
    Third, fourth, fifth elements detail indices of anchor, positive,
    and negative example respectively within their classes
    :param num_clip_per_act: list of size NUM_ACTIVITIES in which element
    at index i-1 represents number of available clips for activity i
    :param num_exem_per_act: Int. Number of triplets desired with anchors
    at each activity
    :return: list of 5 tuples described above
    """
    exemplar_indices = list()
    multiple = len(num_clip_per_act)-1
#    if min(num_clip_per_act) < num_exem_per_act:
#        num_exem_per_act = min(num_clip_per_act)
#        print("Number of exemplars per activity too large. Setting it to " + str(num_exem_per_act))
    if num_exem_per_act % multiple != 0:
        num_exem_per_act -= num_exem_per_act % multiple
        if num_exem_per_act == 0:
            num_exem_per_act = multiple
        print("Number of exemplars was not divisible by " + str(multiple) + ". Setting it to " + str(num_exem_per_act))
    for i in range(len(num_clip_per_act)):
        act_clip_num = num_clip_per_act[i]
        anchors = np.random.choice(act_clip_num, num_exem_per_act, replace=True)
        pair_set = set()
        pair_dict = dict()
        for j in range(num_exem_per_act):
            pos = np.random.choice(act_clip_num, 1)[0]
#            while pos == anchors[j] or (pos, anchors[j]) in pair_set or (anchors[j], pos) in pair_set:
            while pos == anchors[j]:
                pos = np.random.choice(act_clip_num, 1)[0]
            pair_set.add((anchors[j], pos))
            pair_dict[anchors[j]] = pos
        j = 0
        for k in range(len(num_clip_per_act)):
            if k == i:
                continue
            negatives = np.random.choice(int(num_clip_per_act[k]), int(num_exem_per_act/multiple), replace=False)
            for neg in negatives:
                exemplar_indices.append((i, k, anchors[j], pair_dict[anchors[j]], neg))
                j += 1
    return exemplar_indices


def select_triplets(features, alpha, num_items_per_act):
    """ Select semi-hard triplets for training backpropagation
    Semi-hard triplets have distance difference between 0 and alpha
    where distance difference is dist(anchor, negative) - dist(anchor, positive)
    DISCLAIMER: based on triplet loss training suggestion available at
    https://github.com/davidsandberg/facenet/blob/master/src/train_tripletloss.py
    :param features: (NUM_ACTIVITIES*num_items_per_act)xF matrix with
    features for each example in the batch. F is the size of the feature vector.
    :param alpha: margin being used for triplet loss training
    :param num_items_per_act: integer number of examples being considered
    for each activity
    """
    trip_idx = 0
    feat_start_idx = 0
    num_trips = 0
    triplets = []
    for i in range(NUM_ACTIVITIES):
        for j in range(1, num_items_per_act):
            a_idx = feat_start_idx + j - 1
            neg_dists_sqr = np.sum(np.square(features[a_idx] - features), 1)
            neg_dists_sqr[feat_start_idx:feat_start_idx + num_items_per_act] = np.NaN
            for pair in range(j, num_items_per_act):   # For every possible positive pair.
                p_idx = feat_start_idx + pair
                pos_dist_sqr = np.sum(np.square(features[a_idx] - features[p_idx]))
                semi_hard_negs = np.where(np.logical_and(neg_dists_sqr-pos_dist_sqr < alpha,
                                                         pos_dist_sqr < neg_dists_sqr))[0]  # FaceNet selection
                num_semi_negs = semi_hard_negs.shape[0]
                if num_semi_negs > 0:
                    rnd_idx = np.random.randint(num_semi_negs)
                    n_idx = semi_hard_negs[rnd_idx]
                    triplets.append((a_idx, p_idx, n_idx))
                    # print('Triplet %d: (%d, %d, %d), pos_dist=%2.6f, neg_dist=%2.6f (%d, %d, %d, %d, %d)' %
                    #    (trip_idx, a_idx, p_idx, n_idx, pos_dist_sqr, neg_dists_sqr[n_idx], num_semi_negs,
                    #     rnd_idx, i, j, feat_start_idx))
                    trip_idx += 1
                num_trips += 1
        feat_start_idx += num_items_per_act
    np.random.shuffle(triplets)
    triplets_a = np.zeros(len(triplets))
    triplets_p = np.zeros(len(triplets))
    triplets_n = np.zeros(len(triplets))
    for i in range(len(triplets)):
        triplets_a[i] = triplets[i][0]
        triplets_p[i] = triplets[i][1]
        triplets_n[i] = triplets[i][2]
    return triplets_a, triplets_p, triplets_n
