#!/usr/bin/env python3

import argparse
import sys
from typing import List

import sinter


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--stats",
        type=str,
        required=True,
    )
    args = parser.parse_args()

    stats: List[sinter.TaskStats] = sinter.stats_from_csv_files(args.stats)

    def all_except_basis(stat: sinter.TaskStats) -> str:
        metadata = dict(sorted(stat.json_metadata.items()))
        del metadata['b']
        return repr(metadata)

    print(sinter.CSV_HEADER)
    for _, pair in sinter.group_by(stats, key=all_except_basis).items():
        if len(pair) > 2:
            raise ValueError("More than two bases?")
        if len(pair) == 1:
            print("WARNING: duplicating unpaired value with metadata ", pair[0].json_metadata, file=sys.stderr)
            a, b = pair[0], pair[0]
        else:
            a, b = pair
        if a.shots > b.shots:
            a, b = b, a
        assert a.discards == b.discards == 0
        new_errors = round((1 - (1 - a.errors / a.shots) * (1 - b.errors / b.shots)) * a.shots)

        new_metadata = dict(a.json_metadata)
        new_metadata['b'] = 'XZ'
        combo = sinter.TaskStats(
            strong_id=a.strong_id + '*' + b.strong_id,
            decoder=a.decoder,
            json_metadata=new_metadata,
            shots=a.shots,
            errors=new_errors,
            discards=0,
            seconds=a.seconds + b.seconds,
        )
        print(combo)


if __name__ == '__main__':
    main()
