#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 27 17:59:16 2024

@author: brown

Script to obtain the frequent prophages that are in more than 10% of the genomes per each frequent MLST
"""


import pandas as pd
import re
import csv
import os
import sys


def load_mlst_data(file_path): # Build a dictionary MLST : genomes
    data_dict = {}
    with open(file_path, mode= 'r', newline='') as f:
        tsv_file = csv.DictReader(f, delimiter="\t")
        for row in tsv_file:
            key=str(row[tsv_file.fieldnames[1]])
            values = row[tsv_file.fieldnames[0]]
            if key in data_dict:
                data_dict[key].append(values)
            else:
                data_dict[key] = [values]
        return(data_dict)
    
    
    
def main():
    mlst_path = sys.argv[1] # mlst_ab_freq_wored100.tsv
    phages_path = sys.argv[2] # st_phage_phigaro_cl_wored100.tsv
    output_path = sys.argv[3]
    phages_df = pd.read_csv(phages_path, sep = "\t", names=['strain','phage','phigaro'])
    mlst_dict = load_mlst_data(mlst_path)
    
    mlst_output = []
    phages_output = []
    for key,value in mlst_dict.items(): 
        phages_freq = []
        n_strains = len(mlst_dict[key])
        n_strains_10 = int(n_strains)*0.1 # Prophages present in more than 10% of genomes per each frequent MLST
        strains = set(mlst_dict[key])
        subset_phages = phages_df[phages_df['strain'].isin(strains)]
        phage_list = set(subset_phages['phage'])
        for ph in phage_list: # Get only the prophages that pass the threshold
            count = subset_phages['phage'].value_counts()[ph]
            if count >= n_strains_10:
                phages_freq.append(ph)
                
        mlst_output.append(key)
        phages_output.append(phages_freq)
        
    # Output freq prophages per freq MLST        
    output = pd.DataFrame({'MLST':mlst_output, 'Phages':phages_output})
    output.to_csv(output_path, sep="\t")
    
    
if __name__ == "__main__":
    main()