# This code is part of qtealeaves.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
Abstract base class for observables.
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Iterator, Self
import h5py
from qtealeaves.abstracttns.abstract_tn import _AbstractTN
from qtealeaves.tooling import QTeaLeavesError
from qtealeaves.tooling.parameterized import _ParameterizedClass
if TYPE_CHECKING:
from qtealeaves.operators import TNOperators
else:
TNOperators = Any
__all__ = ["_TNObsBase"]
[docs]
class _TNObsBase(_ParameterizedClass, ABC):
"""
Abstract base class for observables.
Attributes
----------
name: str
Name to identify the observable
results_buffer : dict
Store the results of the measurement of the observable
measurable_ansaetze : tuple(str)
Tuple of ansatzes for which this observable is measurable
"""
_measurable_ansaetze: tuple[type[_AbstractTN], ...] = ()
def __init__(self, name: str):
self.name = [name]
self.results_buffer: dict[str, Any] = {}
@property
def measurable_ansaetze(self) -> tuple[type[_AbstractTN], ...]:
"""
Tuple of ansatzes for which this observable is measurable.
"""
return self._measurable_ansaetze
[docs]
def check_measurable(self, tn_type: type[_AbstractTN]) -> bool:
"""
Checks wether or not the observable can be measured for a given ansatz.
Args:
tn_type (type): Label of the TN ansatz
"""
return any(issubclass(tn_type, cls) for cls in self._measurable_ansaetze)
[docs]
@classmethod
@abstractmethod
def empty(cls) -> Self:
"""
Constructor of the class without any content.
"""
raise NotImplementedError("Must be implemented by actual class.")
def __len__(self) -> int:
"""
Provide appropriate length method
"""
return len(self.name)
@abstractmethod
def __iadd__(self, other: Any) -> Self:
"""
Overwrite operator ``+=`` to simplify syntax.
"""
raise NotImplementedError("Must be implemented by actual class.")
[docs]
def measure(
self,
state: _AbstractTN, # pylint: disable=unused-argument
operators: TNOperators, # pylint: disable=unused-argument
**kwargs: Any, # pylint: disable=unused-argument
) -> dict[str, Any]:
"""
Run the measurement of the observable for a given state and save it in the
results buffer.
Args:
state : instance of class from :py:mod:`qtealeaves.emulator`
The state to measure the observable on.
operators : instance of :py:class:`qtealeaves.operators.TNOperators`
The operators to be measured.
**kwargs : dict
Additional keyword arguments, which might be needed for the measurement
of derived observables.
Returns:
dict
Dictionary with the results of the measurement.
"""
# Cover the case the derived class does not implement this method, but is also not used.
if len(self.name) == 0:
return self.results_buffer
raise NotImplementedError("This observable has no measurement implemented yet.")
[docs]
def add_trajectories(self, all_results: dict, new: dict) -> dict:
"""
Add the observables for different quantum trajectories.
**Arguments**
all_results : dict
Dictionary with observables.
new : dict
Dictionary with new observables to add to all_results.
"""
for name in self.name:
if name not in all_results:
all_results[name] = new[name]
if isinstance(new[name], dict):
raise QTeaLeavesError("Dictionary addition not implemented.")
else:
all_results[name] += new[name]
return all_results
[docs]
def avg_trajectories(self, all_results: dict, num_trajectories: int) -> dict:
"""
Get the average of quantum trajectories observables.
**Arguments**
all_results : dict
Dictionary with observables.
num_trajectories : int
Total number of quantum trajectories.
"""
for name in self.name:
if isinstance(all_results[name], dict):
raise QTeaLeavesError("Dictionary addition not implemented.")
all_results[name] /= num_trajectories
return all_results
[docs]
def write_results(
self,
fh: h5py.File,
state_ansatz: type[_AbstractTN],
# pylint: disable-next=unused-argument
**kwargs: Any,
) -> None:
"""
Write the results to a HDF5 file.
The results have to be stored in the result buffer.
**Arguments**
fh : h5py.File
Open HDF5 file where the results are written to.
state_ansatz : type[_AbstractTN]
Label identifying the state ansatz currently in use.
"""
is_measured = self.check_measurable(state_ansatz)
fg = fh.create_group(str(self), track_order=True)
fg.attrs["is_measured"] = is_measured
if is_measured:
for name in self.name:
fg.create_dataset(name, data=self.results_buffer[name])
self.results_buffer = {}
# pylint: disable-next=unused-argument
[docs]
def read(self, fh: h5py.File, **kwargs: Any) -> Iterator[tuple[str, Any]]:
"""
Read observables from HDF5 file.
fh : h5py.File
Read the information about the measurements from this HDF5 file.
"""
# check that group exists
if str(self) not in fh:
raise QTeaLeavesError("Observable group not found in file.")
fg = fh[str(self)]
is_measured = fg.attrs.get("is_measured", False)
for name in self.name:
if not is_measured:
yield name, None
else:
yield name, fg[name][()]
def __repr__(self) -> str:
"""
Return the class name as representation.
"""
return self.__class__.__name__
[docs]
def collect_operators(self) -> Iterator[tuple[str, Any]]:
"""
Observables which require operators must provide this method,
because operators with symmetries might not be written
otherwise.
**Details**
The problems are, for example, correlations with equal operators
because they cannot be contracted over their third link.
"""
raise NotImplementedError(
"This observable does not support collecting operators."
)
[docs]
def get_id(self) -> str:
"""
Get the address in memory, which is useful instead of hashing
the complete object or for comparisons.
"""
return hex(id(self))