from datetime import datetime
import logging
from typing import List

from git import GitError
from mysql.connector import connect, MySQLConnection
from pytz import utc

from utils.core import Status, Update
from utils.runner import Runner
from utils.vcs import traverse_pom_history, clone_repository, delete_repository


class RepoScanner(Runner):
    def __init__(self, connection: MySQLConnection, logger: logging.Logger, batch_size: int):
        super().__init__(connection, logger)
        self.batch_size = batch_size
        self.vulnerabilities = {}

    def run_once(self):
        repos = self._get_repos()
        for repo in repos:
            self._logger.info('Start scanning: {}'.format(repo['full_name']))
            if repo['cve'] not in self.vulnerabilities:
                self.vulnerabilities[repo['cve']] = self._get_vulnerability_information(repo['cve'])
            try:
                cloned_repo = clone_repository(repo['full_name'])
            except GitError:
                self._logger.warning('Could not clone {}'.format(repo['full_name']))
                with self._connection.cursor() as cursor:
                    cursor.execute(f"UPDATE repos SET status='{Status.UNAVAILABLE}' WHERE id = %s", [repo['repo_id']])
                self._connection.commit()
                return

            (updates, _) = traverse_pom_history(
                cloned_repo,
                self.vulnerabilities[repo['cve']]['package_coords'],
                self.vulnerabilities[repo['cve']]['fix_release_date']
            )
            print(updates)
            delete_repository(repo['full_name'])
            self._write_update_to_db(repo['cve'], repo['repo_id'], updates)

    def _get_repos(self):
        with self._connection.cursor() as cursor:
            cursor.execute(f"SELECT id, cve, full_name, pom_path FROM repos WHERE status = '{Status.NEW}' OR status = '{Status.CHANGED}' LIMIT {self.batch_size}")
            repos = [{'repo_id': repo_id, 'cve': cve, 'full_name': full_name, 'pom_path': pom_path} for (repo_id, cve, full_name, pom_path) in cursor]
            if len(repos) == 0:
                self._logger.info('No repos to scan yet')
                return repos
            repo_ids = ', '.join([str(repo['repo_id']) for repo in repos])
            cursor.execute(f"UPDATE repos SET status='{Status.IN_PROGRESS}' WHERE id IN ({repo_ids})")
        self._connection.commit()
        return repos

    def _get_vulnerability_information(self, cve):
        with self._connection.cursor() as cursor:
            cursor.execute("SELECT package_coords, first_fix_release_date FROM vulnerabilities WHERE cve = %s", [cve])
            (package_coords, fix_release_date) = cursor.fetchone()
        fix_release_date = datetime.combine(fix_release_date, datetime.min.time(), utc)
        return {'cve': cve, 'package_coords': package_coords, 'fix_release_date': fix_release_date}

    def _write_update_to_db(self, cve, repo_id, updates: List[Update]):
        prepared_updates = [{
            'cve': cve,
            'repo_id': repo_id,
            'package': update.package,
            'update_delay': update.delay.days,
            'commit_hash': update.commit.hash,
            'commit_date': update.commit.author_date.strftime('%Y-%m-%d'),
            'commit_author': update.commit.author.name,
            'is_fix_update': update.is_fix_update,
            'old_version': update.old_version,
            'old_release_date': update.old_release_date.strftime('%Y-%m-%d'),
            'new_version': update.new_version,
            'new_release_date': update.new_release_date.strftime('%Y-%m-%d'),
        } for update in updates]
        with self._connection.cursor() as cursor:
            cursor.executemany("""
                INSERT INTO updates
                (cve, repo_id, package, update_delay, commit_hash, commit_date, commit_author, is_fix_update, old_version, old_release_date, new_version, new_release_date)
                VALUES 
                (%(cve)s, %(repo_id)s, %(package)s, %(update_delay)s, %(commit_hash)s, %(commit_date)s, %(commit_author)s, %(is_fix_update)s, %(old_version)s, %(old_release_date)s, %(new_version)s, %(new_release_date)s)
            """, prepared_updates)
            cursor.execute(f"UPDATE repos SET status='{Status.DONE}' WHERE id = %s", [repo_id])
        self._connection.commit()


if __name__ == '__main__':
    logging.getLogger().setLevel(logging.INFO)
    with connect(host='127.0.0.1', port=33062, user='vulnerability-history', database='vulnerability-history',
                 password='secret') as conn:
        repo_finder = RepoScanner(conn, logging.getLogger(), 1)
        repo_finder.run_once()
