#########################################
# Johns Hopkins University              #
# 601.455 Computer Integrated Surgery 2 #
# Spring 2018                           #
# Query by Video For Surgical Activities#
# Felix Yu                              #
# JHED: fyu12                           #
# Gianluca Silva Croso                  #
# JHED: gsilvac1                        #
#########################################

import numpy as np
import matplotlib.pyplot as plt
import sys
from sklearn.decomposition import PCA


def visualize_npy_file(input_dir):
    """
    Takes all features in the directory (stored as numpy arrays) and visualizes them using
    dimensionality reduction methods.
    :param input_dir: Directory containing npy files
    """
    print(' ')
    print('Running script to visualize features from data in ' + input_dir)
    print('--------------------------------------------------------------------')
    data = np.load(input_dir + 'train.npy')
    num_input = data.shape[0]
    colors = np.empty(num_input)
    label_file = open(input_dir + 'train_key.txt', 'r')
    i = 0
    for line in label_file:
        colors[i] = int(line.strip().split()[2])
        i += 1
    print('Running PCA')
    data = PCA(n_components=2).fit_transform(data)
    print('Plotting.')
    fig = plt.figure()
    ax = plt.subplot(111)
    data_plot = ax.scatter(data[:, 0], data[:, 1], c=colors, cmap='tab10', alpha=0.7)
    plt.title('PCA with 2 Components on Training Encodings')
    # Shrink current axis by 20%
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    size = 81
    lp = lambda i: plt.plot([],color=data_plot.cmap(data_plot.norm(i)), ms=np.sqrt(size), mec="none",
                            label="Phase {:g}".format(i), ls="", marker="o")[0]

    handles = [lp(i) for i in np.unique(colors)]
    plt.legend(handles=handles, bbox_to_anchor=(1, 1))
    plt.show()


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python visualize_features.py <input_dir>")
    else:
        data_dir = sys.argv[1]
        visualize_npy_file(data_dir)
