#########################################
# Johns Hopkins University              #
# 601.455 Computer Integrated Surgery 2 #
# Spring 2018                           #
# Query by Video For Surgical Activities#
# Code by Tae-Soo Kim                   #
#########################################

import lmdb
import os
import random
import pdb
import numpy as np
#import torch.utils.data as data
import sys

sys.path.append('/home-4/tkim60@jhu.edu/scratch/ThirdParty/opencv_install/lib/python3.6/site-packages')
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
training_set = {'155': 5,
                '149': 5,
                '147': 5,
                '157': 5,
                '062': 5,
                '161': 5,
                '148': 5,
                '133': 5,
                '008': 5,
                #'159': 5,
                '077': 4,
                '169': 4,
                '124': 4,
                '138': 4,
                '052': 4,
                '122': 4,
                '119': 3,
                '170': 3,
                '042': 3,
                '053': 3,
                '039': 3,
                '123': 3,
                '121': 3,
                '015': 3,
                '014': 3,
                #'125': 3,
                '023': 3,
                '051': 2,
                '058': 2,
                '055': 2,
                '085': 2,
                '026': 2,
                '158': 5,
                '146': 5,
                '154': 5,
                '162': 5,
                '137': 5,
                '145': 5,
                '034': 4,
                '128': 4,
                '181': 4,
                }

testing_set = {
                '164': 5,
                '150': 5,
                '127': 4,
                '114': 4,
                '038': 3,
                '118': 3,
                '117': 2,
                '120': 2,
               }

height = 224
width = 224
dataset_sizes = {
  'train': 0,
  'val': 0
}
dataset_mean = 95.888412356227832
dataset_std = 61.551035896724


class CataractDataset(Dataset):

  def __init__(self, train=True, transform=None):
    """
    Args:
        root_dir (string): root directory, data should be organized as
                      root/{activity_name}/{sample}/frame_N.png
        transform (callable, optional): Optional transform to be applied
            on a sample.
    """
    self.root_dir = '/media/tk/EE44DA8044DA4B4B/cataract_phase_img_balance_correct_split'
    #self.root_dir = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_balance_correct_split/'

    self.transform = transform
    self.n_classes = 11
    self.sample_paths = []
    self.lmdb_cursor_x = None
    self.lmdb_cursor_y = None
    self.train = train

    self._init()


  def __len__(self):
    return len(self.sample_paths)

  def __getitem__(self, idx):
    value = np.frombuffer(self.lmdb_cursor_x.get(self.sample_paths[idx]), dtype=np.dtype(np.uint8))
    label = np.frombuffer(self.lmdb_cursor_y.get(self.sample_paths[idx]), dtype=np.dtype(np.int64))

    ## BGR
    x = value.reshape((height, width, 3))
    x.setflags(write=1)

    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)

    if self.transform:
      x = self.transform(x)

    return (x,label[0])

  def __len__(self):
    return len(self.sample_paths)

  def _init(self):
    if self.train:
      lmdb_env_x = lmdb.open(os.path.join(self.root_dir,'X_train'))
    else:
      lmdb_env_x = lmdb.open(os.path.join(self.root_dir, 'X_test'))
    lmdb_txn_x = lmdb_env_x.begin()
    self.lmdb_cursor_x = lmdb_txn_x.cursor()

    if self.train:
      lmdb_env_y = lmdb.open(os.path.join(self.root_dir, 'Y_train'))
    else:
      lmdb_env_y = lmdb.open(os.path.join(self.root_dir, 'Y_test'))
    lmdb_txn_y = lmdb_env_y.begin()
    self.lmdb_cursor_y = lmdb_txn_y.cursor()

    _keys = []
    vals = []
    for k, v in self.lmdb_cursor_y:
      _keys.append(k)
      vals.append(np.frombuffer(v, dtype=np.dtype(np.int64))[0])
    _keys = np.array(_keys)
    vals = np.array(vals)
    least_num_class = 9999999999

    for cl in range(0, self.n_classes ):
      nums = len(np.where(np.array(vals) == cl)[0])
      if nums > 0 and least_num_class > nums:
        least_num_class = nums

    keys = []
    for class_ind in range(0, self.n_classes ):
      indices = np.where(vals == class_ind)[0]
      np.random.shuffle(indices)
      for k in indices[:least_num_class]:
        self.sample_paths.append(_keys[k])

    dataset_sizes['train'] = len(self.sample_paths)



def dataloders(phase, n_classes, batch_size):
  if phase == 'train':
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_csp2/'
    lmdb_file_train_x = os.path.join(data_root, 'X_train')
    lmdb_file_train_y = os.path.join(data_root, 'Y_train')

    return train_datagen(n_classes,
                         batch_size,
                         (3, height, width),
                         lmdb_file_train_x,
                         lmdb_file_train_y)
  elif phase == 'val':
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_csp2/'
    lmdb_file_test_x = os.path.join(data_root, 'X_test')
    lmdb_file_test_y = os.path.join(data_root, 'Y_test')

    return test_datagen(n_classes,
                        batch_size,
                        (3, height, width),
                        lmdb_file_test_x,
                        lmdb_file_test_y)

def dataloders_complete_balance(phase, n_classes, batch_size):
  if phase == 'train':
    # /home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_balance_correct_split/
    # /media/tk/EE44DA8044DA4B4B/cataract_phase_img_balance_correct_split
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_balance_correct_split/'
    lmdb_file_train_x = os.path.join(data_root, 'X_train')
    lmdb_file_train_y = os.path.join(data_root, 'Y_train')

    return train_datagen_CB(n_classes,
                         batch_size,
                         (3, height, width),
                         lmdb_file_train_x,
                         lmdb_file_train_y)
  elif phase == 'val':
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_balance_correct_split/'
    lmdb_file_test_x = os.path.join(data_root, 'X_test')
    lmdb_file_test_y = os.path.join(data_root, 'Y_test')


    return test_datagen_CB(n_classes,
                        batch_size,
                        (3, height, width),
                        lmdb_file_test_x,
                        lmdb_file_test_y)

def dataloders_complete_balance_aug(phase, n_classes, batch_size):
  if phase == 'train':
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_csp2/'
    #data_root = '/media/tk/EE44DA8044DA4B4B/cataract_phase_img_balance_correct_split'
    lmdb_file_train_x = os.path.join(data_root, 'X_train')
    lmdb_file_train_y = os.path.join(data_root, 'Y_train')

    return train_datagen_CB_aug(n_classes,
                         batch_size,
                         (3, height, width),
                         lmdb_file_train_x,
                         lmdb_file_train_y)
  elif phase == 'val':
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_csp2/'
    lmdb_file_test_x = os.path.join(data_root, 'X_test')
    lmdb_file_test_y = os.path.join(data_root, 'Y_test')

    return test_datagen_CB_aug(n_classes,
                        batch_size,
                        (3, height, width),
                        lmdb_file_test_x,
                        lmdb_file_test_y)

def dataloders_perclass(phase, n_classes, batch_size,klass=1):
  if phase == 'train':
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_csp2/'
    lmdb_file_train_x = os.path.join(data_root, 'X_train')
    lmdb_file_train_y = os.path.join(data_root, 'Y_train')

    return train_datagen_perclass(n_classes,
                         batch_size,
                         (3, height, width),
                         lmdb_file_train_x,
                         lmdb_file_train_y,
                         klass)
  elif phase == 'val':
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_csp2/'
    lmdb_file_test_x = os.path.join(data_root, 'X_test')
    lmdb_file_test_y = os.path.join(data_root, 'Y_test')

    return test_datagen_perclass(n_classes,
                        batch_size,
                        (3, height, width),
                        lmdb_file_test_x,
                        lmdb_file_test_y,
                        klass)


def dataloders_affine(phase, batch_size):
  if phase == 'train':
    #data_root = '/media/tk/EE44DA8044DA4B4B/cataract_phase_img_csp2/'
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_csp2/'
    lmdb_file_train_x = os.path.join(data_root, 'X_train')
    lmdb_file_train_y = os.path.join(data_root, 'Y_train')

    return train_datagen_affine(
                         batch_size,
                         (3, height, width),
                         lmdb_file_train_x,
                         lmdb_file_train_y)
  elif phase == 'val':
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_csp2/'
    lmdb_file_test_x = os.path.join(data_root, 'X_test')
    lmdb_file_test_y = os.path.join(data_root, 'Y_test')

    return test_datagen_affine(
                        batch_size,
                        (3, height, width),
                        lmdb_file_test_x,
                        lmdb_file_test_y)


def dataloders_CATARACTS(phase, n_classes, batch_size, split=0):
  if phase == 'train':
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/CATARACTS/split'+str(split)
    lmdb_file_train_x = os.path.join(data_root, 'Xtrain_lmdb')
    lmdb_file_train_y = os.path.join(data_root, 'Ytrain_lmdb')

    return train_datagen_CATARACTS(n_classes,
                         batch_size,
                         (3, height, width),
                         lmdb_file_train_x,
                         lmdb_file_train_y)
  elif phase == 'val':
    data_root = '/home-4/tkim60@jhu.edu/scratch/data/CATARACTS/split' + str(split)
    lmdb_file_test_x = os.path.join(data_root, 'Xtest_lmdb')
    lmdb_file_test_y = os.path.join(data_root, 'Ytest_lmdb')

    return test_datagen_CATARACTS(n_classes,
                        batch_size,
                        (3, height, width),
                        lmdb_file_test_x,
                        lmdb_file_test_y)


def train_datagen_affine(
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y):
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  X_xform = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y_xtrans = np.zeros(batch_size)
  Y_ytrans = np.zeros(batch_size)
  Y_rot = np.zeros(batch_size)

  batch_count = 0
  keys = [k for k, _ in lmdb_cursor_y]
  dataset_sizes['train'] = len(keys)
  np.random.shuffle(keys)

  vis_count = 0
  for index in keys:

    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))

    ## RGB
    x = value.reshape((height, width, 3))
    x.setflags(write=1)

    xform_img, x_class, y_class, rot_class = sample_affine_xform(x, trans_range=[-3, 3], rot_range=[-30, 30])

    x = np.transpose(x, (2, 0, 1))
    x = x.astype(np.float)
    x /= 255.0
    x -= 0.485

    xform_img = np.transpose(xform_img, (2, 0, 1))
    xform_img = x.astype(np.float)
    xform_img /= 255.0
    xform_img -= 0.485

    X[batch_count] = x
    X_xform[batch_count] = xform_img
    Y_xtrans[batch_count] = x_class
    Y_ytrans[batch_count] = y_class
    Y_rot[batch_count] = rot_class
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_xform = X_xform
      ret_y_xtrans = Y_xtrans
      ret_y_ytrans = Y_ytrans
      ret_y_rot = Y_rot
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      X_xform = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y_xtrans = np.zeros(batch_size)
      Y_ytrans = np.zeros(batch_size)
      Y_rot = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x,ret_xform,ret_y_xtrans,ret_y_ytrans,ret_y_rot)

def test_datagen_affine(
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y):


  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  X_xform = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y_xtrans = np.zeros(batch_size)
  Y_ytrans = np.zeros(batch_size)
  Y_rot = np.zeros(batch_size)

  batch_count = 0
  keys = [k for k, _ in lmdb_cursor_y]
  dataset_sizes['val'] = len(keys)
  np.random.shuffle(keys)

  vis_count = 0
  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))

    ## RGB
    x = value.reshape((height, width, 3))
    x.setflags(write=1)

    xform_img, x_class, y_class, rot_class = sample_affine_xform(x, trans_range=[-3, 3], rot_range=[-30, 30])

    x = np.transpose(x, (2, 0, 1))
    x = x.astype(np.float)
    x /= 255.0
    x -= 0.485

    xform_img = np.transpose(xform_img, (2, 0, 1))
    xform_img = x.astype(np.float)
    xform_img /= 255.0
    xform_img -= 0.485

    X[batch_count] = x
    X_xform[batch_count] = xform_img
    Y_xtrans[batch_count] = x_class
    Y_ytrans[batch_count] = y_class
    Y_rot[batch_count] = rot_class
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_xform = X_xform
      ret_y_xtrans = Y_xtrans
      ret_y_ytrans = Y_ytrans
      ret_y_rot = Y_rot
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      X_xform = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y_xtrans = np.zeros(batch_size)
      Y_ytrans = np.zeros(batch_size)
      Y_rot = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x,ret_xform,ret_y_xtrans,ret_y_ytrans,ret_y_rot)

def train_datagen_perclass(n_classes,
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y,
                  klass):
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y = np.zeros(batch_size)

  batch_count = 0
  #keys = [k for k, v in lmdb_cursor_y]
  ks = []
  vals = []
  for k,v in lmdb_cursor_y:
    ks.append(k)
    vals.append(np.frombuffer(v, dtype=np.dtype(np.int64)))
  vals = np.array(vals)
  ks = np.array(ks)

  klass_inds = np.where(vals == klass)[0]
  background_inds = np.where(vals == 0)[0]
  np.random.shuffle(background_inds)
  background_inds = background_inds[:len(klass_inds)]

  keys = np.concatenate((ks[klass_inds],ks[background_inds]))

  dataset_sizes['train'] = len(keys)
  np.random.shuffle(keys)

  vis_count = 0
  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))
    if label != 0:
      label = 1

    ## RGB
    x = value.reshape((height, width, 3))
    x.setflags(write=1)

    x = np.transpose(x, (2, 0, 1))
    x = x.astype(np.float)
    x /= 255.0
    x -= 0.485

    X[batch_count] = x
    Y[batch_count] = label
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_y = Y
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x,ret_y)
def test_datagen_perclass(n_classes,
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y,
                  klass):
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y = np.zeros(batch_size)


  batch_count = 0
  #keys = [k for k, v in lmdb_cursor_y]
  ks = []
  vals = []
  for k,v in lmdb_cursor_y:
    ks.append(k)
    vals.append(np.frombuffer(v, dtype=np.dtype(np.int64)))
  vals = np.array(vals)
  ks = np.array(ks)

  klass_inds = np.where(vals == klass)[0]
  background_inds = np.where(vals == 0)[0]
  np.random.shuffle(background_inds)
  background_inds = background_inds[:len(klass_inds)]

  keys = np.concatenate((ks[klass_inds], ks[background_inds]))
  dataset_sizes['val'] = len(keys)

  vis_count = 0
  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))
    if label != 0:
      label = 1

    ## RGB
    x = value.reshape((height, width, 3))
    x.setflags(write=1)

    x = np.transpose(x, (2, 0, 1))
    x = x.astype(np.float)
    x /= 255.0
    x -= 0.485

    X[batch_count] = x
    Y[batch_count] = label
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_y = Y
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x,ret_y)


def train_datagen(n_classes,
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y):
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y = np.zeros(batch_size)

  batch_count = 0
  keys = [k for k, _ in lmdb_cursor_x]
  dataset_sizes['train'] = len(keys)
  np.random.shuffle(keys)

  vis_count = 0
  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))

    ## RGB
    x = value.reshape((height, width, 3))
    x.setflags(write=1)

    x = np.transpose(x, (2, 0, 1))
    x = x.astype(np.float)
    x /= 255.0

    X[batch_count] = x
    Y[batch_count] = label
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_y = Y
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x,ret_y)


def test_datagen(n_classes,
                 batch_size,
                 input_shape,
                 lmdb_file_test_x,
                 lmdb_file_test_y):
  lmdb_env_x = lmdb.open(lmdb_file_test_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_test_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y = np.zeros(batch_size)

  batch_count = 0
  keys = [k for k, _ in lmdb_cursor_x]
  dataset_sizes['val'] = len(keys)
  np.random.shuffle(keys)

  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))

    ## RGB
    x = value.reshape((height, width, 3))
    x.setflags(write=1)

    x = np.transpose(x, (2, 0, 1))
    x = x.astype(np.float)
    x /= 255.0

    X[batch_count] = x
    Y[batch_count] = label
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_y = Y
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x, ret_y)


def train_datagen_CB(n_classes,
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y):

  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y = np.zeros(batch_size)

  batch_count = 0
  _keys = []
  vals = []
  for k, v in lmdb_cursor_y:
    _keys.append(k)
    vals.append(np.frombuffer(v,dtype=np.dtype(np.int64))[0])
  _keys = np.array(_keys)
  vals = np.array(vals)
  least_num_class = 9999999999

  for cl in range(0, n_classes):
    nums = len(np.where(np.array(vals) == cl)[0])
    if nums > 0 and least_num_class > nums:
      least_num_class = nums

  keys = []
  for class_ind in range(0,n_classes):
    indices = np.where(vals == class_ind)[0]
    np.random.shuffle(indices)
    for k in indices[:least_num_class]:
      keys.append(_keys[k])


  dataset_sizes['train'] = len(keys)
  print(dataset_sizes)
  np.random.shuffle(keys)

  vis_count = 0
  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))

    ## BGR
    x = value.reshape((height, width, 3))
    x.setflags(write=1)

    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)

    x = x.astype(np.float)
    #[0.485, 0.456, 0.406]
    x /= 255
    x -= [0.485, 0.456, 0.406]
    x /= [0.229,0.224,0.225]
    x = np.transpose(x, (2, 0, 1))

    X[batch_count] = x
    Y[batch_count] = label
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_y = Y
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x,ret_y)

'''
def train_datagen_CB_aug(n_classes,
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y):
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y = np.zeros(batch_size)

  batch_count = 0
  _keys = []
  vals = []
  for k, v in lmdb_cursor_y:
    _keys.append(k)
    vals.append(np.frombuffer(v,dtype=np.dtype(np.int64))[0])
  _keys = np.array(_keys)
  vals = np.array(vals)

  least_num_class = len(np.where(np.array(vals) == 10)[0])

  keys = []
  for class_ind in range(0,n_classes):
    indices = np.where(vals == class_ind)[0]
    np.random.shuffle(indices)
    for k in indices[:least_num_class]:
      keys.append(_keys[k])


  dataset_sizes['train'] = len(keys)*4
  np.random.shuffle(keys)

  vis_count = 0
  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))

    ## RGB
    x = value.reshape((height, width, 3))
    x.setflags(write=1)


    x_crop = A.random_resize_crop(x)
    x_aff = A.random_affine(x)
    x_flip = A.horizontal_flip(x)

    x = trans_mean_norm_std(x)
    X[batch_count] = x
    Y[batch_count] = label
    batch_count += 1

    x_crop = trans_mean_norm_std(x_crop)
    X[batch_count] = x_crop
    Y[batch_count] = label
    batch_count += 1

    x_aff = trans_mean_norm_std(x_aff)
    X[batch_count] = x_aff
    Y[batch_count] = label
    batch_count += 1

    x_flip = trans_mean_norm_std(x_flip)
    X[batch_count] = x_flip
    Y[batch_count] = label
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_y = Y
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x,ret_y)
'''

def test_datagen_CB(n_classes,
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y):
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()

  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y = np.zeros(batch_size)

  batch_count = 0
  _keys = []
  vals = []
  for k, v in lmdb_cursor_y:
    _keys.append(k)
    vals.append(np.frombuffer(v,dtype=np.dtype(np.int64))[0])
  _keys = np.array(_keys)
  vals = np.array(vals)

  least_num_class = 9999999999

  for cl in range(0,n_classes):
    nums = len(np.where(np.array(vals) == cl)[0])
    if nums > 0 and least_num_class > nums:
      least_num_class = nums

  keys = []
  for class_ind in range(0,n_classes):
    indices = np.where(vals == class_ind)[0]
    np.random.shuffle(indices)
    for k in indices[:least_num_class]:
      keys.append(_keys[k])


  dataset_sizes['val'] = len(keys)
  np.random.shuffle(keys)
  vis_count = 0
  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))

    ## BGR
    x = value.reshape((height, width, 3))
    x.setflags(write=1)

    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)

    x = x.astype(np.float)
    # [0.485, 0.456, 0.406]
    x /= 255
    x -= [0.485, 0.456, 0.406]
    x /= [0.229, 0.224, 0.225]
    x = np.transpose(x, (2, 0, 1))

    X[batch_count] = x
    Y[batch_count] = label
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_y = Y
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x,ret_y)



def test_datagen_CB_aug(n_classes,
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y):
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y = np.zeros(batch_size)

  batch_count = 0
  _keys = []
  vals = []
  for k, v in lmdb_cursor_y:
    _keys.append(k)
    vals.append(np.frombuffer(v,dtype=np.dtype(np.int64))[0])
  _keys = np.array(_keys)
  vals = np.array(vals)

  least_num_class = len(np.where(np.array(vals) == 10)[0])

  keys = []
  for class_ind in range(0,n_classes):
    indices = np.where(vals == class_ind)[0]
    np.random.shuffle(indices)
    for k in indices[:least_num_class]:
      keys.append(_keys[k])


  dataset_sizes['val'] = len(keys)
  np.random.shuffle(keys)

  vis_count = 0
  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))

    ## RGB
    x = value.reshape((height, width, 3))
    x.setflags(write=1)

    x = np.transpose(x, (2, 0, 1))
    x = x.astype(np.float)
    x -= dataset_mean
    x /= dataset_std

    X[batch_count] = x
    Y[batch_count] = label
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_y = Y
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x,ret_y)


def train_datagen_CATARACTS(n_classes,
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y):
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y = np.zeros(batch_size)

  batch_count = 0
  keys = [k for k, v in lmdb_cursor_y]
  dataset_sizes['train'] = len(keys)

  np.random.shuffle(keys)

  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index))
    y = label.argmax()
    x = value.reshape((135, 240, 3))
    x.setflags(write=1)
    x = cv2.resize(x, (224, 224))

    ## BGR
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.astype(np.float)
    # [0.485, 0.456, 0.406]
    x /= 255
    x -= [0.485, 0.456, 0.406]
    x /= [0.229, 0.224, 0.225]
    x = np.transpose(x, (2, 0, 1))



    X[batch_count] = x
    Y[batch_count] = y
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_y = Y
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x, ret_y)

def test_datagen_CATARACTS(n_classes,
                  batch_size,
                  input_shape,
                  lmdb_file_train_x,
                  lmdb_file_train_y):
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
  Y = np.zeros(batch_size)

  batch_count = 0
  keys = [k for k, v in lmdb_cursor_y]
  dataset_sizes['val'] = len(keys)

  np.random.shuffle(keys)

  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index))
    y = label.argmax()
    x = value.reshape((135, 240, 3))
    x.setflags(write=1)
    x = cv2.resize(x, (224, 224))

    ## BGR
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    x = x.astype(np.float)
    # [0.485, 0.456, 0.406]
    x /= 255
    x -= [0.485, 0.456, 0.406]
    x /= [0.229, 0.224, 0.225]
    x = np.transpose(x, (2, 0, 1))

    X[batch_count] = x
    Y[batch_count] = y
    batch_count += 1

    if batch_count == batch_size:
      ret_x = X
      ret_y = Y
      X = np.zeros((batch_size, input_shape[0], input_shape[1], input_shape[2]))
      Y = np.zeros(batch_size)
      batch_count = 0
      yield (ret_x, ret_y)

def sample_affine_xform(img, trans_range=[-3,3], rot_range=[-30,30]):
  ##  X-Y translation in range [-3, 3], 7 bins each
  ##  Rotation about Z-axis in range [-30, 30], 20 bins
  ##  3 softmax-es
  rows,cols,_ = img.shape
  x_trans = np.random.randint(trans_range[0],trans_range[1]+1)
  y_trans = np.random.randint(trans_range[0], trans_range[1] + 1)
  rot  = np.random.randint(rot_range[0],rot_range[1])

  angle = rot * np.pi / 180
#  M = np.array([[np.cos(angle), -np.sin(angle), x_trans],
#                [np.sin(angle), np.cos(angle), y_trans]])
  M = cv2.getRotationMatrix2D((cols / 2, rows / 2), rot, 1)
  rotated_img = cv2.warpAffine(img,M,(cols,rows))

  M = np.float32([[1, 0, x_trans], [0, 1, y_trans]])
  xform_img = cv2.warpAffine(rotated_img, M, (cols, rows))

  x_class = x_trans + 3
  y_class = y_trans + 3
  rot_class = int((rot + 30) / 3)

  return xform_img, x_class, y_class, rot_class

def compute_dataset_stats():
  data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_csp2/'
  lmdb_file_train_x = os.path.join(data_root, 'X_train')
  lmdb_file_train_y = os.path.join(data_root, 'Y_train')
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  batch_count = 0
  keys = []
  for k, v in lmdb_cursor_y:
    keys.append(k)



  dataset_sizes['train'] = len(keys)
  np.random.shuffle(keys)

  mean_sum = 0.0

  count = 0
  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))

    ## RGB
    x = value.reshape((height, width, 3))
    x.setflags(write=1)
    mean_sum += x.mean()
    count += 1
    print(str(count),'/',str(len(keys)),end='\r')
  mean = mean_sum/len(keys)
  pdb.set_trace()

def compute_dataset_stats_std():
  data_root = '/home-4/tkim60@jhu.edu/scratch/data/cataract_phase_img_csp2/'
  lmdb_file_train_x = os.path.join(data_root, 'X_train')
  lmdb_file_train_y = os.path.join(data_root, 'Y_train')
  lmdb_env_x = lmdb.open(lmdb_file_train_x)
  lmdb_txn_x = lmdb_env_x.begin()
  lmdb_cursor_x = lmdb_txn_x.cursor()

  lmdb_env_y = lmdb.open(lmdb_file_train_y)
  lmdb_txn_y = lmdb_env_y.begin()
  lmdb_cursor_y = lmdb_txn_y.cursor()

  batch_count = 0
  keys = []
  for k, v in lmdb_cursor_y:
    keys.append(k)

  dataset_sizes['train'] = len(keys)
  np.random.shuffle(keys)

  mean_std = 0.0

  count = 0
  for index in keys:
    value = np.frombuffer(lmdb_cursor_x.get(index), dtype=np.dtype(np.uint8))
    label = np.frombuffer(lmdb_cursor_y.get(index), dtype=np.dtype(np.int64))

    ## RGB
    x = value.reshape((height, width, 3))
    x.setflags(write=1)
    x = x.astype(np.float)
    x -= dataset_mean
    mean_std += x.std()
    count += 1
    print(str(count), '/', str(len(keys)), end='\r')
  std = mean_std / len(keys)
  print(std)
  pdb.set_trace()

def trans_mean_norm_std(x):
  x = np.transpose(x, (2, 0, 1))
  x = x.astype(np.float)
  x -= dataset_mean
  x /= dataset_std
  return x


if __name__ == '__main__':
  print('here')
  #compute_dataset_stats_std()
  #[test for test in dataloders_affine('train', 16)]

  train_data_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
  ])
  test = CataractDataset(transform=train_data_transform)

  for x in test:
    pdb.set_trace()
  #[test for test in dataloders_complete_balance_aug('train', 11, 16)]
