import pandas as pd
from collections import Counter
from ucimlrepo import fetch_ucirepo

def formatData(name_ds):
    # Function to format data from UCI Machine Repositories

    # replace categorical data by integers
    # path = "./data/" + name_ds + "/df_allfeatures.csv"
    # df = pd.read_csv(path)
    # labels = pd.read_csv("./data/" + name_ds + "/labels.csv")
    # Change column names to columns containing categorical values (if any)
    # col =  [] 
    # for c in col:
    #     l = list(df[c].unique())
    #     l.sort()
    #     df[c] = df[c].replace(l, list(range(0, len(df[c].unique()))))
    # print(df)
    # df.to_csv(path, index=False)

    # Fetch data from UCI Repository
    path = "./data/" + name_ds + "/"
    ds = fetch_ucirepo(id=148)
    X = ds.data.features
    y = ds.data.targets
    df = X
    labels = y

    # IF NAN DATA IN ROWS
    if df.isnull().any().any():
        print("NaN alert!")
        print(df.columns[df.isna().any()].tolist())
        indexes = df[df[df.columns[df.isna().any()].tolist()].isna().any(axis=1)].index.tolist()
        if name_ds == "cirrhosis":
            for row in df.iterrows():
                if "NaNN" in row[1].values:
                    indexes.append(row[0])
        print(len(indexes))
        df = df.drop(index=indexes)
        labels = labels.drop(index=indexes)
    
    id = list(range(0, len(labels)))
    df.insert(0, 'id', id)
    labels.insert(0, 'id', id)
    labels = labels.rename(columns={labels.columns[-1]: 'label'})
    print(df)
    print(labels)    
    print(Counter(labels['label']))
    l = list(labels.label.unique())
    l.sort()
    print(l)
    labels = labels.replace(l, list(range(0, len(labels.label.unique()))))
    print(labels)
    # df.to_csv(path + "df_allfeatures.csv", index=False)
    # labels.to_csv(path + "labels.csv", index=False)

def process_dat_file(name_ds):
    # formatting data from .dat file
    path = "./data_bin/" + name_ds + "/" 
    p_file = path + name_ds + ".dat"
    with open(p_file, 'r') as file:
        found = False
        for line in file:
            if "@inputs" in line.strip():
                col_ft = line.split("@inputs ")[1].split("\n")[0].split(", ")
                if '' in col_ft:
                    col_ft.remove('')
                data = pd.DataFrame(columns=col_ft + ['label'])
            elif "@input" in line.strip():
                col_ft = line.split("@input ")[1].split("\n")[0].split(",")
                if '' in col_ft:
                    col_ft.remove('')
                data = pd.DataFrame(columns=col_ft + ['label'])
            if found:
                row = line.split(", ")
                if len(row) == 1:
                    row = row[0]
                    row = line.split(",")
                if ',' in row[-1]:
                    tmp = row[-1]
                    l_tmp = tmp.split(',')
                    row[-1] = l_tmp[0]
                    row.append(l_tmp[1])
                
                for i in range(len(row)):
                    row[i] = row[i].split(" ")[0]
                    try:
                        row[i] = float(row[i])
                    except:
                        row[i] = row[i].split("\n")[0]
                        row[i] = row[i].split("\t")[0]
                        row[i] = row[i].split(" ")[0]
                data.loc[len(data.index)] = row
            if "@data" in line.strip():
                found = True
    id = list(range(0, len(data.index)))
    data.insert(0, 'id', id)
    print(data)
    df = data.drop(columns=['label'])
    labels = data.drop(columns=col_ft)
    print(df)
    print(labels)
    print(Counter(labels['label']))
    l = list(labels.label.unique())
    l.sort()
    print(l)
    labels = labels.replace(l, list(range(0, len(labels.label.unique()))))
    print(labels)
    df.to_csv(path + "df_allfeatures.csv", index=False)
    labels.to_csv(path + "labels.csv", index=False)
    input()

def getWallFollow():
    # function used specifically to format the Wallfollowing dataset
    path = "./data/wallfollowing/"
    data = pd.read_csv(path+"sensor_readings_24.data", sep=',', header=None)
    columns = ["sensor" + str(i+1) for i in range(len(data.columns) - 1)] + ['label']
    data.columns = columns
    id = list(range(0, len(data)))
    data.insert(0, 'id', id)
    df = data[['id'] + columns[:-1]]
    print(df)
    labels = data[['id', 'label']]
    print(Counter(labels['label']))
    l = list(labels.label.unique())
    l.sort()
    print(l)
    labels = labels.replace(l, list(range(0, len(labels.label.unique()))))
    print(labels)
    print(Counter(labels['label']))
    df.to_csv(path + "df_allfeatures.csv", index=False)
    labels.to_csv(path + "labels.csv", index=False)

if __name__ == "__main__":
    name = "shuttle"
    formatData(name)
