import copy
import random
import geopandas as gpd 
import pandas as pd 
import shapely
import sqlite3

from utils.parse import parse_config

def create_synth_pop(config):
    """Step through pop creation for given site."""

    hh_swarm, ind_swarm = load_data()

    params = parse_config(config)

    sites = ['chikwawa', 'chileka', 'ndirande']

    for site in sites:
        if params[site] == 'False':
            continue
        else:
            pop = int(params[site+'_pop'])

            print(f"\nSite: {site}")

            hh_sample, pp_sample = load_swarm_site(site, hh_swarm, ind_swarm)
            print(f"Loaded {len(hh_sample)} households, and {len(pp_sample)} people from SWARM data.")

            buildings = load_buildings(site, hh_swarm, (params['osm_building_size_filter_min'], params['osm_building_size_filter_max']))
            print(f"Loaded {len(buildings)} building geometries.")

            num_real_hhs, num_jitter_hhs, synth_hhs_gdf = fill_houses(site, pop, buildings, hh_sample, int(params['building_jitter']))
                                                                      
            print(f"Populated {num_real_hhs} buildings, and {num_jitter_hhs} generated building geometries.")

            complete_households(site, pp_sample, synth_hhs_gdf)

    print('-------------------------------------------------------')


def load_buildings(site, swarm, area_filter):
    """load filtered building geometries for given site."""
    if site == 'chileka':
        buildings = gpd.read_file("../data/init_raw_data/osm_buildings/malawi_buildings_filter.gpkg",
                                  layer=site,
                                  driver="GPKG").to_crs(3857)
        # filter by area
        buildings = buildings[(buildings['area'] > int(area_filter[0])) & (buildings['area'] < int(area_filter[1]))]

    elif site == 'chikwawa':
        buildings = swarm[swarm['loc'] == site].reset_index()
        buildings = gpd.GeoDataFrame(buildings, crs='EPSG:4326', geometry=gpd.points_from_xy(buildings['lon'], buildings['lat']))
        buildings = buildings.to_crs(3857)

    elif site == 'ndirande':
        buildings2016 = gpd.read_file("../data/init_raw_data/strataa/house_footprints/houses_2016_poly.gpkg",
                                      driver="GPKG").to_crs(3857)
        buildings2018 = gpd.read_file("../data/init_raw_data/strataa/house_footprints/houses_2018_poly.gpkg",
                                      driver="GPKG").to_crs(3857)
        buildings = pd.concat([buildings2016, buildings2018], ignore_index=True)
        buildings = gpd.GeoDataFrame(buildings, crs='EPSG:3857', geometry='geometry')

    buildings = buildings[['geometry']]
    return buildings


def load_swarm_site(site, house_swarm, people_swarm):
    """load hh and pop sample data to populate geometries, 
       filter hh and pop to 'site' and drop 'real world' location identifiers."""
    house_swarm = house_swarm[house_swarm['loc'] == site].drop(columns=['lat', 'lon', 'loc'])
    people_swarm = people_swarm[people_swarm['loc'] == site].drop(columns=['lat', 'lon', 'loc'])
    return house_swarm, people_swarm


def fill_houses(site, pop, possible_houses, hh_sample, jitter):
    """Add households from Swarm to osm houses until worldpop estimate is met. May require extra houses."""
    synth_pop_count = 0
    newhhid = 0

    osm_building_count = 0
    generated_building_count = 0

    synth_population_hh = pd.DataFrame()

    while synth_pop_count < pop:

        if len(possible_houses) > 0:
            # get random house footprint
            random_house = possible_houses.sample(n=1)
            possible_houses = possible_houses.drop(random_house.index)
            random_house = random_house.reset_index(drop=True)
            # get random household
            random_household = hh_sample.sample(n=1).reset_index(drop=True)
            random_household['new_hhid'] = newhhid
            random_household['geometry'] = random_house['geometry']
            # add hh to pop
            synth_population_hh = pd.concat([synth_population_hh, random_household], ignore_index=True)

            synth_pop_count += int(random_household['hhcount'].values[0])
            newhhid +=1 
            osm_building_count += 1      
        else:
            # resample buildings
            random_house_to_jitter = synth_population_hh.sample(n=1)
            new_x = random_house_to_jitter['geometry'].values[0].x + random.randint(-jitter, jitter)
            new_y = random_house_to_jitter['geometry'].values[0].y + random.randint(-jitter, jitter)
            # get random household
            random_household = hh_sample.sample(n=1).reset_index(drop=True)
            random_household['new_hhid'] = newhhid
            random_household['geometry'] = shapely.geometry.Point(new_x, new_y)
            # add hh to pop
            synth_population_hh = pd.concat([synth_population_hh, random_household], ignore_index=True)

            synth_pop_count += int(random_household['hhcount'].values[0])
            newhhid +=1
            generated_building_count += 1

    gdf = gpd.GeoDataFrame(synth_population_hh, crs="EPSG:3857", geometry=synth_population_hh['geometry'])
    # gdf.to_file(f"../data/synthetic_population/synthetic_pop.gpkg", layer=f'houses-{site}', driver="GPKG")
    gdf.to_csv(f"../data/synthetic_population/synthetic_{site}_households.csv")

    return osm_building_count, generated_building_count, gdf


def complete_households(site, pp_sample, synth_households):
    """Add people to household, e.g. people from the Swarm individual table to the Swarm household table.
    Where the number of people matches the household hhcount, great. Otherwise, compile a list of partially 
    filled houses to resample."""

    synth_population_ind = pd.DataFrame()

    incomplete_hhs = []
    count = 0

    for i, row in synth_households.iterrows():

        # household population size from 'HouseholdEnrolment' table.
        expected_hh_pop = int(row['hhcount'])
        # people tagged with this household in the 'Individual' table, in df form.
        recorded_hh_people = pp_sample[pp_sample['hhid']==row['hid']].copy().reset_index(drop=True)
        # number of people people tagged with this household in the 'Individual' table.
        actual_hh_pop = len(recorded_hh_people)
        # identify and collate households with insufficient sampled pop...
        if expected_hh_pop != actual_hh_pop:
            count+=expected_hh_pop
            incomplete_hhs.append(row['new_hhid'])
        # ...otherwise, add household pop to synthetic individuals df.
        else:
            recorded_hh_people['geometry'] = row['geometry']
            recorded_hh_people['new_hhid'] = row['new_hhid']
            synth_population_ind = pd.concat([synth_population_ind, recorded_hh_people], ignore_index=True)

    print(f"Assigned {len(synth_population_ind)} people with exact household composition.")
    print(f"Resampling household top-up/trim based on similarity score for the remaining {count} people ({len(incomplete_hhs)} households).")

    synth_population_ind = complete_incomplete_hhs(synth_households, pp_sample, synth_population_ind, incomplete_hhs)

    print(f"Final pop size: {len(synth_population_ind)} people.")
    
    gdf = gpd.GeoDataFrame(synth_population_ind, crs="EPSG:3857", geometry=synth_population_ind['geometry'])
    # gdf.to_file(f"../data/synthetic_population/synthetic_pop.gpkg", layer=f'people-{site}', driver="GPKG")
    gdf.to_csv(f"../data/synthetic_population/synthetic_{site}_people.csv")


def complete_incomplete_hhs(synth_households, pp_sample, partial_ind_pop, hhs_with_missing_pp):
    """Fill incomplete houses, return complete ind pop."""

    print(f"Rolling pop count:")

    for i, new_house_id in enumerate(hhs_with_missing_pp):

        synth_hhs = copy.deepcopy(synth_households)
        ind_sample = copy.deepcopy(pp_sample)

        synth_hhs['hhcount'] = synth_hhs['hhcount'].astype(int)
        synth_hhs['hh_adults'] = synth_hhs['hh_adults'].astype(int)
        synth_hhs['hh_adolescents'] = synth_hhs['hh_adolescents'].astype(int)
        synth_hhs['hh_child'] = synth_hhs['hh_child'].astype(int)
        synth_hhs['hh_youngchild'] = synth_hhs['hh_youngchild'].astype(int)

        ind_sample['age'] = ind_sample['age'].astype(int)

        # identify required household age composition in missing households
        idx = synth_hhs[synth_hhs['new_hhid']==new_house_id].index[0]
        required_composition_data = synth_hhs.iloc[idx]
        required_num_people = required_composition_data['hhcount']

        # identify sampled num people and age composition in missing house
        current_people = ind_sample[ind_sample['hhid']==required_composition_data['hid']]
        current_num_people = len(current_people)

        # required number of people of certain agebands

        # hh_adults (>=16)
        # hh_adolescents (>=5 & <16)
        # hh_child (>=2 & <5)
        # hh_youngchild (<2)

        # although some ambiguity in data dictionaries around use of <> etc.

        required_num_adults = required_composition_data['hh_adults'] - len(current_people[current_people['age'] >=16])

        required_num_adolescents = (required_composition_data['hh_adolescents'] -
                                    len(current_people[(current_people['age'] >= 5) & (current_people['age'] < 16)]))

        required_num_children = (required_composition_data['hh_child'] -
                                 len(current_people[(current_people['age'] >= 2) & (current_people['age'] < 5)]))

        required_num_young_children = required_composition_data['hh_youngchild'] - len(current_people[current_people['age'] < 2])

        # identify similar households
        # Method: test equality between between fields in rows. Similar values are scored 1. High scores are similar.

        similar = required_composition_data == synth_hhs
        synth_hhs['scores'] = similar.astype(int).sum(axis=1)
        idx = synth_hhs[synth_hhs['hid']==required_composition_data['hid']].index
        synth_hhs.drop(idx, inplace=True)
        synth_hhs.drop_duplicates(subset=['hid'], inplace=True)
        ind_sample = ind_sample.merge(synth_hhs[['hid', 'scores']], how='left', left_on='hhid', right_on='hid').drop(columns=['hid'])
        ind_sample.sort_values(by='scores', ascending=False, inplace=True)

        if len(current_people) < required_composition_data['hhcount']:

            for _, ind in ind_sample.iterrows():

                if ind['age'] >=16 and required_num_adults > 0 and len(current_people) < required_composition_data['hhcount']:
                    current_people = current_people.append(ind)
                    required_num_adults -= 1

                if ind['age'] <2 and required_num_young_children > 0 and len(current_people) < required_composition_data['hhcount']:
                    current_people = current_people.append(ind)
                    required_num_young_children -= 1

                if (ind['age'] >=2 and ind['age'] <5) and required_num_children > 0 and len(current_people) < required_composition_data['hhcount']:
                    current_people = current_people.append(ind)
                    required_num_children -= 1

                if (ind['age'] >=5 and ind['age'] <16) and required_num_adolescents > 0 and len(current_people) < required_composition_data['hhcount']:
                    current_people = current_people.append(ind)
                    required_num_adolescents -= 1

                # a check to catch situations where hhcount is given as higher than sample pop available, but age breakdown numbers are lower,
                # e.g. match the sample pop. 
                if (len(current_people) < required_composition_data['hhcount'] and required_num_adolescents < 1 and
                    required_num_adults < 1 and required_num_children < 1 and required_num_young_children < 1):
                    # e.g. add next best person.
                    current_people = current_people.append(ind)

        # randomly trim hh size
        else: 
            current_people = current_people.sample(n=required_composition_data['hhcount'])

        current_people['geometry'] = required_composition_data['geometry']
        current_people['new_hhid'] = required_composition_data['new_hhid']
        partial_ind_pop = pd.concat([partial_ind_pop, current_people], ignore_index=True)

        if i % 100 == 0:
            print(f"\tPop size: {len(partial_ind_pop)}")

    return partial_ind_pop


def load_data():

    ## Filter Requirements

    prop_no_val = 0.4
    sim_col_vals = 0.70

    ########################################################################################################################################

    # connect to .db
    with sqlite3.connect('../data/init_raw_data/swarmdb/swarm.db') as conn:
        conn.text_factory = lambda b: b.decode(errors = 'ignore')

    # households
    query = "select * from HouseholdEnrolment"
    house_data = pd.read_sql_query(query, conn)

    # locations
    query = "select pid, longitude, latitude from Locations"
    locs_data = pd.read_sql_query(query, conn)

    # individuals
    query = "select * from Individual"
    ind_data = pd.read_sql_query(query, conn)

    # identify locations of households
    locs_unique_hh_ids = list(set(locs_data['pid'].values.tolist()))
    hh_unique_hh_ids = list(set(house_data['hid'].values.tolist()))
    hhid_locs_and_hh = [entry for entry in hh_unique_hh_ids if entry in locs_unique_hh_ids]

    valid_hhid_locs_and_hh = [entry for entry in hhid_locs_and_hh if 
    entry.startswith('HH') and 
    (34 <= float(locs_data[locs_data["pid"]==entry]['longitude'].values.tolist()[0]) <= 37) and 
    (-17 <= float(locs_data[locs_data["pid"]==entry]['latitude'].values.tolist()[0]) <= -14)]
    count = 0
    for row in ind_data.itertuples(index=False):   
        if row[6] in valid_hhid_locs_and_hh:
            count+=1

    chikwawa_house = {}
    ndirande_house = {}
    chileka_house = {}

    for entry in hhid_locs_and_hh:
        if entry.startswith('HH'):
            if float(locs_data[locs_data["pid"]==entry]['latitude'].values.tolist()[0]) < -15.90:
                chikwawa_house[entry] = {}
                chikwawa_house[entry]['lat'] = float(locs_data[locs_data["pid"]==entry]['latitude'].values.tolist()[0])
                chikwawa_house[entry]['lon'] = float(locs_data[locs_data["pid"]==entry]['longitude'].values.tolist()[0])

            elif float(locs_data[locs_data["pid"]==entry]['latitude'].values.tolist()[0]) < -15.75:
                ndirande_house[entry] = {}
                ndirande_house[entry]['lat'] = float(locs_data[locs_data["pid"]==entry]['latitude'].values.tolist()[0])
                ndirande_house[entry]['lon'] = float(locs_data[locs_data["pid"]==entry]['longitude'].values.tolist()[0])
            else:
                chileka_house[entry] = {}
                chileka_house[entry]['lat'] = float(locs_data[locs_data["pid"]==entry]['latitude'].values.tolist()[0])
                chileka_house[entry]['lon'] = float(locs_data[locs_data["pid"]==entry]['longitude'].values.tolist()[0])

    print(f"Initial Variables: {len(list(house_data))}")

    house_data = house_data.drop(columns=['crf_ver', 'pin_name', 'hh_meat', 'data_date', 'hh_name', 'pid', 'hh_hhh', 'dob', '_dob_dtype']) ## meat removed as too many diff responses
       
    print(f"Filter 1: drop date, enumerator, crf version etc.: {len(list(house_data))}")

    house_data = house_data[house_data.columns[house_data.applymap(lambda x: x == '').mean() < prop_no_val]]

    print(f"Filter 2: less than {prop_no_val*100}% 'No Value': {len(list(house_data))}")

    to_drop = []
    for c in list(house_data):
        if c != 'hh_youngchild':
            if house_data.pivot_table(index=c, aggfunc='size').nlargest()[0] > sim_col_vals*len(house_data):
                to_drop.append(c)
    house_data = house_data.drop(columns=to_drop)

    print(f"Filter 3: less than {sim_col_vals*100}% of respondents giving like answers: {len(list(house_data))}")

    to_drop = []
    for c in list(house_data):
        if c != 'hh_youngchild':
            if house_data[c].nunique() < 4:
                to_drop.append(c)
    house_data = house_data.drop(columns=to_drop)

    print(f"Filter 4: >4 different responses: {len(list(house_data))}")
    print("")

    # Add info about house location

    for i,j in house_data.iterrows():
        if j['hid'] in list(chikwawa_house.keys()):
            house_data.loc[i,'lat'] = chikwawa_house[j['hid']]['lat']
            house_data.loc[i,'lon'] = chikwawa_house[j['hid']]['lon']
            house_data.loc[i,'loc'] = 'chikwawa'
        elif j['hid'] in list(ndirande_house.keys()):
            house_data.loc[i,'lat'] = ndirande_house[j['hid']]['lat']
            house_data.loc[i,'lon'] = ndirande_house[j['hid']]['lon']
            house_data.loc[i,'loc'] = 'ndirande'
        elif j['hid'] in list(chileka_house.keys()):
            house_data.loc[i,'lat'] = chileka_house[j['hid']]['lat']
            house_data.loc[i,'lon'] = chileka_house[j['hid']]['lon']
            house_data.loc[i,'loc'] = 'chileka'

#############################################################################
# ind

    print(f"Initial Variables: {len(list(ind_data))}")

    ind_data = ind_data.drop(columns=['start', 'crf_ver', 'endd', 'ennumerator', 'data_date', 'site', 'pid', 'ip_dob', '_dob_dtype', 'ip_guardian'])

    print(f"Filter 1: drop date, enumerator, crf version etc.: {len(list(ind_data))}")

    ind_data = ind_data[ind_data.columns[ind_data.applymap(lambda x: x == '').mean() < prop_no_val]]

    print(f"Filter 2: less than {prop_no_val*100}% 'No Value': {len(list(ind_data))}")

    to_drop = []
    for c in list(ind_data):
        if ind_data.pivot_table(index=c, aggfunc='size').nlargest()[0] > sim_col_vals*len(ind_data):
            to_drop.append(c)
    ind_data = ind_data.drop(columns=to_drop)

    print(f"Filter 3: less than {sim_col_vals*100}% of respondents giving like answers: {len(list(ind_data))}")

    # Add info about house location

    for i,j in ind_data.iterrows():
        if j['hhid'] in list(chikwawa_house.keys()):
            ind_data.loc[i,'lat'] = chikwawa_house[j['hhid']]['lat']
            ind_data.loc[i,'lon'] = chikwawa_house[j['hhid']]['lon']
            ind_data.loc[i,'loc'] = 'chikwawa'
        elif j['hhid'] in list(ndirande_house.keys()):
            ind_data.loc[i,'lat'] = ndirande_house[j['hhid']]['lat']
            ind_data.loc[i,'lon'] = ndirande_house[j['hhid']]['lon']
            ind_data.loc[i,'loc'] = 'ndirande'
        elif j['hhid'] in list(chileka_house.keys()):
            ind_data.loc[i,'lat'] = chileka_house[j['hhid']]['lat']
            ind_data.loc[i,'lon'] = chileka_house[j['hhid']]['lon']
            ind_data.loc[i,'loc'] = 'chileka'

    # house_data.to_csv("HH_data_col_trim.csv", index=False)
    # ind_data.to_csv("IND_data_col_trim.csv", index=False)

    return house_data, ind_data
