import time
import math
import os
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import Models
import helpers

import pdb
import sys
import gc 

use_gpu = torch.cuda.is_available()
weight_out_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/weights/C3D/alpha_05'

ALPHA = 0.20
FINAL_LAYER_DIM = 128
BATCH_SEGS_PER_ACTIVITY = 4
NUM_ACTIVITIES = 10
VAL_NUM_EXEM_PER_ACT = 72


def train_model(model, criterion, optimizer, scheduler, num_epochs=100):

    data_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/lmdb_files/'

    best_acc = 0.0

    train_datasets = list()
    val_datasets = list()
    for phase in ['train', 'val']:
        for i in range(NUM_ACTIVITIES):
            act = i+1
            print('Currently loading dataset for activity ' + str(act) + ' in phase ' + phase)
            sys.stdout.flush()
            lmdb_dir = data_dir + '/' + phase + '/' + str(act)
            if phase == 'train':
                train_datasets.append(helpers.ActSegmentDataset(lmdb_dir))
            if phase == 'val':
                val_datasets.append(helpers.ActSegmentDataset(lmdb_dir))
    print()
    sys.stdout.flush()
    train_dataloaders = list()
    for i in range(NUM_ACTIVITIES):
        train_dataloaders.append(DataLoader(train_datasets[i], batch_size = BATCH_SEGS_PER_ACTIVITY,
                                            shuffle = True, num_workers = 1))

    max_iter_per_epoch = math.inf

    for i in range(len(train_datasets)):
        size_activity = int(len(train_datasets[i])/BATCH_SEGS_PER_ACTIVITY)
        if size_activity < max_iter_per_epoch:
            max_iter_per_epoch = size_activity

    num_clip_per_act = list()
    for i in range(len(val_datasets)):
        num_clip_per_act.append(len(val_datasets[i]))

    val_set = helpers.generate_val_exemplars(num_clip_per_act, VAL_NUM_EXEM_PER_ACT)

    for epoch in range(num_epochs):
        print('\t\t  Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 55)
        sys.stdout.flush()
        running_loss = 0.0
        running_num_exemplar = 0
        # TRAINING
        model.train(True)  # Set model to training mode
        for i in range(max_iter_per_epoch):
            # get the inputs            
            seg_vids = torch.FloatTensor(BATCH_SEGS_PER_ACTIVITY*NUM_ACTIVITIES, 3, 16, 112, 112)
            for j in range(NUM_ACTIVITIES):
                seg_vids[BATCH_SEGS_PER_ACTIVITY*j:BATCH_SEGS_PER_ACTIVITY*(j+1),:] = next(iter(train_dataloaders[j]))[0]

            optimizer.zero_grad()
            seg_vids = Variable(seg_vids, requires_grad=True)
            if use_gpu:
                seg_vids = seg_vids.cuda()
            seg_batch_feats = model(seg_vids)
            features = seg_batch_feats.data.cpu().numpy()
            trip_a, trip_p, trip_n = helpers.select_triplets(features, ALPHA, BATCH_SEGS_PER_ACTIVITY)
            num_trip = len(trip_a)
            running_num_exemplar += num_trip
            trip_a = Variable(torch.LongTensor(trip_a))
            trip_p = Variable(torch.LongTensor(trip_p))
            trip_n = Variable(torch.LongTensor(trip_n))
            if use_gpu:
                trip_a = trip_a.cuda()
                trip_p = trip_p.cuda()
                trip_n = trip_n.cuda()
            a = seg_batch_feats.index_select(0, trip_a)
            p = seg_batch_feats.index_select(0, trip_p)
            n = seg_batch_feats.index_select(0, trip_n)
            # pdb.set_trace()
            loss = criterion(a, p, n)
            running_loss += loss.data[0]*num_trip
            # backward + optimize only if in training phase
            loss.backward()
            optimizer.step()
            #Clean out memory 
            del trip_a, trip_p, trip_n
            del a, p, n
            del seg_batch_feats
            del features
            gc.collect()
            print("\t batch: {:d}/{:d}, Average Loss: {:3f}".format(i+1, max_iter_per_epoch, running_loss/float(running_num_exemplar)), end = '\r')
            sys.stdout.flush()
        print('Done training epoch {:d}. Average training loss: {:.5f}'.format(epoch+1, running_loss/float(running_num_exemplar)))
        sys.stdout.flush()
        # VALIDATION
        model.train(False)  # Set model to evaluate mode
        success = 0
        smaller_success = 0
        running_loss = 0.0
        avg_loss = 0.0
        for i in range(len(val_set)):
            print('Validating exemplar number: {:d}/{:d}, with average loss of: {:.5f}'.format(i+1, len(val_set), avg_loss), end = '\r')
            sys.stdout.flush()
            seg_batch = torch.FloatTensor(3, 3, 16, 112, 112)
            triplet = val_set[i]
            anch_act = triplet[0]
            neg_act = triplet[1]
            anch_ind = triplet[2]
            pos_ind = triplet[3]
            neg_ind = triplet[4]
            anch_vid = val_datasets[anch_act][anch_ind][0]
            pos_vid = val_datasets[anch_act][pos_ind][0]
            neg_vid = val_datasets[neg_act][neg_ind][0]
            seg_batch[0,:,:,:,:] = anch_vid
            seg_batch[1,:,:,:,:] = pos_vid
            seg_batch[2,:,:,:,:] = neg_vid
            seg_batch = Variable(seg_batch, requires_grad = False)
            if use_gpu:
                seg_batch = seg_batch.cuda()
            seg_batch_feats = model(seg_batch)
             
            seg_batch_feats = seg_batch_feats.data.cpu().numpy()
            dist_1 = np.linalg.norm(seg_batch_feats[0]-seg_batch_feats[1])
            dist_2 = np.linalg.norm(seg_batch_feats[0]-seg_batch_feats[2])
            loss = dist_1 - dist_2 + ALPHA       
            running_loss += loss
            avg_loss = running_loss/float(i+1)
            # Remove references to try and free up memory
            del seg_batch
            del seg_batch_feats
            del anch_vid
            del pos_vid
            del neg_vid
            gc.collect()
 
            if loss <= 0:
                success += 1
            if loss <= ALPHA:
                smaller_success += 1

        epoch_acc = float(success)/float(len(val_set))
        smaller_epoch_acc = float(smaller_success)/float(len(val_set))
        # Print statistics
        print("Validation Results for Epoch {:d}, accuracy: {:.3f}, smaller accuracy: {:.3f}, average loss: {:.3f}".format(epoch+1, epoch_acc, smaller_epoch_acc, avg_loss))
        print()
        sys.stdout.flush()
        # deep copy the model
        if epoch_acc >= best_acc:
            best_acc = epoch_acc
            best_model_wts = model.state_dict()

            if not os.path.exists(weight_out_dir):
                os.makedirs(weight_out_dir)

            torch.save(best_model_wts, weight_out_dir + '/{:03d}_{:.3f}.pkl'.format(epoch, epoch_acc))
            
            # Remove references to try and free up memory
            best_model_wts = 0
            gc.collect()

    print('Best val Acc: {:4f}'.format(best_acc))
    sys.stdout.flush()

    # load best model weights
    #model.load_state_dict(best_model_wts)
    #return model

def main():
    model = Models.C3D()

    criterion = nn.TripletMarginLoss(margin = ALPHA, p = 2)
    num_epochs = 50
    lr = 0.001
    weight_decay = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Decay LR by a factor of 0.1 every 40 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=130,
                                           gamma=0.1)
    print("Starting!")
    if use_gpu:
        print("using gpu")
        model = model.cuda()
    sys.stdout.flush()
    model = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=num_epochs)



if __name__ == "__main__":
    main()
