import time
import os
import math
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import Models
import helpers
import gc 

import pdb
import sys

use_gpu = torch.cuda.is_available()
weight_out_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/weights/ResNet_CE/'

NUM_ACTIVITIES = 10
BATCH_IMGS_PER_ACTIVITY = 10
VAL_BATCH_SIZE = 30
dropout = 0.5
FINAL_LAYER_DIM = 128

train_data_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Scale((224, 224)),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),])

val_data_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Scale((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),

  ])

def train_model(model, criterion, optimizer, scheduler, num_epochs=100):

    data_dir = '/home-2/fyu12@jhu.edu/work3/fyu/CATARACT/lmdb_files/res_net_images'

    best_acc = 0.0

    train_datasets = list()
    val_datasets = list()
    for phase in ['train', 'val']:
        for i in range(NUM_ACTIVITIES):
            act = i+1
            print('Currently loading dataset for activity ' + str(act) + ' in phase ' + phase)
            sys.stdout.flush()
            lmdb_dir = data_dir + '/' + phase + '/' + str(act)
            if phase == 'train':
                train_datasets.append(helpers.ActImageDataset(lmdb_dir, train_data_transform))
            if phase == 'val':
                val_datasets.append(helpers.ActImageDataset(lmdb_dir, val_data_transform))
    print()
    sys.stdout.flush()

    train_dataloaders = list()
    for i in range(NUM_ACTIVITIES):
        train_dataloaders.append(DataLoader(train_datasets[i], batch_size=BATCH_IMGS_PER_ACTIVITY,
                                            shuffle=True, num_workers=1))
    max_iter_per_epoch_train = math.inf
    for i in range(len(train_datasets)):
        size_activity = int(len(train_datasets[i])/BATCH_IMGS_PER_ACTIVITY)
        if size_activity < max_iter_per_epoch_train:
            max_iter_per_epoch_train = size_activity

    train_labels = torch.LongTensor(NUM_ACTIVITIES*BATCH_IMGS_PER_ACTIVITY)
    for i in range(NUM_ACTIVITIES):
        train_labels[i*BATCH_IMGS_PER_ACTIVITY:(i+1)*BATCH_IMGS_PER_ACTIVITY] = i

    num_clip_per_act = list()
    for i in range(len(val_datasets)):
        num_clip_per_act.append(len(val_datasets[i]))
    val_inds, val_labels = helpers.get_val_indexes(num_clip_per_act, VAL_BATCH_SIZE)

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # TRAINING
        model.train(True)  # Set model to training mode
        max_iter_per_epoch = max_iter_per_epoch_train
        dataloaders = train_dataloaders
        running_loss = 0.0
        running_corrects = 0
        total_predictions = 0
        # Iterate over data.
        for i in range(max_iter_per_epoch):
            # get the inputs
            inputs = torch.FloatTensor(BATCH_IMGS_PER_ACTIVITY*NUM_ACTIVITIES, 3, 224, 224)
            for j in range(NUM_ACTIVITIES):
                inputs[BATCH_IMGS_PER_ACTIVITY*j:BATCH_IMGS_PER_ACTIVITY*(j+1),:] = next(iter(dataloaders[j]))[0]
            total_predictions += NUM_ACTIVITIES*BATCH_IMGS_PER_ACTIVITY
            # wrap them in Variable
            if use_gpu:
                inputs = Variable(inputs.cuda()).float()
                label = Variable(train_labels.cuda()).long()
            else:
                inputs = Variable(inputs).float()
                label = Variable(train_labels).long()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            outputs = model(inputs)
            _, preds = torch.max(outputs.data, 1)
            #pdb.set_trace()
            loss = criterion(outputs, label)

            # backward + optimize only if in training phase
            loss.backward()
            optimizer.step()

            # statistics
            running_loss += loss.data[0]
            running_corrects += torch.sum(preds == label.data)

            print("\t batch: {:d}/{:d}, Average Loss: {:3f}".format(i+1, max_iter_per_epoch, running_loss/float(total_predictions)), end='\r')
            sys.stdout.flush()
            del inputs
            del label
            del outputs
            del preds
            del loss
            gc.collect()

        epoch_acc = running_corrects / float(total_predictions)
        print('Done training epoch {:d}. Accuracy: {:.5f}'.format(epoch+1, epoch_acc))

        # VALIDATION
        model.train(False)
        running_loss = 0.0
        running_corrects = 0
        total_predictions = 0
        # Iterate over data.
        for i in range(NUM_ACTIVITIES):
            for j in range(0, val_inds.shape[1], VAL_BATCH_SIZE):
                inds = val_inds[i, j:j+VAL_BATCH_SIZE]
                labels = val_labels[i, j:j+VAL_BATCH_SIZE]
                labels = torch.from_numpy(labels)
                inputs = torch.FloatTensor(VAL_BATCH_SIZE, 3, 224, 224)
                for x in range(inds.shape[0]):
                    inputs[x,:] = val_datasets[i][int(inds[x])][0]
                total_predictions += VAL_BATCH_SIZE
                # wrap them in Variable
                if use_gpu:
                    inputs = Variable(inputs.cuda()).float()
                    labels = Variable(labels.cuda()).long()
                else:
                    inputs = Variable(inputs).float()
                    labels = Variable(labels).long()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                #pdb.set_trace()
                loss = criterion(outputs, labels)
                # statistics
                running_loss += loss.data[0]
                running_corrects += torch.sum(preds == labels.data)
                print("Act {:d}/{:d}, batch: {:d}/{:d}, Average Loss: {:3f}".format(i+1, NUM_ACTIVITIES,
                        int(j/VAL_BATCH_SIZE) + 1, int(val_inds.shape[1]/VAL_BATCH_SIZE), running_loss/float(total_predictions)), end='\r')
                sys.stdout.flush()
                del labels
                del inputs
                del outputs
                del loss
                gc.collect()

        epoch_acc = running_corrects / float(total_predictions)
        print('Done validating epoch {:d}. Accuracy: {:.5f}'.format(epoch+1, epoch_acc))

        # deep copy the model
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = model.state_dict()

            if not os.path.exists(weight_out_dir):
                os.makedirs(weight_out_dir)

            torch.save(best_model_wts, weight_out_dir + '/{:03d}_{:.3f}.pkl'.format(epoch, epoch_acc))
        print()

    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


def main():
    model = models.resnet18(pretrained=True)
    ct = 0 
    for child in model.children():
        ct += 1
        if ct < 7:
            for param in child.parameters():
                param.requires_grad = False
    # Parameters of newly constructed modules have requires_grad=True by default
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, FINAL_LAYER_DIM)

    criterion = nn.CrossEntropyLoss()
    num_epochs = 100
    lr = 0.001
    weight_decay = 1e-4
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,model.parameters()), lr=lr, weight_decay=weight_decay)

    # Decay LR by a factor of 0.1 every 40 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                           step_size=130,
                                           gamma=0.1)
    print("Starting!")
    if use_gpu:
        print("using gpu")
        model = model.cuda()
    sys.stdout.flush()
    model = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=num_epochs)


if __name__ == "__main__":
    main()
