
from main_functions import *


if __name__ == "__main__":
    warnings.simplefilter('ignore', category=ConstantInputWarning)
    warnings.simplefilter('ignore', category=RuntimeWarning)
    warnings.simplefilter('ignore', category=UserWarning)
    
    # Possible actions are:
    #   - table_ds: get descriptive tables of the real datasets that are used
    #   - synth_data: run study on synthetic datasets with varying data distribution
    #   - synth_ft: run study on synthetic datasets with varying feature components
    #   - synth_data_plot: plot results of action synth_data
    #   - synth_ft_plot: plot results of action synth_ft
    #   - real_data: run study on the 70 real datasets
    #   - real_data_plot: plot results of action real_data (regression lines for all, binary, and multiclass datasets)
    #   - ablation: run ablation study with SIMBA 4 versions: whole, no normalization, no feature importance, and no
    #                                                         feature redundancy
    #   - ablation_plot: plot results of action ablation
    #   - complexity: run the control study with data complexity measures
    
    action = 'complexity'
    eval = "f1" # "f1" or "gmean"
    N = 100 # nb runs for synthetic datasets
    scenarios = 'all' # 'all', 'a', 'b', or 'c': scenarios to run on synthetic datasets
    check_all(action, eval, N, scenarios)

    if action == "table_ds":
        only_class_freq = False # if False, generates table with only mean, sd for class frequency
                                # if True, generates table with detailed class frequencies
        # List of all datasets = separated between multiclass & binary datasets
        l_datasets = ['abalone', 'balance', 'cardio10', 'cardio3', 'chess', 
                    'connect-4', 'contraceptive', 'dermatology', 'drybean', 'ecoli', 
                    'glass', 'hayes-roth', 'knowledge', 'landsat', 'lenses', 
                    'loc_build', 'loc_floor', 'lymphography', 'new-thyroid', 'obesity', 
                    'pageblocks', 'penbased', 'room', 'shuttle', 'soybean', 
                    'steel', 'student', 'theorem', 'thyroid', 'wallfollowing', 
                    'webphishing', 'wholesale', 'wine', 'wine-quality', 'yeast']
        l_datasets_bin = ['abalone-20_vs_8-9-10', 'abalone9-18', 'adult', 'banknote', 'bankruptcy',
                        'breastcancer', 'cleveland-0_vs_4', 'credit', 'dermatology-6', 'ecoli-0-1-4-7_vs_2-3-5-6', 
                        'ecoli1', 'glass-0-1-5_vs_2', 'glass2', 'glass4', 'glass5', 
                        'glass6', 'htru2', 'ionosphere', 'led7digit-0-2-4-5-6-7-8-9_vs_1', 'new-thyroid1', 
                        'page-blocks0', 'poker-9_vs_7', 'purchase', 'segment0', 'skin',
                        'spambase', 'spect-heart', 'vehicle0', 'vowel0', 'winequality-red-3_vs_5', 
                        'wisconsin', 'yeast-1-2-8-9_vs_7', 'yeast3', 'yeast4', 'yeast5']
        all_ds = l_datasets + l_datasets_bin
        all_ds.sort()
        print(all_ds)
        make_DS_descriptive_table(all_ds, only_class_freq=only_class_freq)
    
    elif action == "synth_data":
        # SYNTHETIC DATASETS
        # VARYING DATA DISTRIBUTION, BUT FIXED FEATURES (10 FT, 2 INFORMATIVE)

        if scenarios == 'all' or scenarios == 'a':
            run_synthdata(N = N, data_distrib = [100, 5, 0], feature_distrib = [2, 0, 8], tot_ft = 10, i_range=[5, 101, 1], 
                        title='Datasets (a): (100, 5, i)', corr=True, varying="data", eval_method=eval)

        if scenarios == 'all' or scenarios == 'b':
            run_synthdata(N = N, data_distrib = [50, 50, 0], feature_distrib = [2, 0, 8], tot_ft = 10, i_range=[5, 101, 1], 
                        title='Datasets (b): (50, 50, i)', corr=True, varying="data", eval_method=eval)
        
        if scenarios == 'all' or scenarios == 'c':
            run_synthdata(N = N, data_distrib = [10, 50, 100], feature_distrib = [2, 0, 8], tot_ft = 10, i_range=[5, 101, 1], 
                        title='Datasets (c): (10i, 50i, 100i)', corr=True, varying="data", eval_method=eval)

        show_synthdata_res(varying='data', eval_method=eval, scenarios=scenarios)

    elif action == "synth_ft":
        # SYNTHETIC DATASETS
        # VARYING FEATURE DISTRIBUTION, BUT FIXED DATA DISTRIBUTION (400, 75, 25)

        if scenarios == 'all' or scenarios == 'a':
            run_synthdata(N = N, data_distrib = [400, 75, 25], feature_distrib = [-1, 0, 0], tot_ft = 50, i_range=[2, 51, 1], 
                        title='Datasets with i out of 50 informative features, dataset (400, 75, 25)', 
                        corr=True, varying="ft", eval_method=eval)
            
        if scenarios == 'all' or scenarios == 'b':
            run_synthdata(N = N, data_distrib = [400, 75, 25], feature_distrib = [5, -1, 0], tot_ft = -1, i_range=[0, 46, 1], 
                        title='Datasets with 5+i features, 5 informative, and i redundant, dataset (400, 75, 25)', 
                        corr=True, varying="ft", eval_method=eval)
            
        if scenarios == 'all' or scenarios == 'c':
            run_synthdata(N = N, data_distrib = [400, 75, 25], feature_distrib = [0, -1, 0], tot_ft = 50, i_range=[0, 49, 1], 
                        title='Datasets with 50 features, 50-i informative, and i redundant, dataset (400, 75, 25)', 
                        corr=True, varying="ft", eval_method=eval)

        show_synthdata_res(varying='ft', eval_method=eval, scenarios=scenarios)
        
    elif action == "synth_data_plot":
        # PLOTS
        run_synthdata_plot(N=N, varying='data', eval_method=eval, from_file=True)
    
    elif action == "synth_ft_plot":
        # PLOTS
        run_synthdata_plot(N=N, varying='ft', eval_method=eval, from_file=True)

    elif action == "real_data":
        # REAL DATASETS with 5 common ML classifiers: SVM, LDA, RF, kNN, MLP
        clfs = ["SVM", "LDA", "RF", "kNN", "MLP"]
        for clf in clfs:
            print(clf + " running...")
            run_realdata(classif=clf, eval_method=eval)
            print(clf + " done!")
            print()

    elif action == "real_data_plot":
        # Regression plot real DS, MULTI+BIN, BIN, MULTI
        clfs = ["SVM", "LDA", "RF", "kNN", "MLP"]
        for clf in clfs:
            show_realdata_res(classif=clf, eval_method=eval)

    elif action == 'ablation':
        print("Ablation study running...")
        run_real_ablation(classif = 'SVM', eval_method=eval)
                
    elif action == 'ablation_plot':
        plot_real_ablation(classif="SVM", eval_method=eval)

    elif action == "complexity":
        print('Data complexity analysis running...')
        clf = 'SVM'
        data_complexity_analysis(classif=clf, eval_method=eval)
        show_data_complexity(classif=clf, eval_method=eval)