import pandas as pd
import re
import random

class Span:

    (
        LESS,
        GREATER,
        EQUAL,
        CONTAINED,
        CONTAINS,
        LESS_WITH_OVERLAP,
        GREATER_WITH_OVERLAP,
    ) = range(7)
    
    def __init__(self, start, end):
        if start > end:
            raise ValueError()
        self.start = start
        self.end = end

    def shift(self, delta):
        self.start += delta
        self.end += delta

    def compare(self, other):
        if self.end < other.start:
            return Span.LESS
        elif self.start > other.end:
            return Span.GREATER
        elif self.start == other.start and self.end == other.end:
            return Span.EQUAL
        elif self.start >= other.start and self.end <= other.end:
            return Span.CONTAINED
        elif self.start <= other.start and self.end >= other.end:
            return Span.CONTAINS
        elif other.start <= self.end <= other.end:
            return Span.LESS_WITH_OVERLAP
        elif other.start <= self.start <= other.end:
            return Span.GREATER_WITH_OVERLAP
        else:
            raise AssertionError()
        
class Cleaner:

    def __init__(self):
        self.sub_rules = []
        self.usage_freqs = {}
        self.regex_srcs = []
        self.subs_made = []

    @staticmethod
    def sub(regex_src, trg, desc):
        sub_starts = []
        ready = []
        remaining = desc
        last_pos = 0
        trg_len = len(trg)
        while True:
            m = re.search(regex_src, remaining)
            if m is None:
                ready.append(remaining)
                return (''.join(ready), sub_starts)
            ready.append(remaining[:m.start()])
            ready.append(trg)
            sub_starts.append(last_pos + m.start())
            last_pos += m.start() + trg_len
            remaining = remaining[m.end():]

    def add_sub_rule(self, src, trg, with_bounds=True):
        if src in self.usage_freqs:
            raise ValueError(f'Duplicate source: {src}')
        if src == trg:
            raise ValueError(f'Source and target are the same: {src}')
        self.sub_rules.append((src, trg))
        self.usage_freqs[src] = 0
        regex = re.escape(src)
        if with_bounds:
            if src[0] not in '\',-./;=':
                regex = '\\b' + regex
            if src[-1] not in '\',-./;=':
                regex = regex + '\\b'
        self.regex_srcs.append(regex)

    def clean(self, desc):
        #keep track of when a substitute is modified by a later substitute
        
        desc = re.sub('[\\r\\n]+', ' ', desc) # Remove new lines.
        desc = desc.strip() # Remove spaces at edges.
        
        if desc == '':
            return ''

        subs_made_in_desc = []
        for ((src, trg), regex) in zip(self.sub_rules, self.regex_srcs):
            (desc_, sub_starts) = Cleaner.sub(regex, trg, desc)
            if len(sub_starts) > 0:
                self.usage_freqs[src] += 1
                subs_made_in_desc.append((src, trg, sub_starts, desc_))
                desc = desc_
        self.subs_made.append(subs_made_in_desc)

        # Fix punctuation and spaces.
        desc = desc.strip() # Remove spaces at the edges.
        desc = re.sub(' {2,}', ' ', desc) # Remove consecutive spaces.
        desc = re.sub('( ?\\.){2,}', '.', desc) # Remove consecutive fullstops.
        desc = re.sub('( ?,){2,}', ',', desc) # Remove consecutive commas.
        desc = desc.replace(',.', '.') # Commas in front of fullstops are mistyped.
        desc = desc.replace(' .', '.') # Remove spaces in front of fullstops.
        desc = desc.replace(' ,', ',') # Remove spaces in front of commas.
        desc = re.sub('\\.([^ ])', '. \\1', desc) # Add missing space after fullstops.
        desc = re.sub(',([^ ])', ', \\1', desc) # Add missing space after commas.
        desc = re.sub('[,/]$', '.', desc) # Replace , and / at the end with a fullstop.
        desc = re.sub('^\\)', '', desc) # Replace open parenthesis at the front.
    
        # Add missing fullstops.
        if desc[-1] != '.':
            desc = desc + '.'
        desc = re.sub(
            '(?<![.,;:])'
            '(?<! as)'
            '(?<! and)'
            '(?<! yet)'
            '(?<! but)'
            '(?<! like)'
            '(?<! that)'
            '(?<! plus)'
            '(?<! when)'
            '(?<! tell)'
            '(?<! thus)'
            '(?<! hand)'
            '(?<! while)'
            '(?<! where)'
            '(?<! which)'
            '(?<! since)'
            '(?<! point)'
            '(?<! though)'
            '(?<! because)'
            '(?<! whereas)'
            '(?<! although)'
            '(?<!round his lips)'
            ' (s?he )'
            '(?!has on)', '. \\1', desc, flags=re.IGNORECASE
        )

        # Capitalise first letter of each sentence.
        desc = re.sub('^[a-z]|\\. [a-z]', lambda x:x[0].upper(), desc)
        
        return desc

    def check_usages(self):
        zero_uses = []
        for (src, _) in self.sub_rules:
            if self.usage_freqs[src] == 0:
                zero_uses.append(src)
        if len(zero_uses) > 0:
            print('The following sources were never used:')
            for x in zero_uses:
                print('*', x)
            print()

    def _any_sub_overlaps(self, i):
        change_spans = []
        for (src, trg, sub_starts, desc) in self.subs_made[i]:
            for sub_start in sub_starts:
                new_change_span = Span(sub_start, sub_start+len(trg))
                for change_span in change_spans:
                    if change_span.compare(new_change_span) in [Span.GREATER, Span.GREATER_WITH_OVERLAP]:
                        change_span.shift(len(trg) - len(src))
                    if change_span.compare(new_change_span) not in [Span.LESS, Span.GREATER]:
                        return True
                change_spans.append(new_change_span)
        return False

    def check_sub_overlaps(self):
        overlaps = []
        for i in range(len(self.subs_made)):
            if self._any_sub_overlaps(i):
                overlaps.append(i)
        if len(overlaps) > 0:
            print('The following data frame row indexes have overlapping substitutions:')
            for i in overlaps:
                print('*', i)
            print()
            
cleaner = Cleaner()

# Individual characters cleaners.
cleaner.add_sub_rule('\x08', '', with_bounds=False)
cleaner.add_sub_rule('"', '', with_bounds=False)
cleaner.add_sub_rule('#', '', with_bounds=False)
cleaner.add_sub_rule('\\', '', with_bounds=False)
cleaner.add_sub_rule(']', '', with_bounds=False)
cleaner.add_sub_rule('ħ', '', with_bounds=False)
cleaner.add_sub_rule('ż', '', with_bounds=False)
cleaner.add_sub_rule('3', '', with_bounds=False)
cleaner.add_sub_rule('7', '', with_bounds=False)
cleaner.add_sub_rule('’', '\'', with_bounds=False)
cleaner.add_sub_rule('&', ' and ', with_bounds=False)

# Phrase cleaners.
with open('phrase_fixes.txt', 'r', encoding='utf-8') as f:
    for i, line in enumerate(f.read()[:-1].split('\n')):
        (src, trg) = line.split('\t')
        cleaner.add_sub_rule(src, trg)

# Word cleaners.
with open('word_fixes.txt', 'r', encoding='utf-8') as f:
    for i, line in enumerate(f.read()[:-1].split('\n')):
        (src, trg) = line.split('\t')
        cleaner.add_sub_rule(src, trg)

df = pd.read_json('../raw_2.1.json')
df['description'] = df['description'].map(cleaner.clean)
df = df[df['description'] != '']
del df['user_id']

cleaner.check_usages()
cleaner.check_sub_overlaps()

# Split into 80/10/10 data splits.
filenames = df['filename'].unique().tolist()
random.seed(0)
random.shuffle(filenames)
tenth_size = int(len(df)*0.1)
df_test = df[df['filename'].isin(filenames[0:tenth_size])]
df_dev = df[df['filename'].isin(filenames[tenth_size:2*tenth_size])]
df_train = df[df['filename'].isin(filenames[2*tenth_size:])]
df_test.to_json(
    '../clean_test_2.1.json',
    force_ascii=False,
    orient='records',
    indent=4,
)
df_dev.to_json(
    '../clean_dev_2.1.json',
    force_ascii=False,
    orient='records',
    indent=4,
)
df_train.to_json(
    '../clean_train_2.1.json',
    force_ascii=False,
    orient='records',
    indent=4,
)
