import time
import math
import os
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim
from torch.optim import lr_scheduler
import Models
import helpers

import pdb
import sys
import gc 

use_gpu = torch.cuda.is_available()
weight_out_dir = 'weights'

ALPHA = 0.1
FINAL_LAYER_DIM = 128
BATCH_SEGS_PER_ACTIVITY = 5
NUM_ACTIVITIES = 10
VAL_NUM_EXEM_PER_ACT = 9


def train_model(model, criterion, optimizer, scheduler, num_epochs=100):

    data_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/lmdb_files/'

    best_acc = 0.0

    train_txn, val_txn, train_names, val_names = helpers.load_lmdb_txns(data_dir)
    max_iter_per_epoch = math.inf

    labels = np.zeros(BATCH_SEGS_PER_ACTIVITY*NUM_ACTIVITIES, dtype = np.uint8)
    for i in range(NUM_ACTIVITIES):
        labels[i*BATCH_SEGS_PER_ACTIVITY:(i+1)*BATCH_SEGS_PER_ACTIVITY] = i

    for i in range(len(train_names)):
        size_activity = int(len(train_names[i])/BATCH_SEGS_PER_ACTIVITY)
        if size_activity < max_iter_per_epoch:
            max_iter_per_epoch = size_activity
    no_repeat_tracker = list()

    for i in range(len(train_names)):
        no_repeat_tracker.append(np.zeros(len(train_names[i])))

    num_clip_per_act = list()
    for i in range(len(val_names)):
        num_clip_per_act.append(len(val_names[i]))

    val_set = helpers.generate_val_exemplars(num_clip_per_act, VAL_NUM_EXEM_PER_ACT)

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # TRAINING
        max_iter_per_epoch = 2
        model.train(True)  # Set model to training mode
        print("Currently on epoch: " + str(epoch))
        for i in range(max_iter_per_epoch):
            seg_batch = np.zeros((BATCH_SEGS_PER_ACTIVITY*NUM_ACTIVITIES, 3, 16, 112, 112),
                     dtype=np.float32)
            # get the inputs
            print("\tCurrently on batch: " + str(i))
            batch_names = helpers.get_batch(train_names, no_repeat_tracker, BATCH_SEGS_PER_ACTIVITY)
            optimizer.zero_grad()
            for j in range(len(batch_names)):
                activity = labels[j]
                seg_vid = np.fromstring(train_txn[activity].get(batch_names[j]), dtype = np.uint8)
                seg_vid = np.float32(np.reshape(seg_vid, (3, 16, 112, 112)))
                seg_batch[j,:,:,:,:] = seg_vid
            seg_batch = Variable(torch.from_numpy(seg_batch), requires_grad=True)
            if use_gpu:
                seg_batch = seg_batch.cuda()
            print("\t\t FORWARDING!")
            seg_batch_feats = model(seg_batch)
            features = seg_batch_feats.data.cpu().numpy()
            print("\t\t FINDING TRIPLETS!")
            trip_a, trip_p, trip_n = helpers.select_triplets(features, ALPHA)
            trip_a = Variable(torch.LongTensor(trip_a))
            trip_p = Variable(torch.LongTensor(trip_p))
            trip_n = Variable(torch.LongTensor(trip_n))
            if use_gpu:
                trip_a = trip_a.cuda()
                trip_p = trip_p.cuda()
                trip_n = trip_n.cuda()
            a = seg_batch_feats.index_select(0, trip_a)
            p = seg_batch_feats.index_select(0, trip_p)
            n = seg_batch_feats.index_select(0, trip_n)

            # pdb.set_trace()
            print("\t\t CALCULATING LOSS!")
            loss = criterion(a, p, n)
            # backward + optimize only if in training phase
            print("\t\t BACKPROPOGATING!")
            loss.backward()
            optimizer.step()

        # VALIDATION
        model.train(False)  # Set model to evaluate mode
        success = 0
        for i in range(len(val_set)):
            print("Validating triplet : " + str(i))
            sys.stdout.flush()
            seg_batch = np.zeros((3, 3, 16, 112, 112),
                                 dtype=np.float32)
            triplet = val_set[i]
            anch_act = triplet[0]
            neg_act = triplet[1]
            anch_ind = triplet[2]
            pos_ind = triplet[3]
            neg_ind = triplet[4]
            anch_vid = np.fromstring(val_txn[anch_act].get(val_names[anch_act][anch_ind]), dtype = np.uint8)
            anch_vid = np.float32(np.reshape(anch_vid, (3, 16, 112, 112)))
            pos_vid = np.fromstring(val_txn[anch_act].get(val_names[anch_act][pos_ind]), dtype = np.uint8)
            pos_vid = np.float32(np.reshape(pos_vid, (3, 16, 112, 112)))
            neg_vid = np.fromstring(val_txn[neg_act].get(val_names[neg_act][neg_ind]), dtype = np.uint8)
            neg_vid = np.float32(np.reshape(neg_vid, (3, 16, 112, 112)))
            seg_batch[0,:,:,:,:] = anch_vid
            seg_batch[1,:,:,:,:] = pos_vid
            seg_batch[2,:,:,:,:] = neg_vid
            seg_batch = Variable(torch.from_numpy(seg_batch), requires_grad=True)
            if use_gpu:
                seg_batch = seg_batch.cuda()
            seg_batch_feats = model(seg_batch)
             
            seg_batch_feats = seg_batch_feats.data.cpu().numpy()
            dist_1 = np.linalg.norm(seg_batch_feats[0]-seg_batch_feats[1])
            dist_2 = np.linalg.norm(seg_batch_feats[0]-seg_batch_feats[2])
            loss = dist_1 - dist_2 + ALPHA       

            # Remove references to try and free up memory
            seg_batch = 0
            seg_batch_feats = 0
            anch_vid = 0
            pos_vid = 0
            neg_vid = 0
            gc.collect()
 
            if loss <= 0:
                success += 1

        epoch_acc = float(success)/float(len(val_set))
        # Print statistics
        print("Epoch {:d}, accuracy {:.3f}".format(epoch, epoch_acc))
        # deep copy the model
        if epoch_acc >= best_acc:
            best_acc = epoch_acc
            best_model_wts = model.state_dict()

            if not os.path.exists(weight_out_dir):
                os.makedirs(weight_out_dir)

            torch.save(best_model_wts, weight_out_dir + '/{:03d}_{:.3f}.pkl'.format(epoch, epoch_acc))
            
            # Remove references to try and free up memory
            best_model_wts = 0
            gc.collect()
        print()
        sys.stdout.flush()

    print('Best val Acc: {:4f}'.format(best_acc))
    sys.stdout.flush()

    # load best model weights
    #model.load_state_dict(best_model_wts)
    #return model

def main():
    model = Models.C3D()

    criterion = nn.TripletMarginLoss(margin = ALPHA, p = 2)
    num_epochs = 100
    lr = 0.001
    weight_decay = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Decay LR by a factor of 0.1 every 40 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=130,
                                           gamma=0.1)
    print("Starting!")
    if use_gpu:
        print("using gpu")
        model = model.cuda()
    sys.stdout.flush()
    model = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=num_epochs)



if __name__ == "__main__":
    main()
