#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 17 10:32:21 2025

@author: brown

This script reads CD-HIT output files and builds a variant matrix with the most prevalent clusters. 
Additionally, it relates the most frequent variants to each major MLST.
"""
import pandas as pd
import re
import os
import csv


# Read .clstr files
def parse_clstr(filename, system_name, percentage = 0.01):
    total_seqs = count_total_sequences(filename)
    min_size = max(1, int(total_seqs * percentage))  # at least 1
    
    clusters = []
    current_cluster = []
    small_clusters = []

    with open(filename, 'r') as f:
        for line in f:
            if line.startswith('>Cluster'):
                if current_cluster: 
                    if len(current_cluster) >= min_size:
                        clusters.append(current_cluster)
                    else:
                        small_clusters.append(current_cluster)  
                current_cluster = []
            else:
                parts = line.strip().split('>')
                if len(parts) > 1:
                    match = re.search(r'ab\d{5}', line)
                    if match:
                        strain_id = match.group()
                    current_cluster.append(strain_id)
        
        # Last cluster
        if current_cluster:
            if len(current_cluster) >= min_size:
                clusters.append(current_cluster)
            else:
                small_clusters.append(current_cluster)
    variant_map = {}
    for idx, cluster in enumerate(clusters):
        variant_name = f"{system_name}-{idx+1}"  # RM-1, CRISPR-1, ...
        for strain in cluster:
            # If a strain has multiple variants, they are separated by ";"
            if strain in variant_map:
                if variant_map[strain] != variant_name:
                    variant_map[strain] += f";{variant_name}"
            else:
                variant_map[strain] = variant_name
                
     # Add small cluster of variants as "Others"
    for cluster in small_clusters:
        for strain in cluster:
            if strain not in variant_map:
                variant_map[strain] = f"Other_{system_name}"
                
    return variant_map

def count_total_sequences(clstr_path):
    count = 0
    with open(clstr_path, 'r') as f:
        for line in f:
            if not line.startswith(">Cluster"):
                count += 1
    return count

# Build clstr matrix
def build_variant_matrix(system_clstrs):
    strain_ids = set()
    system_variant_data = {}

    for system_name, clstr_path in system_clstrs.items():
        variant_map = parse_clstr(clstr_path, system_name)
        for strain, variant in variant_map.items():
            strain_ids.add(strain)
            system_variant_data.setdefault(strain, {})[system_name] = variant
    return system_variant_data, sorted(system_clstrs.keys()), sorted(strain_ids)

# Write clstr matrix
def write_variant_matrix(matrix, systems, strains, out_file):
    with open(out_file, "w") as out:
        writer = csv.writer(out, delimiter="\t")
        writer.writerow(["Strain"] + systems)
        for strain in strains:
            row = [strain]
            for sys in systems:
                row.append(matrix.get(strain, {}).get(sys, "NA"))
            writer.writerow(row)
 
# Load MLST files
def load_mlst(mlst_file):
    mlst_map = {}
    with open(mlst_file, 'r') as f:
        for line in f:
            if not line.strip():
                continue  # ignora líneas vacías
            parts = line.strip().split('\t')
            if len(parts) >= 2:
                strain, mlst = parts[0], parts[1]
                mlst_map[strain] = mlst
    return mlst_map

# Build MLST-clstr matrix
def write_strain_mlst_variant_matrix(matrix, mlst_map, systems, output_file):
    with open(output_file, "w") as out:
        writer = csv.writer(out, delimiter="\t")
        writer.writerow(["Strain", "MLST"] + systems)
        for strain in sorted(matrix):
            mlst = mlst_map.get(strain, "-")
            row = [strain, mlst]
            for system in systems:
                row.append(matrix[strain].get(system, "-"))
            writer.writerow(row)

if __name__ == "__main__":
    clstr_files = {
    "RM": "./R-M_TypeI_clstr.fasta.clstr",
    "RosmerTA": "./RosmerTA.clstr.clstr",
    "SspBCDE": "./SspBCDE_clstr.fasta.clstr",
    "PD-T4": "./PD-T4_clstr.fasta.clstr",
    "PD-T7": "./PD-T7_clstr.fasta.clstr",
    "Cas": "./Cas_clstr.fasta.clstr"
    }
    
    
    mlst_file = "./mlst_ab_freq_wored100.tsv"
    output_matrix = "variant_matrix.tsv"
    output_by_mlst = "mlst_defense_variants.tsv"
    
    # Step 1: build matrix of variantes per system
    matrix, systems, strains = build_variant_matrix(clstr_files)
    write_variant_matrix(matrix, systems, strains, output_matrix)
    
    # Step 2: build matrix of variants by MLST
    mlst_map = load_mlst(mlst_file)
    write_strain_mlst_variant_matrix(matrix, mlst_map, systems, output_by_mlst)
    
    
    

