Source code for qtealeaves.observables.custom_function_obs

# This code is part of qtealeaves.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Observable to measure any quantity with a custom function
"""

import logging
from typing import TYPE_CHECKING, Any, Callable, Iterator, Self

import h5py
import numpy as np

from qtealeaves.tooling import QTeaLeavesError

from .tnobase import _TNObsBase

if TYPE_CHECKING:
    from qtealeaves.abstracttns.abstract_tn import _AbstractTN
else:
    _AbstractTN = Any

__all__ = ["CustomFunction"]
logger = logging.getLogger(__name__)


[docs] class CustomFunction(_TNObsBase): """ Custom observable for tensor network measurements that are not part of the standard set of observables. In addition to the name of the observable, the user must provide a custom function to perform the measurement. Moreover, this observable is Json serializable. """ def __init__( self, name: str, function: Callable, func_kwargs: dict, ): _TNObsBase.__init__(self, name) self.function = [function] self.func_kwargs = [func_kwargs] def __iadd__(self, other: Any) -> Self: """ Documentation see :func:`_TNObsBase.__iadd__`. """ if isinstance(other, CustomFunction): self.name += other.name self.function += other.function self.func_kwargs += other.func_kwargs else: raise QTeaLeavesError( f"__iadd__ not defined for types {type(self)} and {type(other)}." ) return self
[docs] def check_measurable(self, tn_type: type[_AbstractTN]) -> bool: """ Assume that the custom function is always measurable for the simulation the user is running, thus return True. Args: tn_type: type[_AbstractTN] The type of the TN, e.g., MPS, TTN, TTO, etc. Returns: bool: True. """ # need to reimplement this function to avoid the check in the base class return True
[docs] @classmethod def empty(cls) -> Self: """ Documentation see :func:`_TNObsBase.empty`. """ obj = cls("", lambda x: x, {}) obj.name = [] obj.function = [] obj.func_kwargs = [] return obj
[docs] def measure( self, state: _AbstractTN, operators: Any = None, **kwargs: Any ) -> dict[str, Any]: """ Documentation see :func:`_TNObsBase.measure`. """ if len(self.name) == 0: return self.results_buffer # Why only the effective projectors, but not the effective operators. While # isometrizing in meas_tensor_product, we still propagate them through? Why? tmp_eff_proj = state.eff_proj state.eff_proj = [] ini_iso_pos = state.iso_center for name, func, func_kwargs in zip( self.name, self.function, self.func_kwargs, ): # Measure the observable with custom function self.results_buffer[name] = func(state, **func_kwargs) # restore the effective projectors if ini_iso_pos is not None: state.iso_towards(ini_iso_pos) state.eff_proj = tmp_eff_proj return self.results_buffer
[docs] def write_results( self, fh: h5py.File, state_ansatz: type[_AbstractTN], # pylint: disable-next=unused-argument **kwargs: Any, ) -> None: """ Write the results to a HDF5 file. The results have to be stored in the result buffer. **Arguments** fh : h5py.File Open HDF5 file where the results are written to. state_ansatz : type[_AbstractTN] Label identifying the state ansatz currently in use. """ fg = fh.create_group(str(self), track_order=True) for name in self.name: fg.create_dataset(name, data=self.results_buffer[name]) self.results_buffer = {}
# pylint: disable-next=unused-argument
[docs] def read( self, fh: h5py.File, **kwargs: Any ) -> Iterator[tuple[str, np.ndarray | None]]: """ Read Custom observable from HDF5 file. fh : h5py.File Read the information about the measurements from this HDF5 file. """ # check that group exists if str(self) not in fh: raise QTeaLeavesError("Observable group not found in file.") fg = fh[str(self)] for name in self.name: yield name, fg[name][()]