"""
Tests for OTEL Feedback Computation.
"""

import gc
import time
from typing import Callable, List, Tuple
import weakref

import numpy as np
import pandas as pd
import pytest
from trulens.apps.app import TruApp
from trulens.core.feedback import Feedback
from trulens.core.feedback.feedback_function_input import FeedbackFunctionInput
from trulens.core.feedback.selector import Selector
from trulens.core.feedback.selector import Trace
from trulens.core.otel.instrument import instrument
from trulens.core.session import TruSession
from trulens.feedback.computer import MinimalSpanInfo
from trulens.feedback.computer import RecordGraphNode
from trulens.feedback.computer import _compute_feedback
from trulens.feedback.computer import _flatten_inputs
from trulens.feedback.computer import _group_kwargs_by_selectors
from trulens.feedback.computer import _remove_already_computed_feedbacks
from trulens.feedback.computer import _validate_unflattened_inputs
from trulens.feedback.computer import compute_feedback_by_span_group
from trulens.otel.semconv.trace import SpanAttributes

from tests.util.mock_otel_feedback_computation import (
    all_retrieval_span_attributes,
)
from tests.util.mock_otel_feedback_computation import feedback_function
from tests.util.otel_test_case import OtelTestCase

try:
    # These imports require optional dependencies to be installed.
    from trulens.apps.langchain import TruChain

    import tests.unit.test_otel_tru_chain
except Exception:
    pass


def _convert_events_to_MinimalSpanInfos(
    events: pd.DataFrame,
) -> List[MinimalSpanInfo]:
    ret = []
    for _, row in events.iterrows():
        span = MinimalSpanInfo(
            span_id=row["trace"]["span_id"],
            parent_span_id=row["record"]["parent_span_id"] or None,
            attributes=row["record_attributes"],
            resource_attributes=row["resource_attributes"],
        )
        ret.append(span)
    return ret


class _TestApp:
    @instrument(attributes={SpanAttributes.SPAN_GROUPS: "span_group"})
    def call0(self, span_group: str, a0: int, b0: int, c0: int) -> List[int]:
        return [a0, b0, c0]

    @instrument()
    def call1(self, a1: int, b1: int, c1: int) -> List[int]:
        return [a1, b1, c1]

    @instrument(
        attributes=lambda ret, exception, *args, **kwargs: {
            SpanAttributes.SPAN_GROUPS: [kwargs["span_group"]]
        }
    )
    def call2(self, span_group: str, a2: int, b2: int, c2: int) -> List[int]:
        self.call0(span_group, 2 * a2, 2 * b2, 2 * c2)
        return [a2, b2, c2]

    @instrument(attributes={SpanAttributes.SPAN_GROUPS: "span_group"})
    def call3(
        self, span_group: str, a3: int, b3: int, c3: int, num_call3_calls: int
    ) -> List[int]:
        for _ in range(num_call3_calls):
            self.call0(span_group, 3 * a3, 3 * b3, 3 * c3)
        return [a3, b3, c3]

    @instrument(attributes={SpanAttributes.SPAN_GROUPS: "span_group"})
    def call4(
        self, span_group: str, a4: int, b4: int, c4: int, num_call3_calls: int
    ) -> List[int]:
        for _ in range(num_call3_calls):
            self.call0(span_group, 4 * a4, 4 * b4, 4 * c4)
        return [a4, b4, c4]

    @instrument()
    def query(self, question: str) -> str:
        # 1. Many attributes from one span that there are many of.
        self.call1(1, 1, 1)
        self.call1(2, 2, 2)
        self.call1(3, 3, 3)
        # 2. Attributes across spans where span groups are used.
        self.call2("10", 10, 10, 10)
        self.call2("11", 11, 11, 11)
        self.call2("12", 12, 12, 12)
        # 3. Attributes across spans where only one has more than one.
        self.call3("20", 20, 20, 20, 3)
        # 4. Attributes across spans where two have more than one (error case).
        self.call4("30", 30, 30, 30, 1)
        self.call4("30", 30, 30, 30, 1)
        # Return value is irrelevant.
        return "Kojikun"


@pytest.mark.optional
class TestOtelFeedbackComputation(OtelTestCase):
    @pytest.mark.skip(reason="Golden file content mismatch in CI")
    def test_feedback_computation(self) -> None:
        # Create app.
        rag_chain = (
            tests.unit.test_otel_tru_chain.TestOtelTruChain._create_simple_rag()
        )
        tru_recorder = TruChain(
            rag_chain,
            app_name="Simple RAG",
            app_version="v1",
            main_method=rag_chain.invoke,
        )
        # Record and invoke.
        tru_recorder.instrumented_invoke_main_method(
            run_name="test run",
            input_id="42",
            ground_truth_output="Like attention but with more heads.",
            main_method_args=("What is multi-headed attention?",),
        )
        TruSession().force_flush()
        # Compute feedback on record we just ingested.
        events = self._get_events()
        spans = _convert_events_to_MinimalSpanInfos(events)
        record_root = RecordGraphNode.build_graph(spans)
        f_baby = Feedback(
            feedback_function, name="baby_grader", higher_is_better=True
        )
        _compute_feedback(
            record_root,
            "baby_grader",
            f_baby,
            True,
            all_retrieval_span_attributes,
        )
        TruSession().force_flush()
        # Compare results to expected.
        self._compare_events_to_golden_dataframe(
            "tests/unit/static/golden/test_otel_feedback_computation__test_feedback_computation.csv"
        )

    def test_compute_feedback_by_span_group(self) -> None:
        # Create app.
        app = _TestApp()
        tru_app = TruApp(
            app, app_name="Test App", app_version="v1", main_method=app.query
        )
        # Record and invoke.
        tru_app.instrumented_invoke_main_method(
            run_name="test run",
            input_id="42",
            main_method_kwargs={
                "question": "What's the population of the capital of Japan?"
            },
        )
        TruSession().force_flush()
        events = self._get_events()
        # Compute feedback on record we just ingested.
        get_selector = lambda s: Selector(
            span_name=f"tests.unit.test_otel_feedback_computation._TestApp.call{s[1]}",
            span_attribute=f"{SpanAttributes.CALL.KWARGS}.{s}",
        )
        # Case 1. Two attributes from one function that has multiple (three)
        #         invocations.
        f1 = Feedback(
            lambda a1, b1: 0.9 if a1 == b1 else 0.1,
            name="blah1",
            higher_is_better=True,
        )
        compute_feedback_by_span_group(
            events,
            f1,
            selectors={"a1": get_selector("a1"), "b1": get_selector("b1")},
        )
        # Case 2. Attributes across functions with span groups.
        f2 = Feedback(
            lambda a2, a0: 0.9 if 2 * a2 == a0 else 0.1,
            name="blah2",
            higher_is_better=True,
        )
        compute_feedback_by_span_group(
            events,
            f2,
            selectors={"a2": get_selector("a2"), "a0": get_selector("a0")},
        )
        # Case 3. Attributes across functions with span groups where one
        #         function is invoked once and the other multiple (three)
        #         times.
        f3 = Feedback(
            lambda a3, a0: 0.9 if 3 * a3 == a0 else 0.1,
            name="blah3",
            higher_is_better=True,
        )
        compute_feedback_by_span_group(
            events,
            f3,
            selectors={"a3": get_selector("a3"), "a0": get_selector("a0")},
        )
        # Case 4. Attributes across functions where both functions are invoked
        #         more than once (error case).
        with self.assertRaisesRegex(
            ValueError,
            "^No feedbacks were computed!$",
        ):
            f4 = Feedback(
                lambda a4, a0: 0.9 if 4 * a4 == a0 else 0.1,
                name="blah4",
                higher_is_better=True,
            )
            compute_feedback_by_span_group(
                events,
                f4,
                selectors={"a4": get_selector("a4"), "a0": get_selector("a0")},
            )
        # Compare results to expected.
        TruSession().force_flush()
        events = self._get_events()
        eval_root_record_attributes = [
            curr["record_attributes"]
            for _, curr in events.iterrows()
            if curr["record_attributes"][SpanAttributes.SPAN_TYPE]
            == SpanAttributes.SpanType.EVAL_ROOT
        ]
        expected_case_number = [1, 1, 1, 2, 2, 2, 3, 3, 3]
        self.assertEqual(
            len(expected_case_number),
            len(eval_root_record_attributes),
        )
        for curr in eval_root_record_attributes:
            self.assertEqual(curr[SpanAttributes.EVAL_ROOT.SCORE], 0.9)
        for i in range(len(expected_case_number)):
            self.assertEqual(
                eval_root_record_attributes[i][
                    SpanAttributes.EVAL_ROOT.METRIC_NAME
                ],
                f"blah{expected_case_number[i]}",
            )

    def test__group_kwargs_by_selectors(self) -> None:
        # Test empty case.
        self.assertEqual([], _group_kwargs_by_selectors({}))
        # Test single selector case.
        self.assertEqual(
            [("input",)],
            _group_kwargs_by_selectors({
                "input": Selector(
                    span_type="a", span_name="b", span_attribute="c"
                )
            }),
        )
        # Test complex grouping.
        kwarg_to_selector = {
            "input1": Selector(
                span_type="a", span_name="b", span_attribute="x"
            ),
            "input2": Selector(
                span_type="a", span_name="b", span_attribute="x"
            ),
            "input3": Selector(
                span_type="c", span_name="d", span_attribute="x"
            ),
            "input4": Selector(span_type="c", span_attribute="x"),
            "input5": Selector(
                span_type="c", span_name="e", span_attribute="x"
            ),
        }
        self.assertEqual(
            [
                ("input1", "input2"),
                ("input3",),
                ("input4",),
                ("input5",),
            ],
            sorted(_group_kwargs_by_selectors(kwarg_to_selector)),
        )

    def test__validate_unflattened_inputs(self) -> None:
        kwarg_groups = [("a",), ("b",)]
        record_id = "record_id1"
        span_group = "span_group1"
        record_ids_with_record_roots = [record_id]
        feedback_name = "feedback1"
        # Happy path.
        unflattened_inputs = {
            (record_id, span_group): {
                ("a",): [{"a": 1}, {"a": 1.1}],
                ("b",): [{"b": 2}],
            }
        }
        res = _validate_unflattened_inputs(
            unflattened_inputs,
            kwarg_groups,
            record_ids_with_record_roots,
            feedback_name,
        )
        self.assertEqual(
            {
                (record_id, span_group): {
                    ("a",): [{"a": 1}, {"a": 1.1}],
                    ("b",): [{"b": 2}],
                }
            },
            res,
        )
        # Missing inputs are removed.
        unflattened_inputs = {(record_id, span_group): {("a",): [{"a": 1}]}}
        res = _validate_unflattened_inputs(
            unflattened_inputs,
            kwarg_groups,
            record_ids_with_record_roots,
            feedback_name,
        )
        self.assertEqual({}, res)
        # Ambiguous inputs are removed and have error message.
        unflattened_inputs = {
            (record_id, span_group): {
                ("a",): [{"a": 1}, {"a", 1.1}],
                ("b",): [{"b": 2}, {"b": 2.2}],
            }
        }
        res = _validate_unflattened_inputs(
            unflattened_inputs,
            kwarg_groups,
            record_ids_with_record_roots,
            feedback_name,
        )
        self.assertEqual({}, res)
        # Records without a record root are removed and have error message.
        invalid_record_id = "invalid_record_id"
        unflattened_inputs = {
            (invalid_record_id, span_group): {
                ("a",): [{"a": 1}, {"a": 1.1}],
                ("b",): [{"b": 2}],
            }
        }
        res = _validate_unflattened_inputs(
            unflattened_inputs,
            kwarg_groups,
            record_ids_with_record_roots,
            feedback_name,
        )
        self.assertEqual({}, res)

    def test__flatten_inputs(self) -> None:
        record_id = "record_id1"
        span_group = "span_group1"
        unflattened_inputs = {
            (record_id, span_group): {
                ("a",): [{"a": 1}, {"a": 1.1}],
                ("b",): [{"b": 2}],
            }
        }
        res = _flatten_inputs(unflattened_inputs)
        self.assertEqual(
            [
                (record_id, span_group, {"a": 1, "b": 2}),
                (record_id, span_group, {"a": 1.1, "b": 2}),
            ],
            res,
        )

    def test__remove_already_computed_feedbacks(self) -> None:
        events = pd.DataFrame({
            "record_attributes": [
                {
                    SpanAttributes.SPAN_TYPE: SpanAttributes.SpanType.EVAL_ROOT,
                    SpanAttributes.EVAL_ROOT.METRIC_NAME: "feedback1",
                    SpanAttributes.RECORD_ID: "record_id1",
                    SpanAttributes.EVAL_ROOT.SPAN_GROUP: "span_group1",
                    SpanAttributes.EVAL_ROOT.ARGS_SPAN_ID + ".a": "span_id1a",
                    SpanAttributes.EVAL_ROOT.ARGS_SPAN_ID + ".b": "span_id1b",
                    SpanAttributes.EVAL_ROOT.ARGS_SPAN_ATTRIBUTE
                    + ".a": "span_attribute1a",
                    SpanAttributes.EVAL_ROOT.ARGS_SPAN_ATTRIBUTE
                    + ".b": "span_attribute1b",
                },
            ]
        })
        flattened_inputs = [
            (
                "record_id1",
                "span_group1",
                {
                    "a": FeedbackFunctionInput(
                        span_id="span_id1a", span_attribute="span_attribute1a"
                    ),
                    "b": FeedbackFunctionInput(
                        span_id="span_id1b", span_attribute="span_attribute1b"
                    ),
                },
            ),
            (
                "record_id1",
                "span_group1",
                {
                    "a": FeedbackFunctionInput(
                        span_id="span_id1a", span_attribute="span_attribute1a"
                    ),
                    "b": FeedbackFunctionInput(
                        span_id="span_id2b", span_attribute="span_attribute1b"
                    ),
                },
            ),
        ]
        res = _remove_already_computed_feedbacks(
            events, "feedback1", flattened_inputs
        )
        self.assertEqual([flattened_inputs[1]], res)

    def _create_invoked_app_with_custom_feedback(
        self, higher_is_better: bool = True
    ) -> TruApp:
        # Create feedback function.
        def custom(input: str, output: str) -> float:
            if (
                input == "What is multi-headed attention?"
                and output == "This is a mocked response for prompt 0."
            ):
                return 0.42
            return 0.0

        f_custom = Feedback(
            custom, name="custom", higher_is_better=higher_is_better
        ).on({
            "input": Selector(
                span_type=SpanAttributes.SpanType.RECORD_ROOT,
                span_attribute=SpanAttributes.RECORD_ROOT.INPUT,
            ),
            "output": Selector(
                span_type=SpanAttributes.SpanType.RECORD_ROOT,
                span_attribute=SpanAttributes.RECORD_ROOT.OUTPUT,
            ),
        })

        # Create app.
        rag_chain = (
            tests.unit.test_otel_tru_chain.TestOtelTruChain._create_simple_rag()
        )
        tru_recorder = TruChain(
            rag_chain,
            app_name="Simple RAG",
            app_version="v1",
            main_method=rag_chain.invoke,
            feedbacks=[f_custom],
        )
        # Record and invoke.
        tru_recorder.instrumented_invoke_main_method(
            run_name="test run",
            input_id="42",
            main_method_args=("What is multi-headed attention?",),
        )
        TruSession().force_flush()
        return tru_recorder

    def test_custom_feedback(self) -> None:
        tru_recorder = self._create_invoked_app_with_custom_feedback(
            higher_is_better=False
        )
        # Compute feedback on record we just ingested.
        num_events = len(self._get_events())
        tru_recorder.compute_feedbacks()
        TruSession().force_flush()
        # Compare results to expected.
        events = self._get_events()
        self.assertEqual(num_events + 2, len(events))
        self.assertEqual(
            SpanAttributes.SpanType.RECORD_ROOT,
            events.iloc[0]["record_attributes"][SpanAttributes.SPAN_TYPE],
        )
        record_root_span_id = events.iloc[0]["trace"]["span_id"]
        eval_root_record_attributes = events.iloc[-2]["record_attributes"]
        self.assertEqual(
            SpanAttributes.SpanType.EVAL_ROOT,
            eval_root_record_attributes[SpanAttributes.SPAN_TYPE],
        )
        self.assertEqual(
            "custom",
            eval_root_record_attributes[SpanAttributes.EVAL_ROOT.METRIC_NAME],
        )
        self.assertEqual(
            0.42,
            eval_root_record_attributes[SpanAttributes.EVAL_ROOT.SCORE],
        )
        self.assertEqual(
            record_root_span_id,
            eval_root_record_attributes[
                SpanAttributes.EVAL_ROOT.ARGS_SPAN_ID + ".input"
            ],
        )
        self.assertEqual(
            SpanAttributes.RECORD_ROOT.INPUT,
            eval_root_record_attributes[
                SpanAttributes.EVAL_ROOT.ARGS_SPAN_ATTRIBUTE + ".input"
            ],
        )
        self.assertEqual(
            record_root_span_id,
            eval_root_record_attributes[
                SpanAttributes.EVAL_ROOT.ARGS_SPAN_ID + ".output"
            ],
        )
        self.assertEqual(
            SpanAttributes.RECORD_ROOT.OUTPUT,
            eval_root_record_attributes[
                SpanAttributes.EVAL_ROOT.ARGS_SPAN_ATTRIBUTE + ".output"
            ],
        )
        self.assertFalse(
            eval_root_record_attributes[
                SpanAttributes.EVAL_ROOT.HIGHER_IS_BETTER
            ],
        )
        eval_record_attributes = events.iloc[-1]["record_attributes"]
        self.assertEqual(
            SpanAttributes.SpanType.EVAL,
            eval_record_attributes[SpanAttributes.SPAN_TYPE],
        )
        self.assertEqual(
            "custom",
            eval_record_attributes[SpanAttributes.EVAL.METRIC_NAME],
        )
        self.assertEqual(
            0.42,
            eval_record_attributes[SpanAttributes.EVAL.SCORE],
        )
        # Check that when trying to compute feedbacks again, nothing happens.
        old_num_events = len(self._get_events())
        tru_recorder.compute_feedbacks(
            raise_error_on_no_feedbacks_computed=False
        )
        TruSession().force_flush()
        new_num_events = len(self._get_events())
        self.assertEqual(old_num_events, new_num_events)

    def test_trace_level_feedback(self) -> None:
        # Create feedback.
        def custom(trace: Trace) -> float:
            if not isinstance(trace, Trace):
                return -1
            if [
                curr["record_attributes"][SpanAttributes.CALL.FUNCTION].split(
                    "."
                )[-1]
                for _, curr in trace.events.iterrows()
            ] != ["b", "d", "e"]:
                return -2
            if len(trace.processed_content_roots) != 2:
                return -3
            b = trace.processed_content_roots[0]
            d = b.children[0]
            e = trace.processed_content_roots[1]
            bde = [b, d, e]
            if [curr.content for curr in bde] != ["B", "D", "E"]:
                return -4
            if [curr.parent for curr in bde] != [None, b, None]:
                return -5
            if [curr.children for curr in bde] != [[d], [], []]:
                return -6
            for i in range(len(trace.events)):
                if not pd.Series.equals(
                    bde[i].event, trace.events.iloc[i].drop("processed_content")
                ):
                    return -7
            return 0.21

        f_custom = Feedback(custom, name="custom").on({
            "trace": Selector(
                trace_level=True,
                span_type=SpanAttributes.SpanType.RETRIEVAL,
                span_attributes_processor=lambda attr: attr[
                    SpanAttributes.CALL.FUNCTION
                ]
                .split(".")[-1]
                .upper(),
            )
        })

        # Create app.
        class _App:
            @instrument()
            def a(self) -> None:
                self.b()
                self.e()

            @instrument(span_type=SpanAttributes.SpanType.RETRIEVAL)
            def b(self) -> None:
                self.c()

            @instrument()
            def c(self) -> None:
                self.d()

            @instrument(span_type=SpanAttributes.SpanType.RETRIEVAL)
            def d(self) -> None:
                pass

            @instrument(span_type=SpanAttributes.SpanType.RETRIEVAL)
            def e(self) -> None:
                pass

        app = _App()
        tru_app = TruApp(
            app,
            app_name="test_trace_level_feedback",
            app_version="v1",
            feedbacks=[f_custom],
        )
        # Record and invoke.
        tru_app.stop_evaluator()
        with tru_app:
            app.a()
        # Compute feedback on record we just ingested.
        TruSession().force_flush()
        tru_app.compute_feedbacks()
        TruSession().force_flush()
        # Ensure feedback was computed.
        events = self._get_events()
        self.assertEqual(7, len(events))
        self.assertEqual(
            ["a", "b", "c", "d", "e", "", "custom"],
            [
                curr["record_attributes"]
                .get(SpanAttributes.CALL.FUNCTION, ".")
                .split(".")[-1]
                for _, curr in events.iterrows()
            ],
        )
        self.assertEqual(
            0.21,
            events.iloc[-2]["record_attributes"][
                SpanAttributes.EVAL_ROOT.SCORE
            ],
        )

    def test_evaluator(self) -> None:
        tru_recorder = self._create_invoked_app_with_custom_feedback()
        num_events = len(self._get_events())
        # Wait for there to be a feedback computed.
        self._wait(lambda: len(self._get_events()) > num_events)
        events = self._get_events()
        self.assertEqual(num_events + 2, len(events))
        self.assertEqual(
            SpanAttributes.SpanType.RECORD_ROOT,
            events.iloc[0]["record_attributes"][SpanAttributes.SPAN_TYPE],
        )
        # Ensure thread to be stopped when app is garbage collected.
        tru_recorder_ref = weakref.ref(tru_recorder)
        evaluator_ref = weakref.ref(tru_recorder._evaluator)
        thread_ref = weakref.ref(tru_recorder._evaluator._thread)
        del tru_recorder
        gc.collect()
        self.assertCollected(tru_recorder_ref)
        self._wait(lambda: thread_ref() is None)
        self.assertCollected(evaluator_ref)
        self.assertCollected(thread_ref)

    def test_aggregation(self) -> None:
        # Create feedback function.
        def custom_with_explanations(a: float, b: float) -> Tuple[float, dict]:
            return a * b, {"explanation": f"{a} * {b}"}

        f_custom = Feedback(custom_with_explanations, name="custom").on({
            "a": Selector(
                span_type=SpanAttributes.SpanType.RECORD_ROOT,
                function_attribute="xs",
                collect_list=False,
            ),
            "b": Selector(
                span_type=SpanAttributes.SpanType.RECORD_ROOT,
                function_attribute="ys",
                collect_list=False,
            ),
        })

        # Create app.
        class _App:
            @instrument(span_type=SpanAttributes.SpanType.RECORD_ROOT)
            def invoke(self, xs: List[int], ys: List[int]) -> float:
                return 0.0

        app = _App()
        tru_app = TruApp(
            app, app_name="Simple App", app_version="v1", feedbacks=[f_custom]
        )
        # Record and invoke.
        with tru_app:
            app.invoke([2, 3], [5, 7])
        TruSession().force_flush()
        # Compute feedback on record we just ingested.
        tru_app.compute_feedbacks()
        TruSession().force_flush()
        # Compare results to expected.
        events = self._get_events()
        self.assertEqual(len(events), 6)
        expected_span_types = [
            SpanAttributes.SpanType.RECORD_ROOT,
            SpanAttributes.SpanType.EVAL_ROOT,
        ] + 4 * [SpanAttributes.SpanType.EVAL]
        self.assertListEqual(
            expected_span_types,
            [
                curr["record_attributes"][SpanAttributes.SPAN_TYPE]
                for _, curr in events.iterrows()
            ],
        )
        expected_sub_scores = [2 * 5, 2 * 7, 3 * 5, 3 * 7]
        self.assertEqual(
            np.mean(expected_sub_scores),
            events.iloc[1]["record_attributes"][SpanAttributes.EVAL_ROOT.SCORE],
        )
        self.assertListEqual(
            expected_sub_scores,
            [
                curr["record_attributes"][SpanAttributes.EVAL.SCORE]
                for _, curr in events.iloc[2:].iterrows()
            ],
        )
        expected_explanations = ["2 * 5", "2 * 7", "3 * 5", "3 * 7"]
        self.assertListEqual(
            expected_explanations,
            [
                curr["record_attributes"][
                    f"{SpanAttributes.EVAL.METADATA}.explanation"
                ]
                for _, curr in events.iloc[2:].iterrows()
            ],
        )

    def test_compute_feedbacks_on_events(self) -> None:
        # Create apps.
        class _TestApp:
            @instrument()
            def greet(self, name: str) -> str:
                return f"Hello, {name}!"

        app = _TestApp()
        tru_app_1 = TruApp(
            app, app_name="Test App 1", app_version="v1", main_method=app.greet
        )
        tru_app_2 = TruApp(
            app, app_name="Test App 2", app_version="v1", main_method=app.greet
        )
        # Record and invoke.
        with tru_app_1 as recording1:
            app.greet("Kojikun")
        record_id_1 = recording1.get().record_id
        with tru_app_2 as recording2:
            app.greet("Not Kojikun")
        record_id_2 = recording2.get().record_id
        TruSession().force_flush()

        # Create feedback function.
        def best_baby_checker(input: str) -> float:
            if input in ["Kojikun", "Nolan", "Sachiboy"]:
                return 1.0
            return 0.0

        f_best_baby_checker = Feedback(
            best_baby_checker, name="best_baby_checker", higher_is_better=True
        ).on({
            "input": Selector(
                span_type=SpanAttributes.SpanType.RECORD_ROOT,
                span_attribute=SpanAttributes.RECORD_ROOT.INPUT,
            )
        })
        # Compute feedbacks on events.
        events = TruSession().get_events(app_name=None, app_version=None)
        self.assertEqual(2, len(events))
        TruSession().compute_feedbacks_on_events(events, [f_best_baby_checker])
        TruSession().force_flush()
        # Compare results to expected.
        events = self._get_events()
        self.assertEqual(6, len(events))
        self.assertEqual(
            [
                SpanAttributes.SpanType.RECORD_ROOT,
                SpanAttributes.SpanType.RECORD_ROOT,
                SpanAttributes.SpanType.EVAL_ROOT,
                SpanAttributes.SpanType.EVAL,
                SpanAttributes.SpanType.EVAL_ROOT,
                SpanAttributes.SpanType.EVAL,
            ],
            [
                curr["record_attributes"][SpanAttributes.SPAN_TYPE]
                for _, curr in events.iterrows()
            ],
        )
        self.assertEqual(
            1.0,
            events.iloc[2]["record_attributes"][SpanAttributes.EVAL_ROOT.SCORE],
        )
        self.assertEqual(
            0.0,
            events.iloc[4]["record_attributes"][SpanAttributes.EVAL_ROOT.SCORE],
        )
        self.assertEqual(
            record_id_1,
            events.iloc[2]["record_attributes"][SpanAttributes.RECORD_ID],
        )
        self.assertEqual(
            record_id_2,
            events.iloc[4]["record_attributes"][SpanAttributes.RECORD_ID],
        )

    @staticmethod
    def _wait(
        f: Callable[[], bool],
        timeout_in_seconds: float = 30,
        cooldown_in_seconds: float = 1,
    ) -> None:
        start_time = time.time()
        while time.time() - start_time < timeout_in_seconds:
            if f():
                return
            time.sleep(cooldown_in_seconds)
        raise TimeoutError(
            f"Timed out waiting for condition to be met after {timeout_in_seconds} seconds."
        )
