#Takes all features in the directory (stored as numpy arrays) and visualizes them using
#dimensionality reduction methods.

import numpy as np
import matplotlib.pyplot as plt
import os
import sys
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import pdb

def visualize_npy_file(input_dir):
    
    print(' ')
    print('Running script to visualize features from data in ' + input_dir)
    print('--------------------------------------------------------------------')
    data = np.load(input_dir + 'train.npy')
    num_input = data.shape[0]
    colors = np.empty(num_input)
    label_file = open(input_dir + 'train_key.txt', 'r')
    i = 0
    for line in label_file:
        colors[i] = int(line.strip().split()[2]) 
        i+=1
    print('Running PCA')
    data = PCA(n_components = 2).fit_transform(data)
    print('Plotting.')
    fig = plt.figure()
    ax = plt.subplot(111)
    data_plot = ax.scatter(data[:,0], data[:,1], c = colors, cmap = 'tab10', alpha = 0.7)
    plt.title('PCA with 2 Components on Training Encodings')
    # Shrink current axis by 20%
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    size = 81 
    lp = lambda i: plt.plot([],color=data_plot.cmap(data_plot.norm(i)), ms=np.sqrt(size), mec="none",
                            label="Phase {:g}".format(i), ls="", marker="o")[0]

    handles = [lp(i) for i in np.unique(colors)]
    plt.legend(handles=handles, bbox_to_anchor=(1, 1))
    plt.show()

def visualize_features(input_dir):
    print(' ')
    print('Running script to visualize features from data in ' + input_dir)
    print('--------------------------------------------------------------------')
    
    data_dir = '/home-2/fyu12@jhu.edu/scratch/ReiterLab/cataract_densenet_output/' + input_dir + '/'
    file_list = os.listdir(data_dir)
    
    print('Constructing raw data matrix.')
    #Get the number of files
    for f in file_list:
        if not f[-3:] == 'npy':
            file_list.remove(f)
    num_file = len(file_list)
    num_feat = np.load(data_dir+file_list[0]).shape[1] 
    data_points = np.empty((num_file, num_feat))
    colors = np.empty(num_file)
    for i in range(len(file_list)):
        f = file_list[i]
        phase = int(f.split('_')[0][1:])
        colors[i] = phase*0.1
        data = np.load(data_dir+f)
        data = np.mean(data, axis = 0)
        #data_points[i,:] = np.linalg.norm(data)        
        data_points[i,:] = data        

    print('Running PCA')
    data_points = PCA(n_components = 100).fit_transform(data_points)

    print('Running t-SNE')
    data_points = TSNE(n_components = 2).fit_transform(data_points)
     
    print('Plotting.')
    plt.scatter(data_points[:,0], data_points[:,1], c = colors, alpha = 0.5)
    plt.title('t-SNE with 2 components for: ' + input_dir)
    plt.show()

if __name__ == "__main__":
    data_dir = sys.argv[1]
    visualize_npy_file(data_dir)
