import numpy as np
import random
import math
import sys
import os

# TOOL_DICT maps a phase to another dict, which maps tool identifiers to a 2-tuple
# which is (prob, length), both floats between 0 and 1, the first representing the
# likelihood of appearance of the tool in question in that phase, and the second
# the proportion of the frames in the phase in which the tool is present. If the
# second is 0.0, the length will be random betweeen MIN_RAND_FRAMES and MAX_RAND_FRAMES
TOOL_DICT = {'1': {1:(1.0, 1.0)},
             '2': {2:(1.0, 1.0), 4:(0.1, 0.0)},
             '3': {6: (1.0, 0.2), 7: (1.0, 0.8), 3: (0.1, 0.0), 15: (0.4, 0.0)},
             '4': {8: (1.0, 1.0)},
             '5': {9: (1.0, 1.0), 15: (0.3, 0.0)},
             '6': {10: (1.0, 1.0)},
             '7': {11: (1.0, 1.0), 12: (1.0, 0.4)},
             '8': {10: (1.0, 1.0), 3: (0.1, 0.0)},
             '9': {10: (1.0, 1.0), 14: (1.0, 0.2)},
             '10': {16: (1.0, 1.0), 4: (1.0, 1.0)}}

MIN_RAND_FRAMES = 0.1
MAX_RAND_FRAMES = 0.3
NUM_TOOLS = 16
data_dir = '/home-2/gsilvac1@jhu.edu/work/gsilvac1/CATARACT/outputs/squeezenet_output/'
output_dir = '/home-2/gsilvac1@jhu.edu/work/gsilvac1/CATARACT/outputs/tools_output/'
data_phases = ['train', 'val', 'test']

def append_tool_labels(features, phase):
    """
    Given a spatial feature matrix and a phase append to each frame simulated 
    tool information based on TOOL_DICT
    :param features: NxM spatial feature matrix (N is number of frames,
    M is spatial feature vector size)
    :param phase: int between 1 and 10 with the identifier of the phase
    :return features with appended tool labels
    """
    N = features.shape[0]
    tool_matrix = np.zeros((N, NUM_TOOLS))
    phase_tools = TOOL_DICT[phase]
    for tool in phase_tools:
        prob, duration = phase_tools[tool]
        if random.random() <= prob:
            if duration == 0.0:
                duration = random.uniform(MIN_RAND_FRAMES, MAX_RAND_FRAMES)
            num_frames = int(math.floor(N*duration))
            start_frame = random.randint(0, N-num_frames)
            end_frame = start_frame + num_frames
            tool_matrix[start_frame:end_frame, tool-1] = 1
    #return np.hstack((features, tool_matrix))
    return tool_matrix

def main():
    train_npy_names = list()
    val_npy_names = list()
    test_npy_names = list()
    for data_phase in data_phases:
        data_phase_dir = data_dir + data_phase + "/"
        activities = os.listdir(data_phase_dir)
        for activity in activities:
            data_phase_activity_dir = data_phase_dir + activity + "/"
            npy_names = os.listdir(data_phase_activity_dir)
            for npy_name in npy_names:
                if data_phase == "train":
                    train_npy_names.append(data_phase+"/"+activity+"/"+npy_name)
                if data_phase == "val":
                    val_npy_names.append(data_phase+"/"+activity+"/"+npy_name)
                if data_phase == "test":
                    test_npy_names.append(data_phase+"/"+activity+"/"+npy_name)

    for data_phase in data_phases:
        print('Looking at phase: ' + data_phase)
        npy_file_names = list()
        if data_phase == 'train':
            npy_file_names = train_npy_names
        elif data_phase == 'val':
            npy_file_names = val_npy_names
        elif data_phase == 'test':
            npy_file_names = test_npy_names
        for npy_name in npy_file_names:
            temp = npy_name.split('/')[-1]
            activity_num = temp.split('_')[0][1:]
            spatial_feats = np.load(data_dir + npy_name)
            extra_feats = append_tool_labels(spatial_feats, activity_num)
            npy_name = npy_name.split('.')
            np.save(output_dir + npy_name[0] + '_tools.npy', extra_feats)



if __name__ == "__main__":
    main()
