#include <chrono>
#include <cstdlib>
#include <ctime>
#include <fstream>
#include <iostream>
#include <map>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
#include <boost/program_options.hpp>

#include "schnyder.h"

int const threadCount = std::thread::hardware_concurrency();
enum Distance {
	Triangle, Vertex, Edge/*, Edge2*/
};

struct Histogram {
	void doSum() {
		double x = 0;
		for (size_t r = 0; r < data.size(); ++r) {
			x += r*data[r];
		}
		double logx = log(x);
		double logx2 = logx*logx;
		sum_1 += 1;
		sum_logx += logx;
		sum_logx2 += logx2;
	}
	Histogram &operator +=(Histogram const &other) {
		if (data.size() < other.data.size()) {
			data.resize(other.data.size(), 0);
		}
		for (unsigned i = 0; i < other.data.size(); ++i) {
			data[i] += other.data[i];
		}
		sum_1 += other.sum_1;
		sum_logx += other.sum_logx;
		sum_logx2 += other.sum_logx2;
		return *this;
	}

	std::vector<uint64_t> data;
	uint64_t sum_1 = 0;
	double sum_logx = 0, sum_logx2 = 0;  // sum_1 = Σ 1, sum_logx = Σ log(Σ_r r Nρ(r)), sum_logx2 = Σ (log(Σ_r r Nρ(r)))²
	int numSomething;  // numSomething = numPolygons, numVertices or numEdges
};

Histogram process(Map<EdgeData> &map, Distance distanceMeasure, std::mt19937_64 &rng) {
	using Uni = std::uniform_int_distribution<>;
	Uni uni;
	auto rnd = [&](int mod) {
		return uni(rng, Uni::param_type{0, mod - 1});
	};
	Histogram hist;
	std::vector<EdgeHandle::Id> edgeList;
	std::vector<EdgeHandle::Id> tmpEdgeList;
	switch (distanceMeasure) {
		case Triangle:
			edgeList.push_back(map.getEdge(rnd(map.numHalfEdges()))->getId());
			for (unsigned dist = 0; !edgeList.empty(); ++dist) {
				tmpEdgeList.clear();
				uint64_t trisAtDist = 0;
				for (auto edgeId : edgeList) {
					auto edge = map.getEdge(edgeId);
					trisAtDist += !edge->data().visited;
					while (!edge->data().visited) {
						tmpEdgeList.push_back(edge->getAdjacent()->getId());
						edge->data().visited = true;
						edge = edge->getNext();
					}
				}
				hist.data.push_back(trisAtDist);
				std::swap(edgeList, tmpEdgeList);
				//edgeList = std::move(tmpEdgeList);
			}
			break;
		case Vertex: {
		retry:
			auto startEdge = map.getEdge(rnd(map.numHalfEdges()));
			if (rnd(map.vertexDegree(startEdge))) {
				goto retry;
			}
			edgeList.push_back(startEdge->getId());
			for (unsigned dist = 0; !edgeList.empty(); ++dist) {
				tmpEdgeList.clear();
				uint64_t vertsAtDist = 0;
				for (auto edgeId : edgeList) {
					auto edge = map.getEdge(edgeId);
					vertsAtDist += !edge->data().visited;
					while (!edge->data().visited) {
						tmpEdgeList.push_back(edge->getNext()->getId());
						edge->data().visited = true;
						edge = edge->getRotateCW();
					}
				}
				hist.data.push_back(vertsAtDist);
				std::swap(edgeList, tmpEdgeList);
				//edgeList = std::move(tmpEdgeList);
			}
			break;
		}
		case Distance::Edge: {
			edgeList.push_back(map.getEdge(rnd(map.numHalfEdges()))->getId());
			for (unsigned dist = 0; !edgeList.empty(); ++dist) {
				tmpEdgeList.clear();
				uint64_t edgesAtDist = 0;
				for (auto edgeId : edgeList) {
					auto edge = map.getEdge(edgeId);
					edgesAtDist += !edge->data().visited;
					while (!edge->data().visited) {
						for (auto nearby = edge->getRotateCW(); nearby != edge; nearby = nearby->getRotateCW()) {
							tmpEdgeList.push_back(nearby->getId());
						}
						edge->data().visited = true;
						edge = edge->getAdjacent();
					}
				}
				hist.data.push_back(edgesAtDist);
				std::swap(edgeList, tmpEdgeList);
				//edgeList = std::move(tmpEdgeList);
			}
			break;
		}
		/*case Distance::Edge2: {
			edgeList.push_back(map.getEdge(rnd(map.numHalfEdges()))->getId());
			for (unsigned dist = 0; !edgeList.empty(); ++dist) {
				tmpEdgeList.clear();
				uint64_t edgesAtDist = 0;
				for (auto edgeId : edgeList) {
					auto edge = map.getEdge(edgeId);
					edgesAtDist += !edge->data().visited;
					while (!edge->data().visited) {
						tmpEdgeList.push_back(edge->getRotateCW()->getId());
						tmpEdgeList.push_back(edge->getRotateCCW()->getId());
						edge->data().visited = true;
						edge = edge->getAdjacent();
					}
				}
				hist.data.push_back(edgesAtDist);
				std::swap(edgeList, tmpEdgeList);
				//edgeList = std::move(tmpEdgeList);
			}
			break;
		}*/
	}
	hist.doSum();
	return hist;
}

unsigned long seedMix(unsigned long a, unsigned long b, unsigned long c) {
    a=a-b;  a=a-c;  a=a^(c >> 13);
    b=b-c;  b=b-a;  b=b^(a << 8);
    c=c-a;  c=c-b;  c=c^(b >> 13);
    a=a-b;  a=a-c;  a=a^(c >> 12);
    b=b-c;  b=b-a;  b=b^(a << 16);
    c=c-a;  c=c-b;  c=c^(b >> 5);
    a=a-b;  a=a-c;  a=a^(c >> 3);
    b=b-c;  b=b-a;  b=b^(a << 10);
    c=c-a;  c=c-b;  c=c^(b >> 15);
    return c;
}

unsigned long getseed()
{
#ifdef _WIN32
	unsigned int pid = GetCurrentProcessId();
#else
	unsigned int pid = getpid();
#endif
	return seedMix(clock(), time(nullptr), pid);
}

using MakeMap = Map<EdgeData> (*)(int size, std::mt19937_64 &rng);

Histogram generateHistogram(MakeMap makeMap, int polygonSides, Distance distanceMeasure, int numFaces, int count, std::mt19937_64 *rng) {
	Histogram hist;

	int c = 0;

	int numEdges = polygonSides*numFaces/2;
	// 2*genus - 2 = numEdges - numFaces - numVertices, genus = 0
	int numVertices = numEdges - numFaces + 2;
	switch (distanceMeasure) {
		case Triangle: hist.numSomething = numFaces; break;
		case Vertex: hist.numSomething = numVertices; break;
		case Distance::Edge: hist.numSomething = numEdges; break;
		//case Distance::Edge2: hist.numSomething = numEdges; break;
	}

	std::mutex mutex;
	std::thread threads[threadCount];
	for (int t = 0; t < threadCount; ++t) {
		threads[t] = std::thread([&, t]() {
			Histogram threadHist;
			mutex.lock();
			while (c < count) {
				int threadC = std::max((count - c)/threadCount/2, 1);
				c += threadC;
				mutex.unlock();
				while (threadC--) {
					auto edges = makeMap(numFaces, rng[t]);
					threadHist += process(edges, distanceMeasure, rng[t]);
				}
				mutex.lock();
			}
			hist += threadHist;
			mutex.unlock();
		});
	}
	for (int t = 0; t < threadCount; ++t) {
		threads[t].join();
	}

	return hist;
}

int main(int argc, char **argv) {
	MakeMap makeMap = nullptr;
	int polygonSides = 0;
	Distance distance = Distance(-1);

	std::map<std::string, std::pair<MakeMap, int>> const maps = {
		{"schnyder", {makeSchnyderMap, 3}},  // c = -12.5, γ = 1
		{"bipolar", {makeBipolarMap, 3}},  // c = -7, γ = √(4/3)
		{"spanning", {makeSpanningTreeMap, 4}},  // c = -2, γ = √2
		{"quadrangulation", {makeUniformQuadrangulation, 4}},  // c = 0, γ = √(8/3)
	};
	std::map<std::string, Distance> const distances = {
		{"tri", Triangle},
		{"vert", Vertex},
		{"edge", Distance::Edge},
		//{"edge2", Distance::Edge},
	};

	std::ofstream statfile;

	std::string classStr, metricStr = "vert", statfileStr;
	int faces = 64, count = 1;
	namespace po = boost::program_options;
	po::options_description desc("Allowed options");
	desc.add_options()
		("help", "produce help message")
		("class,c", po::value(&classStr), "d-angulation class (quadrangulation (4), spanning (4), bipolar (3), schnyder (3))")
		("metric,m", po::value(&metricStr)->default_value(metricStr), "distance metric (vert = graph distance, tri = dual graph distance, edge)")
		("statfile,s", po::value(&statfileStr), "optional statistics file (while generating a histogram)")
		("faces,f", po::value(&faces)->default_value(faces), "number of faces in d-angulations")
		("count,n", po::value(&count)->default_value(count), "number of d-angulations to generate")
		("histogram,h", "output histogram data instead of individual d-angulations")
	;

	po::variables_map vm;
	po::store(po::parse_command_line(argc, argv, desc), vm);
	po::notify(vm);

	bool histogram = vm.count("histogram");

	auto c = maps.find(classStr);
	if ( c != maps.end()) {
		std::tie(makeMap, polygonSides) = c->second;
	}
	auto m = distances.find(metricStr);
	if ( m != distances.end()) {
		distance = m->second;
	}
	if (!statfileStr.empty()) {
		statfile = std::ofstream(statfileStr, std::ios::binary);
	}

	if (vm.count("help") || makeMap == nullptr || (histogram && distance == Distance(-1))) {
		std::cerr << desc << "\n";
		return 1;
	}

	srand(time(nullptr));
	long seed = getseed();
	std::mt19937_64 rng[threadCount];
	for (int t = 0; t < threadCount; ++t) {
		rng[t].seed(seed + t);
	}

	std::cerr << "threadCount = " << threadCount << '\n';

	auto start = std::chrono::high_resolution_clock::now();
	std::cerr << "generateHistogram(" << faces << ", " << count << ", rng)";
	if (histogram) {
		auto hist = generateHistogram(makeMap, polygonSides, distance, faces, count, rng);
		std::cout.precision(16);
		for (unsigned i = 0; i < hist.data.size(); ++i) {
			std::cout << i << ' ' << hist.data[i] << ' ' << hist.numSomething << ' ' << (1.*hist.data[i]/count/hist.numSomething) << ' ' << count << '\n';
		}
		std::cout.flush();
		if (statfile) {
			statfile.precision(17);
			// Column 5 is per-triangulation variance of log k.
			statfile << hist.numSomething << ' ' << hist.sum_1 << ' ' << hist.sum_logx << ' ' << hist.sum_logx2 << ' ' << (hist.sum_1*hist.sum_logx2 - hist.sum_logx*hist.sum_logx)/hist.sum_1/hist.sum_1 << '\n';
			statfile.flush();
		}
	} else {
		std::cout << count << '\n';
		for (int i = 0; i < count; ++i) {
			auto map = makeMap(faces, rng[0]);
			std::cout << map.numHalfEdges() << "\n";
			for (auto edge = map.begin(); edge != map.end(); ++edge) {
				std::cout << (*edge)->getNext()->getId() + 1 << ' '
				          << (*edge)->getPrevious()->getId() + 1 << ' '
				          << (*edge)->getAdjacent()->getId() + 1 << '\n';
			}
		}
	}
	auto stop = std::chrono::high_resolution_clock::now();
	std::cerr << ' ' << std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count()/1.e6 << "s\n";
}
