#########################################
# Johns Hopkins University              #
# 601.455 Computer Integrated Surgery 2 #
# Spring 2018                           #
# Query by Video For Surgical Activities#
# Felix Yu                              #
# JHED: fyu12                           #
# Gianluca Silva Croso                  #
# JHED: gsilvac1                        #
#########################################

import os
import sys
import cv2
import skvideo.io 
import numpy as np

NUM_PHASES = 11
# DIRECTORY WHERE FULL SURGERY VIDEO FILES ARE LOCATED - MODIFY IF NECESSARY
vid_dir = "/home-2/fyu12@jhu.edu/scratch/ReiterLab/cataract/"
# DIRECTORY WHERE ANNOTATION FILES ARE LOCATED - MODIFY IF NECESSARY
anno_dir = "/home-2/fyu12@jhu.edu/scratch/ReiterLab/cataract_phase_annoations/"
# DIRECTORY WHERE PHASE CLIPS WILL BE WRITTEN - MODIFY IF NECESSARY
phase_dir = "/home-2/fyu12@jhu.edu/scratch/ReiterLab/cataract_phase_separated/"


def separate_phases():
    """
    Read each video from vid_dir and segment it into phases based on corresponding
    annotations in anno_dir. Write out resulting phase clips into phase dir under
    directory corresponding to the phase (1 to 10).
    In this script, annotation format is:
    start_frame     end_frame   phase
    """
    videos = os.listdir(vid_dir)
    for vid_name in videos:
        # Set up loading in the video with the annotations
        vid_number = vid_name[4:7]        
        if vid_number == "159" or vid_number == "125":
            continue
        print("Currently on video: " + vid_number)
        sys.stdout.flush()
        anno_name = "vid_"+vid_number+"_tasks.txt"
            
        # Open up annotation and video files.
        with open(anno_dir+anno_name) as f:
            annotations = f.readlines()
        num_phase_vids = [0]*(NUM_PHASES+1)
        cap = cv2.VideoCapture(vid_dir+vid_name)
        fps = int(round(cap.get(cv2.CAP_PROP_FPS)))
        if fps > 30:
            print("\tFPS will be reduced for this video.")
            sys.stdout.flush()
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        cur_frame = 1
        cap.release()
        vid = skvideo.io.FFmpegReader(vid_dir+vid_name)
        cap = vid.nextFrame()
        # Goes through each line in annotation and generates a phase video per line.
        for line in annotations:
            split = line.strip().split()
            start_frame = int(split[0])
            end_frame = int(split[1])
            phase = int(split[2])
            num_phase_vids[phase] += 1
            # Advance video to start of the phase.
            while cur_frame < start_frame:
                frame = next(cap)
                cur_frame += 1
            print('\tWriting out video for phase: ' + str(phase))
            sys.stdout.flush()
            # Create a video writer, and write the video until the end of the phase.
            vid_name = phase_dir+str(phase) + "/p"+str(phase) + "_n"+str(num_phase_vids[phase]) + '_vid_' +\
                vid_number + '.avi'
            '''
            write_video = False
            try:
                test_vid = skvideo.io.FFmpegReader(vid_name)
            except ValueError:
                write_video = True
            '''
            write_video = True
            if write_video:
                fourcc = cv2.VideoWriter_fourcc(*'DIVX')
                if fps > 30:
                    out = cv2.VideoWriter(vid_name, fourcc, fps/2, (width, height))
                else:
                    out = cv2.VideoWriter(vid_name, fourcc, fps, (width, height))
                while cur_frame <= end_frame:
                    frame = next(cap)
                    frame = np.flip(frame, axis = 2)
                    cur_frame += 1
                    if fps > 30 and not cur_frame % 2 == 0:
                        continue
                    out.write(frame)
                out.release()
            '''
            else:
                print('\tSkipping phase ' + str(phase))
                sys.stdout.flush()
                while cur_frame <= end_frame:
                    ret, frame = cap.read()
                    cur_frame+=1
            '''
        vid.close()


if __name__ == "__main__":
    separate_phases()
