import numpy as np
import pandas as pd
import scipy.stats as stats
import matplotlib.pyplot as plt
''' Function takes a list of values and returns the list with the outliers removed.
	Inputs:
	values- list of values
	thresh- number of standard deviations away that defines an outlier
	Outputs:
	new_values- list with outliers removed
'''
def remove_outliers(values, thresh = 4):
	mean = np.mean(values)
	sd = np.std(values)
	new_values = [val for val in values if (abs(val-mean) / sd) < thresh]
	return new_values

''' Function that gets a length of stay predictor for 4P given a set of ADTs.
	Inputs:
	dataframe_in- a pandas dataframe which contains ADT records
	Outputs:
	duration sample- a single value that is a predictor of length of stat in 4P for the given ADT data
'''
def get_lengthofstay(dataframe_in,num_of_pats):

	ADTs_In = dataframe_in
	ADTs_In.is_copy = False
	
	# Gets list of departments for given ADT set
	Dept_names = ADTs_In.ADT_DEPARTMENT_NAME.unique()
	# Creates a column for the duration of each event
	Duration = ADTs_In['OUT_DTTM'] - ADTs_In['IN_DTTM']
	# Adds this column into the ADT dataframe
	ADTs_In['Duration'] = Duration.astype('timedelta64[m]')

	# ADT_visits is a dataframe containing the VisitID, departments, in-times, and duration
	ADTs_In = ADTs_In.sort_values('IN_DTTM')
	ADT_visits = ADTs_In[['VisitID', 'ADT_DEPARTMENT_NAME', 'IN_DTTM', 'Duration']]
	#print(ADT_visits)

	# Group by VisitID and obtain a list of locations and durations for each unique visit ordered by in-time
	VisitGroup = ADT_visits.groupby('VisitID')
	Locations = ADT_visits.groupby('VisitID')['ADT_DEPARTMENT_NAME'].apply(list)
	Durations_List = ADT_visits.groupby('VisitID')['Duration'].apply(list)

	# Loop through locations and combine duplicate entries in location lists and add consecutive event durations in the same department
	for key, group in Locations.items():
		i = 0
		# Examine each department in the list
		while i<len(group):
			check = i
			# This conditional is satified if there is a duplicate location in the list
			if (check<len(group)-1) and (group[check] == group[check+1]):
				start_this = i
				this = i
				# Finds the last consecutive duplicate location
				while ( (this<len(group)-1) and (group[this] == group[this+1]) ):
					this = this+1

				# Creates a new location list with consecutive duplicates removed
				new_group = group[0:start_this] + [group[this]] + group[this+1:]
				# Updates Locations grouping
				Locations.set_value(key, new_group)

				# Gets the corresponding duration list
				od_group = Durations_List.get(key)
				# Creates an updated list combining consecutive duplicate durations
				new_duration_group = od_group[0:start_this] + [ sum(od_group[start_this:this+1]) ] + od_group[this+1:]
				# Updates Durations_List with corrected duration list
				Durations_List.set_value(key, new_duration_group)

				# Increments location to check and updates group variable
				i = start_this+1
				group = Locations.get(key)
			else:
				i = i+1

	# List to hold durations for 4P events
	FourPDurations = []

	# Loops through location list accounting for patients who leave 4P and return again
	for key, group in Locations.items():
		# Keeps track of indices in list that are 4P
		fourp_index_list = []
		for i, dept in enumerate(group):
			if dept == 'HCGH 4P ACUTE':
				fourp_index_list.append(i)
		# Sums up all 4P durations for a given list
		fourp_sum = 0
		for j in fourp_index_list:
			fourp_sum = fourp_sum + Durations_List.get(key)[j]
		FourPDurations.append(fourp_sum)

	# Removes outliers from the distribution
	#FourPDurations = remove_outliers(FourPDurations,4)

	# Converts from minutes to days
	#FourPDurations_Days = [each / (60*24) for each in FourPDurations]

	# Creates histogram of 4P lengths of stay
	plt.hist(FourPDurations, normed=True, bins=80)
	plt.xlabel('Duration (minutes)')
	plt.ylabel('Probability')
	plt.title('FourPDurations')
	#plt.show()

	# Fits a random variable to the histogram
	xt = plt.xticks()[0]  
	xmin, xmax = min(xt), max(xt)  
	lnspc = np.linspace(xmin, xmax, len(FourPDurations))
	hist = np.histogram(FourPDurations, bins=50)
	hist_dist = stats.rv_histogram(hist)
	plt.plot(lnspc, hist_dist.pdf(lnspc), label='PDF')
	#plt.show()

	# Samples from the random variable
	duration_samples = hist_dist.rvs(size = num_of_pats)
	
	# Returns the duration found
	return duration_samples

''' Function that gets the length of stays of all patients given a set of ADTs.
	Inputs:
	dataframe_in- a pandas dataframe which contains ADT records
	Outputs:
	FourPDurations- a list of every duration in 4P from patients in the ADT list
'''
def get_lengthofstaylist(dataframe_in):
	ADTs_In = dataframe_in
	ADTs_In.is_copy = False
	
	# Gets list of departments for given ADT set
	Dept_names = ADTs_In.ADT_DEPARTMENT_NAME.unique()
	# Creates a column for the duration of each event
	Duration = ADTs_In['OUT_DTTM'] - ADTs_In['IN_DTTM']
	# Adds this column into the ADT dataframe
	ADTs_In['Duration'] = Duration.astype('timedelta64[m]')

	# ADT_visits is a dataframe containing the VisitID, departments, in-times, out-times and duration
	ADT_visits = ADTs_In[['VisitID', 'ADT_DEPARTMENT_NAME', 'IN_DTTM', 'OUT_DTTM', 'Duration']]
	ADT_visits = ADT_visits.sort_values('IN_DTTM')
	#print(ADT_visits)

	# Group by VisitID and obtain a list of attributes for each unique visit ordered by in-time
	VisitGroup     = ADT_visits.groupby('VisitID')
	Locations      = ADT_visits.groupby('VisitID')['ADT_DEPARTMENT_NAME'].apply(list)
	Durations_List = ADT_visits.groupby('VisitID')['Duration'].apply(list)
	InTime_List    = ADT_visits.groupby('VisitID')['IN_DTTM'].apply(list)
	OutTime_List   = ADT_visits.groupby('VisitID')['OUT_DTTM'].apply(list)

	# Loop through locations and combine duplicate entries in location lists and add consecutive event durations in the same department
	for key, group in Locations.items():
		i = 0
		# Examine each department in the list
		while i<len(group):
			check = i
			# This conditional is satified if there is a duplicate location in the list
			if (check<len(group)-1) and (group[check] == group[check+1]):
				start_this = i
				this = i
				# Finds the last consecutive duplicate location
				while ( (this<len(group)-1) and (group[this] == group[this+1]) ):
					this = this+1

				# Creates a new location list with consecutive duplicates removed
				new_group = group[0:start_this] + [group[this]] + group[this+1:]
				# Updates Locations grouping
				Locations.set_value(key, new_group)

				# Gets the corresponding duration list
				od_group = Durations_List.get(key)
				print(od_group)
				# Creates an updated list combining consecutive duplicate durations
				new_duration_group = od_group[0:start_this] + [ sum(od_group[start_this:this+1]) ] + od_group[this+1:]
				print(new_duration_group)
				# Updates Durations_List with corrected duration list
				Durations_List.set_value(key, new_duration_group)

				# Increments location to check and updates group variable
				i = start_this+1
				group = Locations.get(key)
			else:
				i = i+1

	# List to hold durations for 4P events
	FourPDurations = []

	# Loops through location list accounting for patients who leave 4P and return again
	for key, group in Locations.items():
		# Keeps track of indices in list that are 4P
		fourp_index_list = []
		for i, dept in enumerate(group):
			if dept == 'HCGH 4P ACUTE':
				fourp_index_list.append(i)
		# Sums up all 4P durations for a given list
		fourp_sum = 0
		for j in fourp_index_list:
			fourp_sum = fourp_sum + Durations_List.get(key)[j]
		FourPDurations.append(fourp_sum)

	# Returns distribution of 4P durations for given patients in ADTs
	return FourPDurations
	