import av
import torch
import pandas as pd
import torch.nn as nn
import numpy as np 
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import gc
import argparse
from torch.optim import AdamW
import warnings
from sklearn.metrics import classification_report
warnings.filterwarnings("ignore", category=UserWarning, module="transformers.feature_extraction_utils")
from sklearn.model_selection import train_test_split
from transformers import AutoProcessor, XCLIPVisionModel, get_linear_schedule_with_warmup, AutoModel,VivitImageProcessor,VivitModel, AutoImageProcessor,VideoMAEModel,VideoMAEForVideoClassification
from tqdm import tqdm
from collections import defaultdict
import os
import transformers
transformers.logging.set_verbosity_error() 


import time
np.random.seed(int(time.time()))

def find_video_files(directory,eval_list=[],unsafe_path_len=None):
    video_files = []
    for root, dirs, files in os.walk(directory):
        if unsafe_path_len is None:
            unsafe_path_len = len(files)
        for file in files[:unsafe_path_len]:
            if file.endswith(".mp4") and int(file.split('.')[0]) not in eval_list:
                full_path = os.path.join(root, file)
                video_files.append(full_path)
    return video_files

class CreateDataset(torch.utils.data.Dataset):
    def __init__(self,videos_file,labels,processor):
        self.videos_file = videos_file
        self.labels = labels
        self.processor = processor
    def __len__(self):
        return len(self.videos_file)

    def __getitem__(self,item):
        try:
            container = av.open(self.videos_file[item])
            indices = self.sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
            video = self.read_video_pyav(container, indices)
            processed_video = self.processor(list(video), return_tensors="pt")
            pro_video = processed_video['pixel_values']
        except av.error.InvalidDataError as e:
            print(f"wrong file {video_file}: {e}")
        except Exception as e:
            print(f"mistake {video_file}: {e}")
        return {
            'input': pro_video,
            'label': self.labels[item],
            'file' : self.videos_file[item].split('/')[-1].split('.')[0]
        }

    def read_video_pyav(self,container, indices):
        '''
        Decode the video with PyAV decoder.
        Args:
            container (`av.container.input.InputContainer`): PyAV container.
            indices (`List[int]`): List of frame indices to decode.
        Returns:
            result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
        '''
        frames = []
        container.seek(0)
        start_index = indices[0]
        end_index = indices[-1]
        for i, frame in enumerate(container.decode(video=0)):
            if i > end_index:
                break
            if i >= start_index and i in indices:
                frames.append(frame)
        return np.stack([x.to_ndarray(format="rgb24") for x in frames])


    def sample_frame_indices(self, clip_len, frame_sample_rate, seg_len):
        '''
        Sample the first number of frame indices from the video.
        Args:
            clip_len (`int`): Total number of frames to sample.
            frame_sample_rate (`int`): Sample every n-th frame.
            seg_len (`int`): Maximum allowed index of sample's last frame.
        Returns:
            indices (`List[int]`): List of sampled frame indices
        '''
        clip_len = min(clip_len, seg_len)

        indices = list(range(0, clip_len * frame_sample_rate, frame_sample_rate))

        return indices

def CreateDataLoader(df,processor,batch_size):
    ds = CreateDataset(videos_file = df['video_path'],
                        labels = df['labels'],
                        processor = processor)
    return ds

class DetectionSystem:
    def __init__(self, model_paths, num_labels=2):
        
        self.models = []
        for model_path in model_paths:
            if os.path.exists(model_path):
                model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base",num_labels = 2)
                model.load_state_dict(torch.load(model_path,map_location=f"cuda:{args.gpu_num}"))
                model.to(f"cuda:{args.gpu_num}")
                model.eval()  
                self.models.append(model)
    
    def process_input(self, input_data):
        outputs = []
        for model in self.models:
            with torch.no_grad():
                output = model(input_data)
                outputs.append(output.logits)
                index, preds = torch.max(output.logits, dim = 1)
                if preds == 1:
                    return 1
        return 0

def eval_model(model, data_loader, n_examples,data_dict):
    losses = []
    correct_predictions = 0

    with torch.no_grad():
        for d in data_loader:
            input_vids = d['input'].to(f"cuda:{args.gpu_num}")
            label = d['label'].to(f"cuda:{args.gpu_num}")
            input_video = input_vids.squeeze(1)
            output = model.process_input(input_video)
            data_dict[str(label.item())][str(d['file'][0])].append(output)
            
            
    return data_dict

def default_list_dict():
    return defaultdict(list)

def main(args):
    print("load data...", flush=True)
    processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
    


    data_dict = defaultdict(default_list_dict)
    all_preds = []
    all_labels = []
    for i in range(20):
        module_list = []
        print(f"this is {i}...")
        if args.step_accuracy:
            module_list.append(f"{args.model_dir}/{args.group_num}/{i}.pth")
        else:
            for j in range(1,6,1):    
                module_list.append(f"{args.model_dir}/{j}/{i}.pth")
        System = DetectionSystem(module_list)
        label1_path = find_video_files(f"{args.eval_dir}/{i}")
        label0_path = find_video_files(f"{args.data_dir}/{i}",unsafe_path_len=len(label1_path))

        label1 = np.full(len(label1_path),1)
        label0 = np.full(len(label0_path),0)
        label0_path = np.array(label0_path)
        label1_path = np.array(label1_path)

        labels = np.concatenate((label1,label0))
        
        video_path = np.concatenate((label1_path,label0_path))
        
        
        print(f"load data for model {i}", flush=True)
        data={}
        data['video_path'] = video_path
        data['labels'] = labels

        df_data = pd.DataFrame(data)
        ds = CreateDataLoader(df_data,processor,1)
        val_data_loader = torch.utils.data.DataLoader(ds,batch_size=1,num_workers = 2,drop_last=True)
            
        if not args.step_accuracy:
            data_dict = eval_model(System, val_data_loader, len(df_data),data_dict)
        else:
            all_preds, all_labels = step_accuracy(System, val_data_loader, all_preds, all)
        del System.models,data,val_data_loader,ds
        torch.cuda.empty_cache()
        gc.collect()
        # print(data_dict)
    last_part = os.path.basename(args.eval_dir)
    if not args.step_accuracy:
        torch.save(data_dict,f"{args.save_dir}/{last_part}_AE.pt")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default=None) 
    parser.add_argument("--model_dir", type=str, default=None) 
    parser.add_argument('--train', action='store_true', help='Enable training')
    parser.add_argument('--no-train', action='store_false', dest='train', help='Disable training')
    parser.add_argument("--gpu_num", type=int, default=0)
    parser.add_argument("--eval_dir", type=str, default=None) 
    parser.add_argument("--save_dir", type=str, default=None) 
    parser.add_argument("--group_num", type=int, default=None, required=False)
    parser.add_argument("--step_accuracy", type=bool, default=False, required=False)
    args = parser.parse_args()
    
    main(args)