#include "otbSLICFilter.h"
#include "otbImageFileReader.h"
#include "otbImageFileWriter.h"
#include "otbSegmentCharacteristicsFilter.h"
#include "otbPersistentFilterStreamingDecorator.h"
#include "itkLabelToRGBImageFilter.h"
typedef otb::VectorImage<unsigned short>                                     VectorImageType;
typedef unsigned int                                                         LabelType;
typedef otb::Image<LabelType>                                                LabelImageType;
typedef otb::SLICFilter<VectorImageType, LabelImageType>                     SLICFilterType;
typedef otb::ImageFileReader<VectorImageType>                                ReaderType;
typedef otb::ImageFileReader<LabelImageType>                                 LabelReaderType;
typedef otb::ImageFileWriter<LabelImageType>                                 WriterType;
typedef typename otb::VectorImage<double>::PixelType                         SampleType;

typedef otb::VectorImage<unsigned char>                                      RGBImageType;
typedef RGBImageType::PixelType                                              RGBPixelType;
typedef otb::ImageFileWriter<RGBImageType>                                   RGBWriterType;

// Segmentation contrast maximisation LUT
typedef itk::LabelToRGBImageFilter <LabelImageType, RGBImageType>            LabelToRGBFilterType;

int main(int argc, char * argv[]){
  
  if(argc != 9  && argc != 10)
    {
      std::cerr<<"Auto tiling mode: "<<argv[0]<<" infname outfname spwidth d_weight maxit thresh margin maxMemory(Mb, per process)"<<std::endl;
      std::cerr<<"Manual tiling: "<<argv[0]<<" infname outfname spwidth d_weight maxit thresh tilesX tilesY margin(expressed in SPwidths)"<<std::endl;
      return EXIT_FAILURE;
    }
  
  const std::string inputName = std::string(argv[1]);
  const std::string outputName = std::string(argv[2]);
  const unsigned int spatialWidth = atoi(argv[3]);
  const double distanceWeight = atof(argv[4]);
  const unsigned int maxIterations = atoi(argv[5]);
  const double threshold = atof(argv[6]);
  unsigned int margin = 0;
  unsigned int nbTilesX = 1;
  unsigned int nbTilesY = 1;
  unsigned int maxMemory;
  if(argc == 9){
    //Auto Tiling mode
    margin = atoi(argv[7]);
    maxMemory = atoi(argv[8]);
  }
  else if (argc == 10){
    //Manual Tiling mode
    nbTilesX = atoi(argv[7]);
    nbTilesY = atoi(argv[8]);
    margin = atoi(argv[9]);
  }
  
  std::vector<std::string> joins;
  joins.push_back(itksys::SystemTools::GetFilenamePath(outputName).append("/"));
  joins.push_back(itksys::SystemTools::GetFilenameWithoutExtension(outputName));
  std::string prefix = itksys::SystemTools::JoinPath(joins);

  typename otb::MPIConfig::Pointer mpiConfig = otb::MPIConfig::Instance();
  mpiConfig->Init(argc,argv,true);

  GDALSetCacheMax(0);  
  // const unsigned int myRank = mpiConfig->GetMyRank();
  // Read the input image
  ReaderType::Pointer reader = ReaderType::New();
  reader->SetFileName(inputName);
  reader->SetReleaseDataFlag(true);
  reader->UpdateOutputInformation();

  if(argc == 9){
    //Auto Tiling mode
    const unsigned int X = reader->GetOutput()->GetLargestPossibleRegion().GetSize()[0];
    const unsigned int Y = reader->GetOutput()->GetLargestPossibleRegion().GetSize()[1];
    const unsigned int nbComps = reader->GetOutput()->GetNumberOfComponentsPerPixel();
    const unsigned int pin = 2*nbComps; //because reader casts to unsigned short
    const unsigned int pout = 4; //unsigned int label
    const unsigned int pc = (nbComps+2)*8+4; //centroid size

    //Reading config
    double alpha = 2*pin;
    double beta = (2*pin + pout)*2*spatialWidth*margin;
    const unsigned int readingNx = vcl_ceil(X*alpha/(vcl_sqrt(4*beta*beta + alpha*maxMemory*1024*1024) - beta));
    const unsigned int readingNy = vcl_ceil(Y*alpha/(vcl_sqrt(4*beta*beta + alpha*maxMemory*1024*1024) - beta));
    
    //Algo config
    alpha = pin+pout+(double)2*pc/spatialWidth/spatialWidth;
    beta = (pin + pout)*2*spatialWidth*margin;
    const unsigned int algoNx = vcl_ceil(X*alpha/(vcl_sqrt(4*beta*beta + alpha*maxMemory*1024*1024) - beta));
    const unsigned int algoNy = vcl_ceil(Y*alpha/(vcl_sqrt(4*beta*beta + alpha*maxMemory*1024*1024) - beta));
    
    nbTilesX = readingNx > algoNx ? readingNx : algoNx;  
    nbTilesY = readingNy > algoNy ? readingNy : algoNy;

    std::cout << "Auto tiling mode selected, available RAM = " << maxMemory << "MB" << std::endl;
    std::cout << "Number of x tiles = "<<  nbTilesX << std::endl;
    std::cout << "Number of y tiles = "<<  nbTilesY << std::endl;
  }
  
  SLICFilterType::Pointer slicFilter = SLICFilterType::New();
  slicFilter->SetInputImage(reader->GetOutput());
  slicFilter->SetInputName(inputName); 
  slicFilter->SetPrefix(prefix);
  slicFilter->SetSpatialWidth(spatialWidth);
  slicFilter->SetSpatialDistanceWeight(distanceWeight);
  slicFilter->SetMaxIterationNumber(maxIterations);
  slicFilter->SetThreshold(threshold);
  slicFilter->SetNbTilesX(nbTilesX);
  slicFilter->SetNbTilesY(nbTilesY);
  slicFilter->SetMargin(margin);
  slicFilter->Run();

  // if(myRank == 0){
  //   LabelReaderType::Pointer labelReader = LabelReaderType::New();
  //   std::stringstream outputName;
  //   outputName << prefix << ".vrt";
  //   labelReader->SetFileName(outputName.str());
  //   LabelToRGBFilterType::Pointer colorFilter = LabelToRGBFilterType::New();
  //   colorFilter->SetInput(labelReader->GetOutput());
  
  //   RGBWriterType::Pointer colorWriter = RGBWriterType::New();
  //   colorWriter->SetInput(colorFilter->GetOutput());
  //   colorWriter->SetFileName(output);
  //   colorWriter->Update();
  // }
  return EXIT_SUCCESS;
}
