# This code is part of qtealeaves.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
Observable to measure any quantity with a custom function
"""
import logging
from typing import TYPE_CHECKING, Any, Callable, Iterator, Self
import h5py
import numpy as np
from qtealeaves.tooling import QTeaLeavesError
from .tnobase import _TNObsBase
if TYPE_CHECKING:
from qtealeaves.abstracttns.abstract_tn import _AbstractTN
else:
_AbstractTN = Any
__all__ = ["CustomFunction"]
logger = logging.getLogger(__name__)
[docs]
class CustomFunction(_TNObsBase):
"""
Custom observable for tensor network measurements
that are not part of the standard set of observables.
In addition to the name of the observable, the user
must provide a custom function to perform the
measurement. Moreover, this observable is Json serializable.
"""
def __init__(
self,
name: str,
function: Callable,
func_kwargs: dict,
):
_TNObsBase.__init__(self, name)
self.function = [function]
self.func_kwargs = [func_kwargs]
def __iadd__(self, other: Any) -> Self:
"""
Documentation see :func:`_TNObsBase.__iadd__`.
"""
if isinstance(other, CustomFunction):
self.name += other.name
self.function += other.function
self.func_kwargs += other.func_kwargs
else:
raise QTeaLeavesError(
f"__iadd__ not defined for types {type(self)} and {type(other)}."
)
return self
[docs]
def check_measurable(self, tn_type: type[_AbstractTN]) -> bool:
"""
Assume that the custom function is always measurable for
the simulation the user is running, thus return True.
Args:
tn_type: type[_AbstractTN]
The type of the TN, e.g., MPS, TTN, TTO, etc.
Returns:
bool: True.
"""
# need to reimplement this function to avoid the check in the base class
return True
[docs]
@classmethod
def empty(cls) -> Self:
"""
Documentation see :func:`_TNObsBase.empty`.
"""
obj = cls("", lambda x: x, {})
obj.name = []
obj.function = []
obj.func_kwargs = []
return obj
[docs]
def measure(
self, state: _AbstractTN, operators: Any = None, **kwargs: Any
) -> dict[str, Any]:
"""
Documentation see :func:`_TNObsBase.measure`.
"""
if len(self.name) == 0:
return self.results_buffer
# Why only the effective projectors, but not the effective operators. While
# isometrizing in meas_tensor_product, we still propagate them through? Why?
tmp_eff_proj = state.eff_proj
state.eff_proj = []
ini_iso_pos = state.iso_center
for name, func, func_kwargs in zip(
self.name,
self.function,
self.func_kwargs,
):
# Measure the observable with custom function
self.results_buffer[name] = func(state, **func_kwargs)
# restore the effective projectors
if ini_iso_pos is not None:
state.iso_towards(ini_iso_pos)
state.eff_proj = tmp_eff_proj
return self.results_buffer
[docs]
def write_results(
self,
fh: h5py.File,
state_ansatz: type[_AbstractTN],
# pylint: disable-next=unused-argument
**kwargs: Any,
) -> None:
"""
Write the results to a HDF5 file.
The results have to be stored in the result buffer.
**Arguments**
fh : h5py.File
Open HDF5 file where the results are written to.
state_ansatz : type[_AbstractTN]
Label identifying the state ansatz currently in use.
"""
fg = fh.create_group(str(self), track_order=True)
for name in self.name:
fg.create_dataset(name, data=self.results_buffer[name])
self.results_buffer = {}
# pylint: disable-next=unused-argument
[docs]
def read(
self, fh: h5py.File, **kwargs: Any
) -> Iterator[tuple[str, np.ndarray | None]]:
"""
Read Custom observable from HDF5 file.
fh : h5py.File
Read the information about the measurements from this HDF5 file.
"""
# check that group exists
if str(self) not in fh:
raise QTeaLeavesError("Observable group not found in file.")
fg = fh[str(self)]
for name in self.name:
yield name, fg[name][()]