import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
from segment import segment_bytime
from processdata import get_lengthofstay

''' This function uses the segment and length of stay functions in order to get a list of durations.
	This list has the durations of stay in 4P for the initial patients on the unit when the
	simulation begins.
    Inputs: 
    simulation_date- the current date and time from the simulation
    start_range- the beginning of the range to segment the time of arrival, should be before datetime
    end_ range- the end of hte range to segment time of arrival, should be after datetime
    number_of_patients- the number of patients on the unit that need values for the time left
    Outputs:
    duration_list- a list of remaining times for the number of patients given
'''
def get_timeleft(simulation_date, start_range, end_range, number_of_patients):
	# Uses the general ADTs dataframe- this is all of the relevant historical data
	ADTs = pd.read_pickle('ADTs.pkl')

	# Gets the ADTs of the relevant patients given the temporal criteria
	# Accounts for ranges that span overnight
	if end_range < start_range:
		Seg_ADTs1, pat_num1 = segment_bytime(ADTs,simulation_date,start_range,24)
		Seg_ADTs2, pat_num2 = segment_bytime(ADTs,simulation_date,0,end_range)
		Seg_ADTs = pd.concat([Seg_ADTs1, Seg_ADTs2])
	else:
		Seg_ADTs, pat_num = segment_bytime(ADTs,simulation_date,start_range,end_range)

	pat_num = number_of_patients
	# Creates list to hold durations of 4P patients
	duration_list = get_lengthofstay(Seg_ADTs, pat_num)

	# The duration is multiplied by a random number between 0 and 1 to transform it to a remaining time
	for i in range(len(duration_list)):
		duration_list[i] = duration_list[i] * np.random.random()

	return duration_list

''' This function uses a lookup table of to determine the remaining length of stay of patients
	given the current simulation time based on historical trends. It obtains a list of the durations 
	of stay in 4P for the initial patients on the unit when the simulation begins.
    Inputs: 
    simulation_date-    the current date and time from the simulation
    start_range-        the beginning of the range to segment the time of arrival
    end_ range-         the end of hte range to segment time of arrival
    number_of_patients- the number of patients on the unit that need values for the time left
    lookup-             a 4x7x4 table that contains duration lists and number of patients per hour
    					probability distribution for each time partition
    Outputs:
    samples- a list of remaining times for the number of patients given
'''
def get_timeleft_new(simulation_date, start_range, end_range, number_of_patients,lookup):
	# Determine the temporal parameters given the simulation_date and time range values
	month = simulation_date.month
	if month in [12,1,2]:
		season = 0
	if month in [3,4,5]:
		season = 1
	if month in [6,7,8]:
		season = 2
	if month in [9,10,11]:
		season = 3
	dayofweek = simulation_date.weekday()
	if start_range == 7:
		time_range = 0
	elif start_range == 15:
		time_range = 1
	elif start_range == 19:
		time_range = 2
	else:
		time_range = 3

	# Get the list of durations from the lookup table
	duration_list = lookup[season][dayofweek][time_range][0]

	# Create a histogram of durations and fits a random variable to the histogram
	plt.hist(duration_list, normed=True, bins=100)
	xt = plt.xticks()[0]  
	xmin, xmax = min(xt), max(xt)  
	lnspc = np.linspace(xmin, xmax, len(duration_list))
	hist = np.histogram(duration_list, bins=100)
	hist_dist = stats.rv_histogram(hist)

	# Sample the correct number from this random variable
	samples = hist_dist.rvs(size = number_of_patients)
	return samples

	# Account for time that has already passed since patients arrived
	#print(np.mean(duration_list)/(60*40))
	#print(np.mean(samples)/(60*40))
	for i in range(len(samples)):
		new_sample = samples[i] * np.random.normal(.6,.1)
		if new_sample < 0:
			samples[i] = -1 * new_sample
		else:
			samples[i] = new_sample
	#print(np.mean(samples)/(60*40))
	return samples
	