import asyncio
import logging
import platform
import sys
from queue import Queue
import threading
from asyncio import Queue
from multiprocessing import Process
import time
import os
import urllib.parse
from asyncio.exceptions import CancelledError
import tldextract
import hashlib
import whois
import redis
import datetime
import json
import ssl, socket
from collections import namedtuple


socket.setdefaulttimeout(5)

clogger = logging.getLogger(__name__)



WhoisStatistic = namedtuple("WhoisStatistic",
                            ["fqdn",
                            "sld",
                            "whois",
                            "type",
                            "status",
                            "exception"
                            ])


class ncrawler:
    def __init__(self,max_tasks=10,max_tries=4,loop=None,task_queue=None,redis_db = None,redis_set=None,page_timeout = 20,timeout_index = 1,target_type=None):
        self.max_tasks = max_tasks
        self.max_tries = max_tries
        self.loop = loop or asyncio.get_event_loop()
        self.task_queue = task_queue or Queue(loop = self.loop)
        self.done = []
        self.pool = redis.ConnectionPool(host="127.0.0.1",port=6379,db=redis_db)
        self.t0 = time.time()
        self.set_index = 0
        self.redis_set = redis_set
        self.target_type = target_type
    
    def default(self,o):
        if isinstance(o, (datetime.date, datetime.datetime)):
            return o.isoformat()

    def record_statistic(self,key,info_statistic):
        # 将每个URL信息，保存到redis 或 mysql
        resdb = redis.Redis(connection_pool=self.pool)
        res_id = hashlib.md5(key.encode()).hexdigest()
        resdb.set(res_id,json.dumps(info_statistic._asdict(),default=self.default))
        resdb.zadd(self.redis_set,{key:self.set_index})
        self.set_index += 1
        print(info_statistic)
        self.done.append(info_statistic)

    
    async def is_seen(self,key):
        # 是否收集过，没有收集过返回True,否则返回False
        resdb = redis.Redis(connection_pool=self.pool)
        res_id = hashlib.md5(key.encode()).hexdigest()
        if not resdb.get(res_id):
            return True
        else:
            return False
    def report(self):
        pass


class Whoiscrawler(ncrawler):
    def __init__(self,targets,max_tries=3,max_tasks = 1,loop=None,task_queue=None,redis_db = None,redis_set=None,page_timeout = 20,timeout_index = 1,target_type=None):
        super(Whoiscrawler,self).__init__(max_tries= max_tries,max_tasks=max_tasks,loop=loop,task_queue=task_queue,redis_db = redis_db,redis_set=redis_set,page_timeout = page_timeout,timeout_index = timeout_index,target_type=target_type)
        if targets:
            for target in targets:
                self.task_queue.put_nowait(target)
    
    async def fetch(self,target):
        if not await self.is_seen(target):
            print("have done")
            return
        tlde = tldextract.extract(target)
        sld = tlde.domain+"."+tlde.suffix
        exception = None
        try:
            tar_whois = whois.whois(target)
        except whois.parser.PywhoisError as e:
            tar_whois = {target:"unknown"}
        except:
            tar_whois = {target:"failed"}
        
        self.record_statistic(sld,WhoisStatistic(fqdn=target,
                                                sld=sld,
                                                whois=tar_whois,
                                                type = self.target_type,
                                                status=200,
                                                exception=exception))
        return

    async def work(self):
        try:
            while True:
                target = await self.task_queue.get()
                if target =="done":   # 结束信号
                    return 1
                await self.fetch(target)
                self.task_queue.task_done()
                await asyncio.sleep(5)

        except asyncio.CancelledError as e:
            print(e)
            return 

    async def crawl(self):
        self.t0 = time.time()
        workers = [asyncio.Task(self.work(),loop=self.loop) for _ in range(self.max_tasks)]
        for wi in workers:
            done = await wi
            if done:
                for wi in workers:
                    wi.cancel()

        self.t1 = time.time()

    

if __name__=="__main__":
    levels = [logging.ERROR, logging.WARN, logging.INFO, logging.DEBUG]
    logging.basicConfig(level=levels[min(2, len(levels)-1)])


    if platform.system()=="Windows":
        from asyncio.windows_events import ProactorEventLoop
        loop = ProactorEventLoop()
        asyncio.set_event_loop(loop)
    else:
        loop = asyncio.get_event_loop()


    target_path = "domains.txt"  # the domain list

    with open(target_path,"r") as fp:
        data = fp.readlines()
        data = [line.strip() for line in data]

    target_type = "epp"
    data.append("done")
    whois_target = data
    whois_queue = Queue(loop=loop)
    whois_redis = 10
    whoiscrawler = Whoiscrawler(targets=whois_target,max_tasks=6,task_queue=whois_queue,redis_db = whois_redis,redis_set="epp",target_type=target_type)

    try:
        workers = [whoiscrawler.crawl()]
        loop.run_until_complete(asyncio.gather(*workers))
    except KeyboardInterrupt:
        sys.stderr.flush()
        print("\nInterrupted\n")
    finally:
        loop.stop()
        loop.run_forever()
        loop.close()
    