import json
import unittest
from unittest.mock import MagicMock
from unittest.mock import patch

import pandas as pd
import pytest

try:
    from trulens.connectors.snowflake.dao.run import RunDao
    from trulens.core.enums import Mode

except Exception:
    RunDao = None
    Mode = None


@pytest.mark.snowflake
class TestRunDao(unittest.TestCase):
    def setUp(self):
        if RunDao is None or Mode is None:
            self.skipTest(
                "RunDao or Mode is not available because optional tests are disabled."
            )

        self.sf_session = MagicMock()
        self.sf_session.get_current_database.return_value = "DB"
        self.sf_session.get_current_schema.return_value = "SCH"
        dummy_sql = MagicMock()
        dummy_sql.collect.return_value = []
        self.sf_session.sql.return_value = dummy_sql
        self.dao = RunDao(snowpark_session=self.sf_session)

    @patch("trulens.connectors.snowflake.dao.run.execute_query")
    def test_create_new_run(self, mock_execute_query):
        object_name = "MY_AGENT"
        object_type = "EXTERNAL AGENT"
        object_version = "V1"
        run_name = "my_run"
        dataset_name = "db.schema.table"
        source_type = "TABLE"
        dataset_spec = {"col1": "col1"}

        req_payload = {
            "object_name": object_name,
            "object_type": object_type,
            "object_version": object_version,
            "run_name": run_name,
            "description": "desc",
            "run_metadata": {
                "labels": ["label"],
                "llm_judge_name": "mistral-large2",
                "mode": "APP_INVOCATION",
            },
            "source_info": {
                "name": dataset_name,
                "column_spec": dataset_spec,
                "source_type": source_type,
            },
        }
        req_payload_json = json.dumps(req_payload)

        self.dao.create_new_run(
            object_name=object_name,
            object_type=object_type,
            object_version=object_version,
            dataset_name=dataset_name,
            source_type=source_type,
            dataset_spec=dataset_spec,
            description="desc",
            label="label",
            llm_judge_name="mistral-large2",
            run_name=run_name,
            mode=Mode.APP_INVOCATION,
        )

        self.assertEqual(mock_execute_query.call_count, 2)

        for call in mock_execute_query.call_args_list:
            if call[0][1] == "SELECT SYSTEM$AIML_RUN_OPERATION('CREATE', ?);":
                actual_parameters = call[1].get("parameters", [])
                if actual_parameters:
                    actual_payload = json.loads(actual_parameters[0])
                    expected_payload = json.loads(req_payload_json)

                    print("Expected Payload:", expected_payload)
                    print("Actual Payload:", actual_payload)

                    # Perform deep comparison of the dictionaries (ignoring order)
                    self.assertEqual(actual_payload, expected_payload)

    @patch("trulens.connectors.snowflake.dao.run.execute_query")
    def test_create_new_run_log_ingestion_mode(self, mock_execute_query):
        object_name = "MY_AGENT"
        object_type = "EXTERNAL AGENT"
        object_version = "V1"
        run_name = "my_run_log_ingestion"
        dataset_name = "db.schema.table"
        source_type = "TABLE"
        dataset_spec = {"col1": "col1"}

        req_payload = {
            "object_name": object_name,
            "object_type": object_type,
            "object_version": object_version,
            "run_name": run_name,
            "description": "desc",
            "run_metadata": {
                "labels": ["label"],
                "llm_judge_name": "mistral-large2",
                "mode": "LOG_INGESTION",
            },
            "source_info": {
                "name": dataset_name,
                "column_spec": dataset_spec,
                "source_type": source_type,
            },
        }
        req_payload_json = json.dumps(req_payload)

        self.dao.create_new_run(
            object_name=object_name,
            object_type=object_type,
            object_version=object_version,
            dataset_name=dataset_name,
            source_type=source_type,
            dataset_spec=dataset_spec,
            description="desc",
            label="label",
            llm_judge_name="mistral-large2",
            run_name=run_name,
            mode=Mode.LOG_INGESTION,
        )

        self.assertEqual(mock_execute_query.call_count, 2)

        for call in mock_execute_query.call_args_list:
            if call[0][1] == "SELECT SYSTEM$AIML_RUN_OPERATION('CREATE', ?);":
                actual_parameters = call[1].get("parameters", [])
                if actual_parameters:
                    actual_payload = json.loads(actual_parameters[0])
                    expected_payload = json.loads(req_payload_json)

                    print("Expected Payload:", expected_payload)
                    print("Actual Payload:", actual_payload)

                    # Perform deep comparison of the dictionaries (ignoring order)
                    self.assertEqual(actual_payload, expected_payload)

    @patch("trulens.connectors.snowflake.dao.run.execute_query")
    def test_create_new_run_default_mode(self, mock_execute_query):
        """Test that when no mode is specified, it defaults to APP_INVOCATION."""
        object_name = "MY_AGENT"
        object_type = "EXTERNAL AGENT"
        object_version = "V1"
        run_name = "my_run_default"
        dataset_name = "db.schema.table"
        source_type = "TABLE"
        dataset_spec = {"col1": "col1"}

        req_payload = {
            "object_name": object_name,
            "object_type": object_type,
            "object_version": object_version,
            "run_name": run_name,
            "description": "desc",
            "run_metadata": {
                "labels": ["label"],
                "llm_judge_name": "mistral-large2",
                "mode": "APP_INVOCATION",  # Should default to this
            },
            "source_info": {
                "name": dataset_name,
                "column_spec": dataset_spec,
                "source_type": source_type,
            },
        }
        req_payload_json = json.dumps(req_payload)

        # Call without specifying mode parameter
        self.dao.create_new_run(
            object_name=object_name,
            object_type=object_type,
            object_version=object_version,
            dataset_name=dataset_name,
            source_type=source_type,
            dataset_spec=dataset_spec,
            description="desc",
            label="label",
            llm_judge_name="mistral-large2",
            run_name=run_name,
            # mode parameter omitted to test default behavior
        )

        self.assertEqual(mock_execute_query.call_count, 2)

        for call in mock_execute_query.call_args_list:
            if call[0][1] == "SELECT SYSTEM$AIML_RUN_OPERATION('CREATE', ?);":
                actual_parameters = call[1].get("parameters", [])
                if actual_parameters:
                    actual_payload = json.loads(actual_parameters[0])
                    expected_payload = json.loads(req_payload_json)

                    print("Expected Payload:", expected_payload)
                    print("Actual Payload:", actual_payload)

                    # Perform deep comparison of the dictionaries (ignoring order)
                    self.assertEqual(actual_payload, expected_payload)

    @patch("trulens.connectors.snowflake.dao.run.execute_query")
    def test_get_run_no_result(self, mock_execute_query):
        # Simulate that get_run returns an empty DataFrame (no run exists).
        mock_execute_query.return_value = pd.DataFrame()
        result_df = self.dao.get_run(
            run_name="nonexistent_run",
            object_name="MY_AGENT",
            object_type="EXTERNAL AGENT",
        )
        self.assertTrue(result_df.empty)

    @patch("trulens.connectors.snowflake.dao.run.execute_query")
    def test_get_run_with_result(self, mock_execute_query):
        # Simulate that get_run returns a DataFrame with a single row.
        mock_execute_query.return_value = pd.DataFrame([
            {"run_name": "my_run", "run_status": "ACTIVE"}
        ])
        result_df = self.dao.get_run(
            run_name="my_run",
            object_name="MY_AGENT",
            object_type="EXTERNAL AGENT",
        )
        self.assertIsInstance(result_df, pd.DataFrame)
        self.assertEqual(result_df.iloc[0]["run_name"], "my_run")
        self.assertEqual(result_df.iloc[0]["run_status"], "ACTIVE")

    @patch("trulens.connectors.snowflake.dao.run.execute_query")
    def test_delete_run(self, mock_execute_query):
        req_payload = {
            "run_name": "my_run",
            "object_name": "MY_AGENT",
            "object_type": "EXTERNAL AGENT",
        }
        req_payload_json = json.dumps(req_payload)
        expected_query = "SELECT SYSTEM$AIML_RUN_OPERATION('DELETE', ?);"
        self.dao.delete_run("my_run", "MY_AGENT", "EXTERNAL AGENT")
        mock_execute_query.assert_called_once_with(
            self.sf_session, expected_query, parameters=(req_payload_json,)
        )

    def test_set_nested_value_simple(self):
        # Test with an empty dictionary and a simple nested key path.
        d = {}
        keys = ["a", "b", "c"]
        self.dao._set_nested_value(d, keys, 42)
        expected = {"a": {"b": {"c": 42}}}
        self.assertEqual(d, expected)

    def test_set_nested_value_existing(self):
        # Test updating a nested value where intermediate dicts exist.
        d = {"a": {"b": {"x": 100}}}
        keys = ["a", "b", "c"]
        self.dao._set_nested_value(d, keys, "new")
        expected = {"a": {"b": {"x": 100, "c": "new"}}}
        self.assertEqual(d, expected)

    def test_set_nested_value_overwrite(self):
        # Test overwriting an existing value.
        d = {"a": {"b": {"c": "old"}}}
        keys = ["a", "b", "c"]
        self.dao._set_nested_value(d, keys, "new")
        expected = {"a": {"b": {"c": "new"}}}
        self.assertEqual(d, expected)

    def test_update_run_metadata_field_masks(self):
        # Prepare an empty run metadata and a set of field updates.
        existing_run_metadata = {}
        field_updates = {
            "invocations.invocation_1.completion_status.record_count": 1,
            "metrics.metric_1.completion_status.status": "PARTIALLY_COMPLETED",
            "labels": ["new_label"],
            "llm_judge_name": "j1",
        }
        (
            updated_run_metadata,
            invocation_masks,
            metric_masks,
            computation_masks,
            non_map_masks,
        ) = self.dao._update_run_metadata_field_masks(
            existing_run_metadata, field_updates
        )

        expected_run_metadata = {
            "invocations": {
                "invocation_1": {
                    "id": "invocation_1",
                    "completion_status": {"record_count": 1},
                }
            },
            "metrics": {
                "metric_1": {
                    "id": "metric_1",
                    "completion_status": {"status": "PARTIALLY_COMPLETED"},
                }
            },
            "labels": ["new_label"],
            "llm_judge_name": "j1",
        }
        self.assertEqual(updated_run_metadata, expected_run_metadata)
        self.assertEqual(
            invocation_masks,
            {"invocation_1": ["completion_status.record_count"]},
        )
        self.assertEqual(
            metric_masks, {"metric_1": ["completion_status.status"]}
        )
        self.assertEqual(computation_masks, {})
        # non_map_masks is returned as a list but order doesn't matter.
        self.assertCountEqual(non_map_masks, ["labels", "llm_judge_name"])

    def test_update_run_metadata_field_masks_invalid(self):
        # Test that an invalid update key (with insufficient parts) raises a ValueError.
        existing_run_metadata = {}
        field_updates = {
            "invocations.invocation_1": 1  # Should be at least "invocations.<id>.<field>"
        }
        with self.assertRaises(ValueError):
            self.dao._update_run_metadata_field_masks(
                existing_run_metadata, field_updates
            )

    def test_update_run_metadata_multiple_updates(self):
        # Start with an empty run metadata.
        existing_run_metadata = {}

        # Step 1: Add new invocation and metric entries.
        field_updates_1 = {
            "invocations.inv_1.completion_status.record_count": 100,
            "metrics.met_1.completion_status.status": "COMPLETED",
        }
        (
            metadata_1,
            invocation_masks_1,
            metric_masks_1,
            computation_masks_1,
            non_map_masks_1,
        ) = self.dao._update_run_metadata_field_masks(
            existing_run_metadata, field_updates_1
        )

        expected_metadata_1 = {
            "invocations": {
                "inv_1": {
                    "id": "inv_1",
                    "completion_status": {"record_count": 100},
                }
            },
            "metrics": {
                "met_1": {
                    "id": "met_1",
                    "completion_status": {"status": "COMPLETED"},
                }
            },
        }
        self.assertEqual(metadata_1, expected_metadata_1)
        self.assertEqual(
            invocation_masks_1, {"inv_1": ["completion_status.record_count"]}
        )
        self.assertEqual(
            metric_masks_1, {"met_1": ["completion_status.status"]}
        )
        self.assertEqual(computation_masks_1, {})
        self.assertEqual(set(non_map_masks_1), set())

        # Step 2: Update the existing invocation and metric entries.
        field_updates_2 = {
            "invocations.inv_1.completion_status.record_count": 150,
            "metrics.met_1.completion_status.status": "FAILED",
        }
        (
            metadata_2,
            invocation_masks_2,
            metric_masks_2,
            computation_masks_2,
            non_map_masks_2,
        ) = self.dao._update_run_metadata_field_masks(
            metadata_1, field_updates_2
        )

        expected_metadata_2 = {
            "invocations": {
                "inv_1": {
                    "id": "inv_1",
                    "completion_status": {"record_count": 150},
                }
            },
            "metrics": {
                "met_1": {
                    "id": "met_1",
                    "completion_status": {"status": "FAILED"},
                }
            },
        }
        self.assertEqual(metadata_2, expected_metadata_2)
        self.assertEqual(
            invocation_masks_2, {"inv_1": ["completion_status.record_count"]}
        )
        self.assertEqual(
            metric_masks_2, {"met_1": ["completion_status.status"]}
        )
        self.assertEqual(computation_masks_2, {})
        self.assertEqual(set(non_map_masks_2), set())

        # Step 3: Update only a top-level field (e.g. labels).
        field_updates_3 = {"labels": ["final_label"]}
        (
            metadata_3,
            invocation_masks_3,
            metric_masks_3,
            computation_masks_3,
            non_map_masks_3,
        ) = self.dao._update_run_metadata_field_masks(
            metadata_2, field_updates_3
        )

        expected_metadata_3 = {
            "invocations": {
                "inv_1": {
                    "id": "inv_1",
                    "completion_status": {"record_count": 150},
                }
            },
            "metrics": {
                "met_1": {
                    "id": "met_1",
                    "completion_status": {"status": "FAILED"},
                }
            },
            "labels": ["final_label"],
        }
        self.assertEqual(metadata_3, expected_metadata_3)
        # In this step, no nested update for invocations/metrics, so those masks should be empty.
        self.assertEqual(invocation_masks_3, {})
        self.assertEqual(metric_masks_3, {})
        self.assertEqual(computation_masks_3, {})
        self.assertEqual(set(non_map_masks_3), {"labels"})
