#########################################
# Johns Hopkins University              #
# 601.455 Computer Integrated Surgery 2 #
# Spring 2018                           #
# Query by Video For Surgical Activities#
# Code by Tae-Soo Kim                   #
#########################################

import os
import time
import matplotlib.pyplot as plt
import numpy as np
import pdb
import sys
from torchvision import transforms
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.parallel

from torch.autograd import Variable

import torch.optim
from torch.optim import lr_scheduler

import torch.utils.data
import torch.utils.data.distributed

from cataract_data_utils import CataractDataset

use_gpu = torch.cuda.is_available()
CAT_n_classes = 22
n_classes = 11
batch_size = 128
weight_out_dir = 'SqueezeNet_phase'

dropout = 0.0


def main():
  ## Train params
  num_epochs = 100
  import pretrainedmodels
  model = pretrainedmodels.__dict__['squeezenet1_0'](num_classes=1000, pretrained='imagenet')
  ## FINETUNING
  dim_feats = 512
  model.last_conv = nn.Conv2d(dim_feats, n_classes, kernel_size=(1, 1), stride=(1, 1))

  #state_dict = torch.load('weights/SqueezeNet_phase/004_0.441.pkl',
                          #map_location=lambda storage, loc: storage)
  #model.load_state_dict(state_dict)

  criterion = nn.CrossEntropyLoss()
  print("Starting!")
  if use_gpu:
    print("using gpu")
    model = model.cuda()
    criterion = criterion.cuda()
  sys.stdout.flush()
  base_lr = 0.001
  momentum = 0.9
  weight_decay = 1e-4


  parameters = []
  for p in model.parameters():
    if p.requires_grad:
      parameters.append(p)

  optimizer = torch.optim.SGD(parameters, lr=base_lr, momentum=momentum, weight_decay=weight_decay)

  # Decay LR by a factor of 0.1 every 40 epochs
  exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                         step_size=int(num_epochs/4),
                                         gamma=0.1)

  model = train_model(model, criterion, optimizer,
                      exp_lr_scheduler, num_epochs=num_epochs)


def train_model(model, criterion, optimizer, scheduler, num_epochs=100):
  since = time.time()
  best_model_wts = model.state_dict()
  best_acc = 0.0

  train_data_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25),
    transforms.RandomCrop(224),
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
  ])
  val_data_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
  ])
  cat_train = CataractDataset(train=True,
                               transform=train_data_transform)

  cat_val = CataractDataset(train=False,
                               transform=val_data_transform)
  dataloaders = {
    'train': DataLoader(cat_train, batch_size=batch_size, shuffle=True, num_workers=2),
    'val': DataLoader(cat_val, batch_size=batch_size, shuffle=True, num_workers=2)
  }

  for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
      if phase == 'train':
        scheduler.step()
        model.train(True)  # Set model to training mode
      else:
        model.train(False)  # Set model to evaluate mode

      # dataset_size = len(dataloders[phase].dataset.keys)
      running_loss = 0.0
      running_corrects = 0

      # Iterate over data.
      count = 0
      for data in dataloaders[phase]:
        dataset_size = len(dataloaders[phase])
        # get the inputs
        inputs, labels = data
        # N x 1 ->  N
        # labels = labels.view(-1)

        # wrap them in Variable
        if use_gpu:
          inputs = Variable(inputs.cuda()).float()
          labels = Variable(labels.cuda()).long()
        else:
          inputs, labels = Variable(inputs).float(), Variable(labels).long()
        count += 1

        # zero the parameter gradients
        optimizer.zero_grad()
        # forward
        outputs = model(inputs)
        outputs = outputs.view(inputs.size(0),n_classes)
        _, preds = torch.max(outputs.data, 1)
        loss = criterion(outputs, labels)


        # backward + optimize only if in training phase
        if phase == 'train':
          loss.backward()
          optimizer.step()

        # statistics
        running_loss += loss.data[0]
        running_corrects += torch.sum(preds == labels.data)

        print('{:d}/{:d}:  {:s}_loss: {:.3f}, {:s}_acc: {:.3f} \r'.format((count * batch_size),
                                                                          dataset_size * batch_size,
                                                                          phase,
                                                                          running_loss / count,
                                                                          phase,
                                                                          running_corrects / (count * batch_size)),
              end='\r')
        sys.stdout.flush()

      epoch_loss = running_loss / (count * batch_size)
      epoch_acc = running_corrects / (count * batch_size)

      print('---------  {} Loss: {:.4f} Acc: {:.4f} -----------'.format(phase, epoch_loss, epoch_acc))

      # deep copy the model
      if phase == 'val' and epoch_acc > best_acc:
        best_acc = epoch_acc
        best_model_wts = model.state_dict()

        if not os.path.exists('weights/' + weight_out_dir):
          os.makedirs('weights/' + weight_out_dir)

        torch.save(best_model_wts, 'weights/' + weight_out_dir + '/{:03d}_{:.3f}.pkl'.format(epoch, epoch_acc))

    print()

  time_elapsed = time.time() - since
  print('Training complete in {:.0f}m {:.0f}s'.format(
    time_elapsed // 60, time_elapsed % 60))
  print('Best val Acc: {:4f}'.format(best_acc))

  # load best model weights
  model.load_state_dict(best_model_wts)
  return model


if __name__ == "__main__":
  main()
