package backend.analysis.machineLearning;

import weka.attributeSelection.AttributeSelection;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.NominalPrediction;
import weka.classifiers.trees.RandomForest;
import weka.core.FastVector;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.supervised.instance.SMOTE;
import weka.filters.unsupervised.attribute.Normalize;

import java.util.ArrayList;
import java.util.List;

public class ClassifierLeaveOneOut {

    private Instances data;

    private Instances trainingData;
    private Instances testData;

    private AttributeSelection attSelect;

    public int correctClassifiedInstances;
    public int wronglyClassifiedInstances;

    List<ClassResults> classResults;


    public ClassifierLeaveOneOut() {
        classResults = new ArrayList<>();
        classResults.add(new ClassResults("Evolvability"));
        classResults.add(new ClassResults("Functional"));
    }


    public void setData(Instances trainingData, Instances testData){
        Normalize normalize = new Normalize();
        try {
            normalize.setInputFormat(trainingData);
            this.trainingData = Filter.useFilter(trainingData, normalize);
            this.testData = Filter.useFilter(testData, normalize);

        } catch (Exception e) {
            e.printStackTrace();
        }

        //this.trainingData = setAttributeSelection(trainingData);
    }


    public void filterTrainingData(String[] opt){
        try{
            trainingData.setClassIndex(testData.numAttributes()-1);
            //setAttributeSelection(trainingData);
        }catch (Exception e){
            System.out.println("Problem in Data Filtering. Check option for filtering");
            e.printStackTrace();
        }
    }


    public void filterTestData(String[] opt){
        try{
            testData.setClassIndex(testData.numAttributes()-1);
            //setAttributeSelection(testData);
        }catch (Exception e){
            System.out.println("Problem in Data Filtering. Check option for filtering");
            e.printStackTrace();
        }
    }


    public void filterTrainingDataWithSMOTE(String[] opt){
        try{
            System.out.println("Before smote " +trainingData.numInstances());
            trainingData.setClassIndex(trainingData.numAttributes()-1);
            testData.setClassIndex(testData.numAttributes()-1);
            SMOTE smote = new SMOTE();
            smote.setOptions(new String[]{"-P","100"}); // change this parameter
            smote.setInputFormat(trainingData);
            trainingData = Filter.useFilter(trainingData, smote);
            System.out.println("After smote: "+ trainingData.numInstances());
        }catch (Exception e){
            System.out.println("Problem in Data Filtering. Check option for filtering");
            e.printStackTrace();
        }
    }


    public List<ClassifierResult> runRandomForestWithTestsSetLeaveOneOut(boolean isGroup) {
        List<ClassifierResult> results = new ArrayList<>();
        RandomForest tree = new RandomForest();
        try {
            filterTrainingDataWithSMOTE(new String[]{"-R","47"});
            tree.buildClassifier(trainingData);
            Evaluation eval = new Evaluation(trainingData);
            eval.evaluateModel(tree, testData);
            System.out.println(eval.toSummaryString());
            System.out.println(eval.toMatrixString());

            if(isGroup) {
                results.add(new ClassifierResult("Random forest","Documentation", eval.precision(0), eval.recall(0), eval.fMeasure(0),
                        eval.areaUnderROC(0), eval.matthewsCorrelationCoefficient(0)));
                results.add(new ClassifierResult("Random forest","Visual representation", eval.precision(1), eval.recall(1), eval.fMeasure(1),
                        eval.areaUnderROC(1), eval.matthewsCorrelationCoefficient(1)));
                results.add(new ClassifierResult("Random forest","Structure", eval.precision(2), eval.recall(2), eval.fMeasure(2),
                        eval.areaUnderROC(2), eval.matthewsCorrelationCoefficient(2)));
                results.add(new ClassifierResult("Random forest","Functional", eval.precision(3), eval.recall(3), eval.fMeasure(3),
                        eval.areaUnderROC(3), eval.matthewsCorrelationCoefficient(3)));
            }
            else {
                results.add(new ClassifierResult("Random forest","Evolvability", eval.precision(0), eval.recall(0), eval.fMeasure(0),
                        eval.areaUnderROC(0), eval.matthewsCorrelationCoefficient(0)));
                results.add(new ClassifierResult("Random forest","Functional", eval.precision(1), eval.recall(1), eval.fMeasure(1),
                        eval.areaUnderROC(1), eval.matthewsCorrelationCoefficient(1)));
            }

            System.out.println("\n Random forest \n");

            System.out.println("Precision 0: "+ eval.precision(0));
            System.out.println("Precision 1: "+eval.precision(1));
            if(isGroup) {
                System.out.println("Precision 2: "+eval.precision(2));
                System.out.println("Precision 3: "+eval.precision(3));
            }

            System.out.println("\n");
            System.out.println("Recall 0: "+eval.recall(0));
            System.out.println("Recall 1: "+eval.recall(1));
            if(isGroup) {
                System.out.println("Recall 2: "+eval.recall(2));
                System.out.println("Recall 3: "+eval.recall(3));
            }

            System.out.println("\n");
            System.out.println("F-Measure 0:  "+eval.fMeasure(0));
            System.out.println("F-Measure 1:  "+eval.fMeasure(1));
            if(isGroup) {
                System.out.println("F-Measure 2:  "+eval.fMeasure(2));
                System.out.println("F-Measure 3:  "+eval.fMeasure(3));
            }

            System.out.println("\n");
            System.out.println("AUC 0 : "+ eval.areaUnderROC(0));
            System.out.println("AUC 1 : "+ eval.areaUnderROC(1));
            if(isGroup) {
                System.out.println("AUC 2 : "+ eval.areaUnderROC(2));
                System.out.println("AUC 3 : "+ eval.areaUnderROC(3));
            }

            System.out.println("\n");
            System.out.println("MCC 0 : "+ eval.matthewsCorrelationCoefficient(0));
            System.out.println("MCC 1 : "+ eval.matthewsCorrelationCoefficient(1));
            if(isGroup) {
                System.out.println("MCC 2 : "+ eval.matthewsCorrelationCoefficient(2));
                System.out.println("MCC 3 : "+ eval.matthewsCorrelationCoefficient(3));
            }

            System.out.println("Class Attribute" + testData.classAttribute());
            FastVector predictions = new FastVector();
            predictions.appendElements(eval.predictions());
            for (int i = 0; i < predictions.size(); i++) {
                NominalPrediction np = (NominalPrediction) predictions.elementAt(i);
                System.out.println("np actual: "+np.actual());
                System.out.println("np predicited " + np.predicted());
                if(np.actual() == 0) {
                    if (np.predicted() == np.actual()) {
                        classResults.get(0).truePositive++;
                    } else {
                        if(np.predicted() == 1) {
                            classResults.get(0).falseNegative++;
                            classResults.get(1).falsePositive++;
                        }
                    }
                }
                else if(np.actual() == 1) {
                    if (np.predicted() == np.actual()) {
                        classResults.get(1).truePositive++;
                    } else {
                        if(np.predicted() == 1) {
                            classResults.get(1).falseNegative++;
                            classResults.get(0).falsePositive++;
                        }
                    }
                }

            }

        } catch (Exception e) {
            e.printStackTrace();
        }
        return results;
    }


}
