from datetime import datetime

import pandas as pd

from agefromname import AgeFromName


class InvalidGenerationBirthYearDefinition(Exception):
	pass


class GenerationFromName(object):
	def __init__(self,
	             generation_birth_years={'Greatest': [1915, 1929],
	                                     'Silent': [1930, 1945],
	                                     'Baby Boomers': [1946, 1964],
	                                     'Generation X': [1965, 1980],
	                                     'Millenials': [1981, 1995],
	                                     'Generation Z': [1996, 2010],
	                                     'Post Gen Z': [2011, 2025]},
	             age_from_name=None):
		'''
		:param generation_birth_years: dict, maps generation names to the first and
			last years of birth. Ex.: {'Millenials': [1980, 1995], 'Generation X': [1956, 1979]}
			Years outside this range will be considered "_other".
		:param age_from_name: AgeFromName, optional.  If note entered, will be autogenerated
		'''
		self._validate_generation_birth_years(generation_birth_years)
		self._generation_birth_years = generation_birth_years
		self._age_from_name = age_from_name if age_from_name is not None else AgeFromName()

	def _validate_generation_birth_years(self, generation_birth_years):
		invalid_type_or_tempate_error = "generation_birth_years must be a dict, which maps generation names to first and last birth years.  Ex: {'Millenials': [1980, 1995],'Generation X': [1956, 1979]}."
		if type(generation_birth_years) != dict:
			raise InvalidGenerationBirthYearDefinition(invalid_type_or_tempate_error)
		for key, val in generation_birth_years.items():
			if type(val) not in (list, tuple):
				raise InvalidGenerationBirthYearDefinition("Values in generation_birth_years must be lists or tuples.")
			if len(val) != 2:
				raise InvalidGenerationBirthYearDefinition(
					"Values in generation_birth_years must be nondecreasingly ordered lists of length two.")
			if val[0] > val[1]:
				raise InvalidGenerationBirthYearDefinition(
					"Values in generation_birth_years must be nondecreasingly ordered lists of length two.")

		ordered_ranges = list(sorted(generation_birth_years.values()))
		for i in range(1, len(ordered_ranges)):
			if ordered_ranges[i - 1][1] >= ordered_ranges[i][0]:
				raise InvalidGenerationBirthYearDefinition(
					"Values in generation_birth_years non overlapping. %s and %s overlap." % (
						ordered_ranges[i - 1], ordered_ranges[i]))

	def get_estimated_counts(self, first_name, sex, current_year=datetime.now().year, minimum_age=0):
		'''
		:param first_name: str, First name
		:param sex: str, m or f for sex
		:param current_year: int, optional, defaults to current year
		:param minimum_age: int, optional, defaults to 0
		:return: pd.Series, with int indices indicating years of
			birth, and estimated counts of total population with that name and generation
		'''
		year_counts = self._age_from_name.get_estimated_counts(first_name,
		                                                       sex,
		                                                       current_year,
		                                                       minimum_age)
		to_ret = self._generational_rollup(year_counts)
		to_ret.name = 'estimated_count'
		return to_ret

	def get_estimated_distribution(self, first_name, sex,
	                               current_year=datetime.now().year,
	                               minimum_age=0):
		'''
		:param first_name: str, First name
		:param sex: str, m or f for sex
		:param current_year: int, optional, defaults to current year
		:param minimum_age: int, optional, defaults to 0
		:return: pd.Series, nd the estimated percentage of the total population of
		people who share sex andfirst name who were born that generation.
		'''
		to_ret = self._generational_rollup(self._age_from_name.get_estimated_distribution
		                                   (first_name, sex, current_year, minimum_age))
		to_ret.name = 'estimate_percentage'
		return to_ret

	def argmax(self, first_name, sex, current_year=datetime.now().year, minimum_age=0):
		'''
		:param first_name: str, First name
		:param sex: str, m or f for sex
		:param current_year: int, optional, defaults to current year
		:param minimum_age: int, optional, defaults to 0
		:return: pd.Series, nd the estimated percentage of the total population of
		people who share sex andfirst name who were born that generation.
		'''
		return self.get_estimated_distribution(first_name, sex,
		                                       current_year, minimum_age).argmax()

	def _generational_rollup(self, year_counts):
		generation_counts = {generation: year_counts[(year_counts.index <= genmax)
		                                             & (year_counts.index >= genmin)].sum()
		                     for generation, (genmin, genmax)
		                     in self._generation_birth_years.items()}
		generation_counts['_other'] = year_counts.sum() - sum(generation_counts.values())
		to_ret = pd.Series(generation_counts)
		return to_ret
