import numpy as np
import skvideo.io
import skimage
import skimage.transform
from skimage import img_as_ubyte
import os
import pdb
import lmdb 
import sys

sys.path.insert(0, "../Models/")
import helpers

VIDEO_WIDTH = 112 
VIDEO_HEIGHT = 112 
NUM_CHANNELS = 3 
BATCH_SIZE =  10

def write_clip_lmdb(clip_path, lmdb_env):
    clip_name = clip_path.split('/')[-1]
    clip_name_split = clip_name.split('_')
    p = clip_name_split[0][1:]
    n = clip_name_split[1][1:]
    v = clip_name_split[3][:3]

    #Load in video
    try:
        vid = skvideo.io.FFmpegReader(clip_path)
    except ValueError:
        print("\t\tIssue with video encoding. Skipping.")
        sys.stdout.flush()
        return

    num_frames = vid.getShape()[0]
    num_frames -= num_frames%32
    
    cap = vid.nextFrame()
    clip = np.empty((1, 16, VIDEO_HEIGHT, VIDEO_WIDTH, NUM_CHANNELS))
    print('\tWriting phase clip ' + clip_name + ' to lmdb with ' + str(num_frames) + ' frames.')
    lmdb_txn = lmdb_env.begin(write=True)

    for i in range(num_frames):
        img = np.flip(next(cap), axis = 2)
        curFrame = i%32
        curSegment = int(i/32)+1
        key_name = p + '_' + n + '_' + v + '_' + str(curSegment)
        if curFrame%2 == 0:
            clip[:, int(curFrame/2), :, :, :] = skimage.transform.resize(img, (VIDEO_HEIGHT, VIDEO_WIDTH))

        if curFrame == 31:
            X = np.transpose(clip, (0, 4, 1, 2, 3))
            X = img_as_ubyte(X)
            lmdb_txn.put(key_name.encode(), X.tobytes())
            if curSegment % BATCH_SIZE == 0:
                print('\t\tDone writing segment ' + str(curSegment))
                sys.stdout.flush()
                lmdb_txn.commit()
                lmdb_txn = lmdb_env.begin(write=True)           
    lmdb_txn.commit()
    vid.close()

def write_lmdb():
    phase_clip_dir = '/home-2/gsilvac1@jhu.edu/work/gsilvac1/CATARACT/optical_flow/data/'
    lmdb_out_dir = '/home-2/gsilvac1@jhu.edu/work/gsilvac1/CATARACT/optical_flow/seg_lmdbs/'
    train_list, test_list = helpers.read_train_test_split('/home-2/gsilvac1@jhu.edu/work/gsilvac1/CATARACT/train_test_split/')

    for i in range(1,11):
        print('\t\tCurrently on phase ' + str(i))
        print('-----------------------------------------------------------')
        sys.stdout.flush()

        lmdb_out_train_dir = lmdb_out_dir + 'train/'
        lmdb_out_test_dir = lmdb_out_dir + 'test/'

        lmdb_file_train = os.path.join(lmdb_out_train_dir,  str(i))
        lmdb_file_test = os.path.join(lmdb_out_test_dir, str(i))    

        lmdb_env_train = lmdb.open(lmdb_file_train, map_size = int(1e12))
        lmdb_env_test = lmdb.open(lmdb_file_test, map_size = int(1e12))

        spec_phase_clip_dir = phase_clip_dir + str(i) + '/'
        phase_clip_names = os.listdir(spec_phase_clip_dir)
        for phase_clip in phase_clip_names:
            if phase_clip in train_list:
                write_clip_lmdb(spec_phase_clip_dir + phase_clip, lmdb_env_train)
            else:
                write_clip_lmdb(spec_phase_clip_dir + phase_clip, lmdb_env_test)
        print('')
        
if __name__ == "__main__":
  write_lmdb()
