Source code for qtealeaves.observables.state2file

# This code is part of qtealeaves.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""
Observable to save the tensors forming the final MPS
"""
import logging
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Self

import h5py

from qtealeaves.emulator import ATTN, LPTN, MPS, TTN, TTO, StateVector
from qtealeaves.tooling import QTeaLeavesError

from .tnobase import _TNObsBase

if TYPE_CHECKING:
    from qtealeaves.abstracttns.abstract_tn import _AbstractTN
else:
    _AbstractTN = Any

__all__ = ["State2File"]
logger = logging.getLogger(__name__)


[docs] class State2File(_TNObsBase): """ Write the state to a file. We stress that saving the state to a file will enable to further measure observables, since you will have available all the informations you had at the end of the simulation. .. warning:: While saving the state can be useful, it can really slow down the evolution when it is saved at each time-step of a time evolution. Please use this observable carefully! Reference to the description on the backend for more specific informations. **Arguments** name : str Filename to save the state. formatting : char Specifies format, i.e., 'F' for formatted, 'U' for unformatted, or 'D' for formatted without symmetries. On the python backend, U is pickled, F is formatted D is converted to dense tensor and pickled (especially the last on is different from fortran, where the dense TN is stored as formatted file). """ # mypy triggered because of StateVector _measurable_ansaetze = (MPS, TTN, TTO, ATTN, LPTN, StateVector) # type: ignore[assignment] def __init__(self, name: str, formatting: str): super().__init__(name) self.formatting = [formatting]
[docs] @classmethod def empty(cls) -> Self: """ Documentation see :func:`_TNObsBase.empty`. """ obj = cls("x", "F") obj.name = [] obj.formatting = [] return obj
def __iadd__(self, other: Any) -> Self: """ Documentation see :func:`_TNObsBase.__iadd__`. """ if isinstance(other, State2File): self.name += other.name self.formatting += other.formatting else: raise QTeaLeavesError( f"__iadd__ not defined for types {type(self)} and {type(other)}." ) return self
[docs] def measure( self, state: _AbstractTN, operators: Any = None, **kwargs: Any ) -> dict[str, Any]: """ Save the state to file. **Arguments** state : instance of class from :py:mod:`qtealeaves.emulator` The state to measure the observable on. operators : extra argument to maintain same signature from parent class, pass `None`. params : dict Dictionary with parameters of the simulation. Needed for inserting params of the observables into filename. """ if not self.check_measurable(state.__class__): logger.warning("Observable %s not measurable for %s", self.name, str(state)) return self.results_buffer num_excited = len(state.eff_proj) excited_index = f"_excited{num_excited:03}" if num_excited > 0 else "" postfix = kwargs.get("postfix", "") params = kwargs.get("params", None) if params is None: logger.warning( "No params provided, cannot add them to filename. " "If this is intended, please pass an empty dictionary to silence this warning." ) params = {} for jj, name_jj in enumerate(self.name): filename_tmp = str(self.eval_str_param(name_jj, params)) filename_tmp += postfix + excited_index + ".pkl" + state.extension if self.formatting[jj] == "U": state.save_pickle(filename_tmp) elif self.formatting[jj] == "F": filename_tmp = filename_tmp.replace(".pkl", ".") state.write(filename_tmp) elif self.formatting[jj] == "D": state_dense = state.to_dense() state_dense.save_pickle(filename_tmp) self.results_buffer[name_jj] = filename_tmp return self.results_buffer
[docs] def add_trajectories( self, all_results: dict[str, list[str]], new: dict[str, str] ) -> dict[str, list[str]]: """ Documentation see :func:`_TNObsBase.add_trajectories`. Here, we generate a list of filenames. """ for name in self.name: if name not in all_results: all_results[name] = [new[name]] else: all_results[name].append(new[name]) return all_results
[docs] def avg_trajectories( self, all_results: dict[str, list[str]], num_trajectories: int ) -> dict[str, list[str]]: """ Documentation see :func:`_TNObsBase.avg_trajectories`. Here, we return the list of filenames as is, no action possible for averaging. """ return all_results
[docs] def write_results( self, fh: h5py.File, state_ansatz: type[_AbstractTN], # pylint: disable-next=unused-argument **kwargs: Any, ) -> None: """ Write the actual results to a HDF5 file. The results have to be stored in the result buffer. **Arguments** fh : h5py.File Open HDF5 file where the results are written to. state_ansatz : str Label identifying the state ansatz currently in use. """ is_measured = self.check_measurable(state_ansatz) fg = fh.create_group(str(self), track_order=True) fg.attrs["is_measured"] = is_measured if is_measured: for ii, name in enumerate(self.name): fg.create_dataset( f"{str(self)}_{ii}", data=self.results_buffer[name], dtype=h5py.string_dtype(), ) self.results_buffer = {}
[docs] def read(self, fh: h5py.File, **kwargs: Any) -> Iterator[tuple[str, str]]: """ Read file observable from HDF5 file. **Arguments** fh : h5py.File Read the information about the measurements from this HDF5 file. params : dict (in kwargs) The parameter dictionary, which is required to obtain the output folder. It is required to evaluate callable etc. used in ``self.name``. """ params = kwargs.get("params") # check that group exists if str(self) not in fh: raise QTeaLeavesError("Observable group not found in file.") fg = fh[str(self)] is_measured = fg.attrs.get("is_measured", False) for ii, name in enumerate(self.name): if is_measured: filename = self.eval_str_param(name, params) filename_final = fg[f"{str(self)}_{ii}"][()].decode("utf-8") yield filename, filename_final else: filename = self.eval_str_param(name, params) yield filename, ""