#include "otbVectorImage.h"
#include "otbImage.h"
#include "otbImageFileReader.h"
#include "otbImageFileWriter.h"
#include "otbSegmentCharacteristicsFilter.h"
#include "otbUnsupervisedSegmentationCriteriaFilter.h"
#include "otbPersistentFilterStreamingDecorator.h"

typedef unsigned int                                                         LabelType;
typedef otb::Image<LabelType>                                                LabelImageType;
typedef otb::ImageFileReader<LabelImageType>                                 LabelReaderType;

typedef unsigned short                                                       ComponentType;
typedef otb::VectorImage<ComponentType>                                      VectorImageType;
typedef otb::ImageFileReader<VectorImageType>                                VectorReaderType;

typedef otb::ImageFileWriter<LabelImageType>                                 LabelWriterType;


typedef typename otb::VectorImage<double>::PixelType                         SampleType;

typedef otb::SegmentCharacteristicsFilter<VectorImageType,LabelImageType>      CharFilter;
typedef otb::UnsupervisedSegmentationCriteriaFilter<VectorImageType,LabelImageType>      CritFilter;

typedef otb::PersistentFilterStreamingDecorator<CharFilter> StreamingCharFilterType;
typedef otb::PersistentFilterStreamingDecorator<CritFilter> StreamingCritFilterType;

typedef std::unordered_map<LabelType, unsigned int> LabeledIntContainerType;
typedef typename LabeledIntContainerType::iterator LabeledIntContainerIteratorType;
typedef std::pair<LabelType,unsigned int> LabelIntPairType;
typedef std::unordered_map<LabelType,SampleType> LabeledSampleContainerType;
typedef typename LabeledSampleContainerType::iterator LabeledSampleContainerIteratorType;
typedef std::pair<LabelType,SampleType> LabelSamplePairType;
typedef itk::VariableSizeMatrix<double> Matrix;
typedef std::unordered_map<LabelType,Matrix> LabeledMatrixContainerType;
typedef typename LabeledMatrixContainerType::iterator LabeledMatrixContainerIteratorType;
typedef std::pair<LabelType,Matrix> LabelMatrixPairType;




int main(int argc, char* argv[])
{
  
  if(argc!=4)
    {
      std::cerr<<"Usage: "<<argv[0]<<" image image_seg outlier_result"<<std::endl;
      return EXIT_FAILURE;
    }
  //Read the input image
  VectorReaderType::Pointer vreader = VectorReaderType::New();
  vreader->SetFileName(argv[1]);
  vreader->Update();
  // Read the input segmentation
  LabelReaderType::Pointer lreader = LabelReaderType::New();
  lreader->SetFileName(argv[2]);
  lreader->Update();
  itk::ImageRegionConstIterator<VectorImageType> vit(vreader->GetOutput(),vreader->GetOutput()->GetLargestPossibleRegion());
  itk::ImageRegionConstIterator<LabelImageType> lit(lreader->GetOutput(),lreader->GetOutput()->GetLargestPossibleRegion());
  unsigned int nbComps = vreader->GetOutput()->GetNumberOfComponentsPerPixel();
  unsigned int nbPixels = vreader->GetOutput()->GetLargestPossibleRegion().GetNumberOfPixels();
    
  typedef  std::unordered_map<LabelType,SampleType> LabeledSampleContainerType;
  typedef typename LabeledSampleContainerType::iterator LabeledSampleContainerIteratorType;
  typedef std::pair<LabelType,SampleType> LabelSamplePairType;
  LabeledSampleContainerType means;

  typedef itk::VariableSizeMatrix<double> Matrix;
  typedef std::unordered_map<LabelType,Matrix> LabeledMatrixContainerType;
  typedef typename LabeledMatrixContainerType::iterator LabeledMatrixContainerIteratorType;
  typedef std::pair<LabelType,Matrix> LabelMatrixPairType;
  LabeledMatrixContainerType vars;
  LabeledMatrixContainerType invVariances;
  // ListSampleType::Pointer segList = ListSampleType::New();
  // segList->SetMeasurementVectorSize(2*nbComps);
  SampleType max(nbComps);
  max.Fill(0.0);
  for(vit.GoToBegin(),lit.GoToBegin();!vit.IsAtEnd() && !lit.IsAtEnd();++vit,++lit)
    {
      for(unsigned int i = 0 ; i < nbComps ; ++i)
  	{
  	  if(vit.Get()[i] > max[i])
  	    max[i]=vit.Get()[i];
  	}
      LabelType label = lit.Get();
      LabeledSampleContainerIteratorType search = means.find(label);
      if(search == means.end())
  	{
  	  //Label is not in list
  	  //Add new label to means
  	  //Create new element in vars
  	  SampleType s(nbComps+3);
  	  //means
  	  for(unsigned int i = 0 ; i < nbComps ; ++i)
  	    {
  	      s[i] = vit.Get()[i];
  	    }
  	  s[nbComps]=vit.GetIndex()[0];
  	  s[nbComps+1]=vit.GetIndex()[1];
  	  //count
  	  s[nbComps+2]=1;
  	  LabelSamplePairType newPairMean(label,s);
  	  means.insert(newPairMean);
	  
  	  //variances
  	  Matrix m(nbComps+2,nbComps+2);
  	  m.Fill(0.0);
  	  LabelMatrixPairType newPairVar(label,m);
  	  vars.insert(newPairVar);
  	}
      else
  	{
  	  //Element is already in list
  	  //Update means and vars with running formulas	    
  	  SampleType oldMeans = search->second;
  	  SampleType newMeans = oldMeans;
  	  SampleType currentPixel = vit.Get();
  	  unsigned int N = oldMeans[nbComps+2];
  	  for(unsigned int i = 0 ; i < nbComps ; ++i)
  	    {
  	      newMeans[i] += (currentPixel[i]-oldMeans[i])/(N+1);
  	    }
  	  newMeans[nbComps] += (vit.GetIndex()[0]-oldMeans[nbComps])/(N+1);
  	  newMeans[nbComps+1] += (vit.GetIndex()[1]-oldMeans[nbComps+1])/(N+1);

  	  newMeans[nbComps+2]++;
  	  search->second = newMeans;
  	  LabeledMatrixContainerIteratorType searchVar = vars.find(label);
  	  Matrix oldCovar = searchVar->second;
  	  Matrix newCovar = oldCovar;
  	  for(unsigned int i = 0 ; i < nbComps ; ++i)
  	    {
  	      for(unsigned int j = 0 ; j < nbComps ; ++j)
  		{
  		  //Update running covar matrix
  		  newCovar[i][j] += (currentPixel[i]-newMeans[i])*(currentPixel[j]-oldMeans[j]);
  		}	      
  	    }
	  newCovar[nbComps][nbComps] += (vit.GetIndex()[0]-newMeans[nbComps])*(vit.GetIndex()[0]-oldMeans[nbComps]);
	  newCovar[nbComps+1][nbComps+1] += (vit.GetIndex()[1]-newMeans[nbComps+1])*(vit.GetIndex()[1]-oldMeans[nbComps+1]);
  	  searchVar->second = newCovar;
  	}     
    }
  LabeledMatrixContainerIteratorType varit = vars.begin();
  double llSpectral = 0.0;
  double llSpatial = 0.0;
  for(; varit != vars.end();++varit)
    {
      Matrix var = varit->second;
      SampleType mean = means.find(varit->first)->second;    
      var/=mean[nbComps+2];
      varit->second = var;
      //Calculate log likelihoods
      for(unsigned int comp = 0; comp < nbComps; comp++)
  	{
  	  llSpectral += var[comp][comp]*mean[nbComps+2];
  	}
      llSpatial += var[nbComps][nbComps]*mean[nbComps+2];
      llSpatial += var[nbComps+1][nbComps+1]*mean[nbComps+2];

      Matrix varInv(nbComps,nbComps);

      //discard variances of x and y
      for(unsigned int i = 0 ; i < nbComps ; ++i)
	for(unsigned int j  = 0 ; j < nbComps ; ++j)
	  varInv[i][j] = var[i][j];
      varInv = varInv.GetInverse();
      LabelMatrixPairType newInv(varit->first,varInv);
      invVariances.insert(newInv);	 
    }
  LabelImageType::Pointer labelImage = LabelImageType::New();
  labelImage->SetRegions(lreader->GetOutput()->GetLargestPossibleRegion());
  labelImage->Allocate();
  labelImage->FillBuffer(0);
  labelImage->SetOrigin(lreader->GetOutput()->GetOrigin());
  labelImage->SetSpacing(lreader->GetOutput()->GetSpacing());

  itk::ImageRegionIterator<LabelImageType> litOut(labelImage,labelImage->GetLargestPossibleRegion());
  // unsigned int outlierCount = 0;
  double sprime =0.0;
  for(vit.GoToBegin(),lit.GoToBegin(),litOut.GoToBegin();!vit.IsAtEnd() && !lit.IsAtEnd() && !litOut.IsAtEnd();++vit,++lit,++litOut)
    {
      const LabelType label = lit.Get();	 
      const Matrix varInv = invVariances.find(label)->second;
      const Matrix var = vars.find(label)->second;
      const SampleType mean = means.find(label)->second;
      const SampleType cp = vit.Get();
      // // Euclidian distance
      double d = 0.0;
      double mdiff = 0.0;
      double mcp = 0.0;
      double mm = 0.0;
      for(unsigned int i = 0 ; i < nbComps ; ++i)
      	{
      	  d += (cp[i]-mean[i])*(cp[i]-mean[i]);
	  mdiff += (cp[i]-mean[i]);
	  mcp += cp[i];
	  mm += mean[i];
      	}
      // mcp/=nbComps;
      // mm/=nbComps;
      litOut.Set(vcl_sqrt(d-mdiff*mdiff/nbComps));
      unsigned int count = mean[nbComps+2];
      for(unsigned int i = 0 ; i < nbComps ; ++i)
	{
	  sprime+=var[i][i]/nbComps;
	}
      sprime+=count*(mcp*mcp-mm*mm)/(count-1)/nbComps;
      // litOut.Set(vcl_sqrt(d));
      // // Modified maha distance
      // double mvar = 0.0;
      // for(unsigned int i = 0 ; i < nbComps ; ++i)
      // 	{
      // 	  mvar += var[i][i]/nbComps;
      // 	}
      // double dmaha = 0.0;
      // for(unsigned int i = 0; i < nbComps; i++)
      // 	{
      // 	  for(unsigned int j = 0; j < nbComps; j++)
      // 	    {
      // 	      dmaha += (cp[i]-mean[i])*varInv[i][j]*(cp[j]-mean[j]);
      // 	    }
      // 	}
      // litOut.Set(vcl_sqrt(vcl_sqrt(mvar))*100*vcl_sqrt(dmaha));
      // if(dmaha > 200.0/vcl_sqrt(mvar))
      // 	{
      // 	  litOut.Set(255);
      // 	  outlierCount++;
      // 	}

      // //Spectral angle distance
      // double a1 = 0.0;
      // double a2 = 0.0;
      // double nm = 0.0;
      // double ncp = 0.0;
      // double nm2 = 0.0;
      // double ncp2 = 0.0;
      // for(unsigned int i = 0; i < nbComps; i++)
      // 	{
      // 	  a1 += cp[i]*mean[i]/max[i]/max[i];
      // 	  a2 += (max[i]-cp[i])*(max[i]-mean[i])/max[i]/max[i];
      // 	  nm += mean[i]*mean[i]/max[i]/max[i];
      // 	  ncp += cp[i]*cp[i]/max[i]/max[i];
      // 	  nm2 += (max[i]-mean[i])*(max[i]-mean[i])/max[i]/max[i];
      // 	  ncp2 += (max[i]-cp[i])*(max[i]-cp[i])/max[i]/max[i];	  
      // 	}
      
      // double out = (vcl_acos(a1/vcl_sqrt(nm*ncp))+vcl_acos(a2/vcl_sqrt(nm2*ncp2)))*255/3.1415;
      //std::cout << vcl_acos(a2/vcl_sqrt(nm2*ncp2)) << " " << vcl_acos(a1/vcl_sqrt(nm*ncp)) << std::endl;
      // litOut.Set(out);      
    }
  std::cout << "Number of superpixels = " << means.size() << std::endl;
  std::cout << "Sprime = " <<  (float) sprime/nbPixels << std::endl;
  // std::cout << "Spatial negative LL = " << (float) llSpatial/means.size() << std::endl;
   std::cout << "Spectral negative LL = " << (float) llSpectral/nbPixels << std::endl;
  
  LabelWriterType::Pointer lwriter = LabelWriterType::New();
  lwriter->SetInput(labelImage);
  lwriter->SetFileName(argv[3]);
  lwriter->Update();
    
  return 0;
}
