from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
from tqdm import tqdm
from glob import glob
from mikatools import *
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from sklearn.cluster import AffinityPropagation
import os

"""

Running this script will get embeddings for the data in the JSON files were produced by the crawling script.


embed_files() takes min_messages. By default this is set to 20, so only chatters with at least 20 messages will be considered.

This will create _clusters.json which have a list of clusters. Clusters with only one chatter are excluded.

"""



task = "SEMANTIC_SIMILARITY"
model_name = "text-embedding-004"
model = TextEmbeddingModel.from_pretrained(model_name)

def embed(text):
	inputs = [TextEmbeddingInput(text, task)]
	embeddings = model.get_embeddings(inputs)
	return embeddings[0].values


def embed_files(min_messages=20):
	for file in glob("./*_chat_log.json"):
		print("Embedding", file)
		d = json_load(file)
		user_key = []
		embeddings = []
		for user, data in tqdm(list(d.items())):
			if len(data) >= min_messages:
				user_key.append(user)
				text = "\n".join([x["message"] for x in data])
				embedding = embed(text)
				embeddings.append(embedding)
		json_dump(user_key, file.replace(".json","_userkey.json"))
		pickle_dump(embeddings, file.replace(".json","_embeddings.bin"))

def _get_matrix(embeddings):
    m = []
    embeddings = [np.array(e) for e in embeddings]
    for em in embeddings:
        r = []
        for e in embeddings:
            r.append(cosine_similarity(em.reshape(1, -1), e.reshape(1, -1))[0][0])
        m.append(r)
    return m

def _group_chatters(chatters, labels):
    groups = []
    for x in range(len(set(labels))):
        groups.append([])
    for i, label in enumerate(labels):
        groups[label].append(chatters[i])
    return groups

def semantic_clusters(embeddings, chatters):
    m = np.array(_get_matrix(embeddings))
    agg = AffinityPropagation(affinity="precomputed")
    u = agg.fit_predict(m)
    return _group_chatters(chatters, agg.labels_)

def _clean_clusters(clusters):
	c = []
	for cl in clusters:
		if len(cl) > 1:
			c.append(cl)
	return c

def cluster_files():
	for file in glob("./*_chat_log.json"):
		print("Clustering", file)
		chatters = json_load(file.replace(".json","_userkey.json"))
		embeddings = pickle_load(file.replace(".json","_embeddings.bin"))
		clusters = semantic_clusters(embeddings, chatters)
		clusters = _clean_clusters(clusters)
		json_dump(clusters, file.replace(".json","_clusters.json"))



embed_files()
cluster_files()