# -*- coding: utf-8 -*-
"""
Created on Thu Nov 14 10:03:50 2013

@author: Ali Baghani
@copyright: 2013 Ultrasonix Medical Corporation - Analogic Ultrasound Group
"""

import struct
import numpy
import matplotlib.pyplot as plt
from numpy import array

def rp_read(filename, version = '6.0.7'):
    """This function loads the ultrasound RF data saved from the Sonix software
    
    Inputs: 
        filename - The path of the data to open
        version  - The version of the Sonix software used to save the data
        
    
    Outputs:
        Im       - The image data returned into a 3D array (h, w, numframes)
        header   - The file header information   
    
    """

    class RPFileHeader:
        pass

    class GPSData:
        pass
    
    with open(filename, 'rb') as f:
        
        hdrFmt = 'I'*19
        hinfo = struct.unpack(hdrFmt, f.read(19*4))
        header = RPFileHeader()   
        header.filetype = hinfo[0]
        header.nframes = hinfo[1]
        header.w = hinfo[2]
        header.h = hinfo[3]
        header.ss = hinfo[4]
        header.ul = [hinfo[5], hinfo[6]]
        header.ur = [hinfo[7], hinfo[8]]
        header.br = [hinfo[9], hinfo[10]]
        header.bl = [hinfo[11], hinfo[12]]
        header.probe = hinfo[13]
        header.txf = hinfo[14]
        header.sf = hinfo[15]
        header.dr = hinfo[16]
        header.ld = hinfo[17]
        header.extra = hinfo[18]
        
        fileType = {
            '.bpr': 2, #
            '.b8' : 4, #
            '.b32': 8, #
            '.rf' : 16,
            '.mpr': 32,
            '.m'  : 64,
            '.drf': 128,
            '.pw' : 256,
            '.crf': 512,
            '.col': 1024,
            '.cvv': 2048,
            '.con': 4096,
            '.el' : 8192,
            '.elo': 16384,
            '.epr': 32768,
            '.ecg': 65536,
            '.gps,.gps1': 131072,
            '.gps2': 262144,
            '.t': 524288
            }
                    
        usedToHaveAFrameTagFileTypes = {k:fileType[k] for k in [
            '.bpr',
            '.b8',
            '.b32',
            '.rf',
            '.mpr',
            '.drf',
            '.crf',
            '.con'
            ]}.values()
        
        XRGBFileTypes = {k:fileType[k] for k in ['.b32', '.col', '.el']}.values()
        
        GrScFileTypes = {k:fileType[k] for k in ['.b8', '.con', '.elo']}.values()
        
        GPSFileTypes  = {k:fileType[k] for k in ['.gps,.gps1', '.gps2']}.values()
        
        dataFmt = { 
            '.b8,.con,.elo':'B' * (header.h * header.w),
            '.b32,.col,.el':'I' * (header.h * header.w),
            '.bpr':'B' * (header.h * header.w),
            '.rf' :'h' * (header.h * header.w),
            '.mpr':'H' * (header.h),
            '.m'  :'B' * (header.h * header.w),
            '.drf':'h' * (header.h),
            '.pw' :'B' * (header.h * header.w),
            '.crf':'h' * (header.h * header.w * header.extra),
            '.cvv':'B' * (header.h * header.w * 2),
            '.epr':'B' * (header.h * header.w),
            '.ecg':'B' * (header.h * header.w),
            '.gps':'d' * 13 + 'H' * 4,         # used prior to 6.0.3
            '.gps1,.gps2':'d' * 20 + 'H' * 4,  # used with 6.0.3 and later
            '.t'  :'I' * (header.w)
            }
                                  
        frameSz = { 
            '.b8,.con,.elo': 1 * (header.h * header.w),
            '.b32,.col,.el': 4 * (header.h * header.w),
            '.bpr': 1 * (header.h * header.w),
            '.rf' : 2 * (header.h * header.w),
            '.mpr': 2 * (header.h),
            '.m'  : 1 * (header.h * header.w),
            '.drf': 2 * (header.h),
            '.pw' : 1 * (header.h * header.w),
            '.crf': 2 * (header.h * header.w * header.extra),
            '.cvv': 1 * (header.h * header.w * 2),
            '.con': 1 * (header.h * header.w),
            '.elo': 1 * (header.h * header.w),
            '.epr': 1 * (header.h * header.w),
            '.ecg': 1 * (header.h * header.w),
            '.gps': 8 * 13 + 2 * 4,       # used prior to 6.0.3
            '.gps1,.gps2':8 * 20 + 2 * 4, # used with 6.0.3 and later
            '.t'  : 4 * (header.w)
            }

        # --------------  memory initialization for speeding up ---------------
        if header.filetype in {k:fileType[k] for k in 
        ['.bpr','.b8','.rf', '.m', '.pw','.con','.elo','.epr']
        }.values():
            Im = array([[[0 for i in range(header.nframes)]
            for j in range(header.w)]
            for k in range(header.h)])
                
        elif header.filetype in XRGBFileTypes:
            Im = array([[[numpy.uint32(0) for i in range(header.nframes)]
            for j in range(header.w)]
            for k in range(header.h)])
        
        elif header.filetype in {k:fileType[k] for k in 
        ['.mpr']}.values():
            Im = array([[0 for i in range(header.nframes)]
            for j in range(header.h)])
        
        elif header.filetype in {k:fileType[k] for k in 
        ['.drf']}.values():
            Im = array([[[0 for i in range(header.nframes)]
            for j in range(1)]
            for k in range(header.h)])
        
        elif header.filetype in {k:fileType[k] for k in 
        ['.crf']}.values():
            Im = array([[[0 for i in range(header.nframes)]
            for j in range(header.w * header.extra)]
            for k in range(header.h)])
        
        elif header.filetype in {k:fileType[k] for k in 
        ['.cvv']}.values():
            Im = array([[[0 for i in range(header.nframes)]
            for j in range(header.w)]
            for k in range(header.h * 2)])
        
        elif header.filetype in {k:fileType[k] for k in 
        ['.t']}.values():
            Im = array([[0 for i in range(header.nframes)]
            for j in range(header.w)])
        
        elif header.filetype in GPSFileTypes:
            Im = GPSData()
            Im.gps_posx    = array([.0 for i in range(header.nframes)])
            Im.gps_posy    = array([.0 for i in range(header.nframes)])
            Im.gps_posz    = array([.0 for i in range(header.nframes)])
            Im.gps_s       = array([[[.0 for i in range(header.nframes)]
            for j in range(3)]
            for k in range(3)])
            Im.gps_time    = array([.0 for i in range(header.nframes)])
            Im.gps_quality = array([0 for i in range(header.nframes)])
            Im.Zeros   = array([[0 for i in range(header.nframes)]
            for j in range(3)])
            if version >= '6.0.3': # The GPS file type changed here and these were added
                Im.gps_a    = array([.0 for i in range(header.nframes)])
                Im.gps_e    = array([.0 for i in range(header.nframes)])
                Im.gps_r    = array([.0 for i in range(header.nframes)])
                Im.gps_q    = array([[.0 for i in range(header.nframes)]
                for j in range(4)])
        
        else:
            Im = []
        
        # ------------- Reading the data from the file into the Im -----------
        for frame_count in range(header.nframes):
            
            if header.filetype in usedToHaveAFrameTagFileTypes:
                if version < '6.0.0':
                # Each frame has 4 byte header for frame number in older versions
                    tag = struct.unpack('I', f.read(4))
                                
            if header.filetype in GrScFileTypes:
                data = array(struct.unpack(dataFmt['.b8,.con,.elo'], f.read(frameSz['.b8,.con,.elo'])))
                Im[:,:,frame_count] = data.reshape([header.h, header.w])
                
            elif header.filetype in XRGBFileTypes:
                data = array(struct.unpack(dataFmt['.b32,.col,.el'], f.read(frameSz['.b32,.col,.el'])))
                Im[:,:,frame_count] = data.reshape([header.h, header.w])

            elif header.filetype == fileType['.bpr']:
                data = array(struct.unpack(dataFmt['.bpr'], f.read(frameSz['.bpr'])))
                Im[:,:,frame_count] = data.reshape([header.w, header.h]).T
                
            elif header.filetype == fileType['.rf']:
                data = array(struct.unpack(dataFmt['.rf'], f.read(frameSz['.rf'])))
                Im[:,:,frame_count] = data.reshape([header.w, header.h]).T

            elif header.filetype == fileType['.mpr']:
                if (version >= '6.0.7'): # This feature was broken before this version
                    data = array(struct.unpack(dataFmt['.mpr'], f.read(frameSz['.mpr'])))
                    Im[:,frame_count] = data
                
            elif header.filetype == fileType['.m']:
                data = array(struct.unpack(dataFmt['.m'], f.read(frameSz['.m'])))
                Im[:,:,frame_count] = data.reshape([header.h, header.w])
                
            elif header.filetype == fileType['.drf']:
                data = array(struct.unpack(dataFmt['.drf'], f.read(frameSz['.drf'])))
                Im[:,:,frame_count] = data.reshape([header.h, 1])

            elif header.filetype == fileType['.pw']:
                data = array(struct.unpack(dataFmt['.pw'], f.read(frameSz['.pw'])))
                Im[:,:,frame_count] = data.reshape([header.h, header.w])
        
            elif header.filetype == fileType['.crf']:
                data = array(struct.unpack(dataFmt['.crf'], f.read(frameSz['.crf'])))
                Im[:,:,frame_count] = data.reshape([header.w * header.extra, header.h]).T

            elif header.filetype == fileType['.cvv']:
                data = array(struct.unpack(dataFmt['.cvv'], f.read(frameSz['.cvv'])))
                Im[:,:,frame_count] = data.reshape([header.h * 2, header.w])
                
            elif header.filetype == fileType['.epr']:
                data = []
                # Not working as of 6.0.6
            
            elif header.filetype == fileType['.ecg']:
                data = array(struct.unpack(dataFmt['.ecg'], f.read(frameSz['.ecg'])))
                Im[:,:,frame_count] = data.reshape([header.h, header.w])
            
            elif header.filetype in GPSFileTypes:
                if version >= '6.0.3':
                    data = array(struct.unpack(dataFmt['.gps1,.gps2'], f.read(frameSz['.gps1,.gps2'])))
                    Im.gps_posx[frame_count] = data[0]
                    Im.gps_posy[frame_count] = data[1]
                    Im.gps_posz[frame_count] = data[2]
                    Im.gps_a[frame_count] = data[3]
                    Im.gps_e[frame_count] = data[4]
                    Im.gps_r[frame_count] = data[5]
                    Im.gps_s[:, :, frame_count] = data[6:15].reshape([3, 3])
                    Im.gps_q[:, frame_count] = data[15:19]
                    Im.gps_time[frame_count] = data[19]
                    Im.gps_quality[frame_count] = data[20]
                    Im.Zeros[:, frame_count] = data[21:24]
                else:
                    data = array(struct.unpack(dataFmt['.gps'], f.read(frameSz['.gps'])))
                    Im.gps_posx[frame_count] = data[0]
                    Im.gps_posy[frame_count] = data[1]
                    Im.gps_posz[frame_count] = data[2]
                    Im.gps_s[:, :, frame_count] = data[3:12].reshape([3, 3])
                    Im.gps_time[frame_count] = data[12]
                    Im.gps_quality[frame_count] = data[13]
                    Im.Zeros[:, frame_count] = data[14:17]
                    
            elif  header.filetype == fileType['.t']:
                data = array(struct.unpack(dataFmt['.t'], f.read(frameSz['.t'])))
                Im[:,frame_count] = data
                    
            else:
                print('Data not supported')
                                 
    return (Im, header)

def plot_xrgb(data):
    """ This function plots XRGB images: .b32, .col, and .el
    
    """    
    # Converting the XRGB data to R and G and B channels
    X = (data & 0xFF000000) >> 24
    R = (data & 0x00FF0000) >> 16
    G = (data & 0x0000FF00) >> 8
    B = (data & 0x000000FF)
    
    Sz = list(data.shape)
    Sz.append(4)
    I = numpy.zeros(Sz, dtype=numpy.uint8)
    
    I[:, :, 0] = R
    I[:, :, 1] = G
    I[:, :, 2] = B
    I[:, :, 3] = 255-X

    fig = plt.figure()
    ax = fig.add_subplot(111)
    im = ax.imshow(I)

def plot_grayscale(data):
    """This function plots grayscale images: .b8, .m, .pw, and .bpr
    
    """
    fig = plt.figure()
    ax  = fig.add_subplot(111)
    im  = ax.imshow(data)
    im.set_cmap('gray')
