Source code for qtealeaves.observables.tnobase

# This code is part of qtealeaves.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Abstract base class for observables.
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Iterator, Self

import h5py

from qtealeaves.abstracttns.abstract_tn import _AbstractTN
from qtealeaves.tooling import QTeaLeavesError
from qtealeaves.tooling.parameterized import _ParameterizedClass

if TYPE_CHECKING:
    from qtealeaves.operators import TNOperators
else:
    TNOperators = Any

__all__ = ["_TNObsBase"]


[docs] class _TNObsBase(_ParameterizedClass, ABC): """ Abstract base class for observables. Attributes ---------- name: str Name to identify the observable results_buffer : dict Store the results of the measurement of the observable measurable_ansaetze : tuple(str) Tuple of ansatzes for which this observable is measurable """ _measurable_ansaetze: tuple[type[_AbstractTN], ...] = () def __init__(self, name: str): self.name = [name] self.results_buffer: dict[str, Any] = {} @property def measurable_ansaetze(self) -> tuple[type[_AbstractTN], ...]: """ Tuple of ansatzes for which this observable is measurable. """ return self._measurable_ansaetze
[docs] def check_measurable(self, tn_type: type[_AbstractTN]) -> bool: """ Checks wether or not the observable can be measured for a given ansatz. Args: tn_type (type): Label of the TN ansatz """ return any(issubclass(tn_type, cls) for cls in self._measurable_ansaetze)
[docs] @classmethod @abstractmethod def empty(cls) -> Self: """ Constructor of the class without any content. """ raise NotImplementedError("Must be implemented by actual class.")
def __len__(self) -> int: """ Provide appropriate length method """ return len(self.name) @abstractmethod def __iadd__(self, other: Any) -> Self: """ Overwrite operator ``+=`` to simplify syntax. """ raise NotImplementedError("Must be implemented by actual class.")
[docs] def measure( self, state: _AbstractTN, # pylint: disable=unused-argument operators: TNOperators, # pylint: disable=unused-argument **kwargs: Any, # pylint: disable=unused-argument ) -> dict[str, Any]: """ Run the measurement of the observable for a given state and save it in the results buffer. Args: state : instance of class from :py:mod:`qtealeaves.emulator` The state to measure the observable on. operators : instance of :py:class:`qtealeaves.operators.TNOperators` The operators to be measured. **kwargs : dict Additional keyword arguments, which might be needed for the measurement of derived observables. Returns: dict Dictionary with the results of the measurement. """ # Cover the case the derived class does not implement this method, but is also not used. if len(self.name) == 0: return self.results_buffer raise NotImplementedError("This observable has no measurement implemented yet.")
[docs] def add_trajectories(self, all_results: dict, new: dict) -> dict: """ Add the observables for different quantum trajectories. **Arguments** all_results : dict Dictionary with observables. new : dict Dictionary with new observables to add to all_results. """ for name in self.name: if name not in all_results: all_results[name] = new[name] if isinstance(new[name], dict): raise QTeaLeavesError("Dictionary addition not implemented.") else: all_results[name] += new[name] return all_results
[docs] def avg_trajectories(self, all_results: dict, num_trajectories: int) -> dict: """ Get the average of quantum trajectories observables. **Arguments** all_results : dict Dictionary with observables. num_trajectories : int Total number of quantum trajectories. """ for name in self.name: if isinstance(all_results[name], dict): raise QTeaLeavesError("Dictionary addition not implemented.") all_results[name] /= num_trajectories return all_results
[docs] def write_results( self, fh: h5py.File, state_ansatz: type[_AbstractTN], # pylint: disable-next=unused-argument **kwargs: Any, ) -> None: """ Write the results to a HDF5 file. The results have to be stored in the result buffer. **Arguments** fh : h5py.File Open HDF5 file where the results are written to. state_ansatz : type[_AbstractTN] Label identifying the state ansatz currently in use. """ is_measured = self.check_measurable(state_ansatz) fg = fh.create_group(str(self), track_order=True) fg.attrs["is_measured"] = is_measured if is_measured: for name in self.name: fg.create_dataset(name, data=self.results_buffer[name]) self.results_buffer = {}
# pylint: disable-next=unused-argument
[docs] def read(self, fh: h5py.File, **kwargs: Any) -> Iterator[tuple[str, Any]]: """ Read observables from HDF5 file. fh : h5py.File Read the information about the measurements from this HDF5 file. """ # check that group exists if str(self) not in fh: raise QTeaLeavesError("Observable group not found in file.") fg = fh[str(self)] is_measured = fg.attrs.get("is_measured", False) for name in self.name: if not is_measured: yield name, None else: yield name, fg[name][()]
def __repr__(self) -> str: """ Return the class name as representation. """ return self.__class__.__name__
[docs] def collect_operators(self) -> Iterator[tuple[str, Any]]: """ Observables which require operators must provide this method, because operators with symmetries might not be written otherwise. **Details** The problems are, for example, correlations with equal operators because they cannot be contracted over their third link. """ raise NotImplementedError( "This observable does not support collecting operators." )
[docs] def get_id(self) -> str: """ Get the address in memory, which is useful instead of hashing the complete object or for comparisons. """ return hex(id(self))