#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 16 11:08:06 2020

@author: Anton Kondratev
Implementation of the Bayesian Online Change Point Detection algorithm.
This class processes points one after another.
"""
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import pandas as pd
from   matplotlib.colors import LogNorm

class TBayesian:
    def __init__(self,mu,kappa,alpha,beta, hazard):
        self.mus = np.array([mu])
        self.kappas = np.array([kappa])
        self.alphas = np.array([alpha])
        self.betas = np.array([beta])
        self.lengths = np.array([1])
        self.hazard = hazard
        self.R = []
        self.R.append(np.array([1]))
        self.maxes = []
        self.counter = 0

    def _ret_ppdf(self, x):
        '''
        Returns predictive probability distribution. Generalized t-distribution.
        '''
        df = 2 * self.alphas
        loc = self.mus
        scale = np.sqrt(self.betas*(self.kappas + 1) / (self.alphas * self.kappas))
        return stats.t.pdf(x=x, df = df, loc = loc, scale = scale)

    def iterate(self, i, x):
        '''
        i - iteration step. The class is assumed to run in a loop. This parameter represents the counter of the loop.
        x - data point
        return:
        True, time step - if change is found
        False, 0 - if change was not found
        '''
        self.counter += 1
        pis = self._ret_ppdf(x)
        growth_probs = pis * (1-self.hazard) * self.lengths
        cp_probs = np.sum(pis * self.hazard * self.lengths)
        joint = np.append(cp_probs, growth_probs)
        self.R.append(joint/np.sum(joint))
        self.lengths = np.append(cp_probs, growth_probs)
        self._update_statistics(x)
        self.maxes.append(np.argmax(joint))
        cpoint = np.sum([True for i in range(len(self.maxes)-1) if (self.maxes[i+1] - self.maxes[i] < -3)])
        if cpoint > 0:
            return True, self.counter + 1
        else:
            return False, 0

    def _update_statistics(self, x):
        new_mu = (self.kappas*self.mus + x)/(self.kappas + 1)
        new_kappa = self.kappas + 1
        new_alpha = self.alphas + (1/2)
        new_beta = self.betas + (self.kappas * (x - self.mus)**2) / (2. * (self.kappas + 1.))
        self.mus = np.concatenate(([self.mus[0]],new_mu))
        self.kappas = np.concatenate(([self.kappas[0]],new_kappa))
        self.alphas = np.concatenate(([self.alphas[0]],new_alpha))
        self.betas = np.concatenate(([self.betas[0]],new_beta))

    def get_results(self):
        change_point = 0
        retR = np.zeros((len(self.R),len(self.R)))
        for i in range(len(self.R)):
            l = len(self.R[i])
            retR[i,:l] = self.R[i]
        for i in range(len(self.maxes)-1):
            if self.maxes[i+1] - self.maxes[i] < -5:
                change_point = i+1
                break
        return retR, self.maxes, change_point




if __name__ == "__main__":
    directory = "/media/synology/Documents/TUT/MSThesis/Scripts/Data/Baseline/"
    file1 = 'Lemon peelFlaskBaseline1.mat.csv'
    data = pd.read_csv(directory + file1)
    ims_abs = data['IMS_abs14']
    ts = [ims_abs[i+1] - ims_abs[i] for i in range(len(ims_abs)-1)]
    bs = TBayesian(0,1,1,1/50,1/50)
    for i in range(1,len(ts)):
        x = ts[i-1]
        bs.iterate(i,x)
    res, maxes, cp = bs.get_results()
    print(maxes)
    fig, ax = plt.subplots(1,2)
    norm = LogNorm(vmin=0.0001, vmax=1)
    plt.rcParams['figure.figsize'] = [10, 30]
    ax[0].plot(ims_abs)
    ax[0].axvline(cp, color='red')
    ax[1].imshow(np.rot90(res), cmap='gray_r', norm=norm)
    plt.show()


    

    