#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 17 12:21:52 2024

@author: brown

Script to build the absolute and relative frequency matrix to the figure 3B
"""

import pandas as pd
import re
import csv
import os

def load_mlst_data(file_path): # Dictionary of MLST : strains
    data_dict = {}
    with open(file_path, mode= 'r', newline='') as f:
        tsv_file = csv.DictReader(f, delimiter="\t")
        for row in tsv_file:
            key=str(row[tsv_file.fieldnames[1]])
            values = row[tsv_file.fieldnames[0]]
            if key in data_dict:
                data_dict[key].append(values)
            else:
                data_dict[key] = [values]
        return(data_dict)
    
def load_phages_data(file_path):  # Dictionary of strains : prophagess
    data_dict = {}
    with open(file_path, mode= 'r', newline='') as f:
        tsv_file = csv.DictReader(f, delimiter="\t")
        for row in tsv_file:
            key= row[tsv_file.fieldnames[0]]
            phages = row[tsv_file.fieldnames[1]]
            if key in data_dict:
                data_dict[key].append(phages)
            else:
                data_dict[key] = [phages]
        return(data_dict)
 
def build_matrix(mlst_dict,phages_dict): 
    matrix = {}
    matrix_freq = {}
    for mlst,strains in mlst_dict.items():  # Build absolute frequency matrix
        if mlst not in matrix:
            matrix[mlst] = {}
            matrix_freq[mlst] = {}
        total_st = 0
        for st in strains:
            phages = phages_dict.get(st, [])
            for ph in phages:
                if ph in matrix[mlst]:
                    matrix[mlst][ph] +=1 
                else: 
                    matrix[mlst][ph] = 1
                    matrix_freq[mlst][ph] = {}
            total_st += 1
                
        for ph in matrix[mlst]: # Build relative frequency matrix
            matrix_freq[mlst][ph] = round((matrix[mlst][ph] / total_st)*100,2) 
    
    # Convert the matrix dictionary to a DataFrame
    df_matrix = pd.DataFrame(matrix).fillna(0).T
    df_matrix_freq = pd.DataFrame(matrix_freq).fillna(0).T
        
    return df_matrix, df_matrix_freq
    
    
    
def main():
    # Loading data
    mlst_path = "./mlst_ab_freq_wored100.tsv"
    phages_path = "./st_phage_phigaro_cl_mlst8_nored100.tsv"
    mlst_dict = load_mlst_data(mlst_path)
    phages_dict = load_phages_data(phages_path)
    ####
    
    matrix, matrix_freq= build_matrix(mlst_dict,phages_dict) # Build both matrix

    matrix.to_csv('mlst_phages_counts.tsv', sep='\t')
    matrix_freq.to_csv('mlst_phages_freq.tsv', sep='\t')
    
    
    
    
    
if __name__ == "__main__":
    main()
