# This code is part of qtealeaves.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
Observable to save the tensors forming the final MPS
"""
import logging
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Self
import h5py
from qtealeaves.emulator import ATTN, LPTN, MPS, TTN, TTO, StateVector
from qtealeaves.tooling import QTeaLeavesError
from .tnobase import _TNObsBase
if TYPE_CHECKING:
from qtealeaves.abstracttns.abstract_tn import _AbstractTN
else:
_AbstractTN = Any
__all__ = ["State2File"]
logger = logging.getLogger(__name__)
[docs]
class State2File(_TNObsBase):
"""
Write the state to a file.
We stress that saving the state to a file will enable to
further measure observables, since you will have available
all the informations you had at the end of the simulation.
.. warning::
While saving the state can be useful, it can really slow
down the evolution when it is saved at each time-step of
a time evolution. Please use this observable carefully!
Reference to the description on the backend for more
specific informations.
**Arguments**
name : str
Filename to save the state.
formatting : char
Specifies format, i.e., 'F' for formatted, 'U' for
unformatted, or 'D' for formatted without symmetries.
On the python backend, U is pickled, F is formatted
D is converted to dense tensor and pickled (especially
the last on is different from fortran, where the dense
TN is stored as formatted file).
"""
# mypy triggered because of StateVector
_measurable_ansaetze = (MPS, TTN, TTO, ATTN, LPTN, StateVector) # type: ignore[assignment]
def __init__(self, name: str, formatting: str):
super().__init__(name)
self.formatting = [formatting]
[docs]
@classmethod
def empty(cls) -> Self:
"""
Documentation see :func:`_TNObsBase.empty`.
"""
obj = cls("x", "F")
obj.name = []
obj.formatting = []
return obj
def __iadd__(self, other: Any) -> Self:
"""
Documentation see :func:`_TNObsBase.__iadd__`.
"""
if isinstance(other, State2File):
self.name += other.name
self.formatting += other.formatting
else:
raise QTeaLeavesError(
f"__iadd__ not defined for types {type(self)} and {type(other)}."
)
return self
[docs]
def measure(
self, state: _AbstractTN, operators: Any = None, **kwargs: Any
) -> dict[str, Any]:
"""
Save the state to file.
**Arguments**
state : instance of class from :py:mod:`qtealeaves.emulator`
The state to measure the observable on.
operators : extra argument to maintain same signature from parent class, pass `None`.
params : dict
Dictionary with parameters of the simulation.
Needed for inserting params of the observables into filename.
"""
if not self.check_measurable(state.__class__):
logger.warning("Observable %s not measurable for %s", self.name, str(state))
return self.results_buffer
num_excited = len(state.eff_proj)
excited_index = f"_excited{num_excited:03}" if num_excited > 0 else ""
postfix = kwargs.get("postfix", "")
params = kwargs.get("params", None)
if params is None:
logger.warning(
"No params provided, cannot add them to filename. "
"If this is intended, please pass an empty dictionary to silence this warning."
)
params = {}
for jj, name_jj in enumerate(self.name):
filename_tmp = str(self.eval_str_param(name_jj, params))
filename_tmp += postfix + excited_index + ".pkl" + state.extension
if self.formatting[jj] == "U":
state.save_pickle(filename_tmp)
elif self.formatting[jj] == "F":
filename_tmp = filename_tmp.replace(".pkl", ".")
state.write(filename_tmp)
elif self.formatting[jj] == "D":
state_dense = state.to_dense()
state_dense.save_pickle(filename_tmp)
self.results_buffer[name_jj] = filename_tmp
return self.results_buffer
[docs]
def add_trajectories(
self, all_results: dict[str, list[str]], new: dict[str, str]
) -> dict[str, list[str]]:
"""
Documentation see :func:`_TNObsBase.add_trajectories`.
Here, we generate a list of filenames.
"""
for name in self.name:
if name not in all_results:
all_results[name] = [new[name]]
else:
all_results[name].append(new[name])
return all_results
[docs]
def avg_trajectories(
self, all_results: dict[str, list[str]], num_trajectories: int
) -> dict[str, list[str]]:
"""
Documentation see :func:`_TNObsBase.avg_trajectories`.
Here, we return the list of filenames as is, no action possible
for averaging.
"""
return all_results
[docs]
def write_results(
self,
fh: h5py.File,
state_ansatz: type[_AbstractTN],
# pylint: disable-next=unused-argument
**kwargs: Any,
) -> None:
"""
Write the actual results to a HDF5 file.
The results have to be stored in the result buffer.
**Arguments**
fh : h5py.File
Open HDF5 file where the results are written to.
state_ansatz : str
Label identifying the state ansatz currently in use.
"""
is_measured = self.check_measurable(state_ansatz)
fg = fh.create_group(str(self), track_order=True)
fg.attrs["is_measured"] = is_measured
if is_measured:
for ii, name in enumerate(self.name):
fg.create_dataset(
f"{str(self)}_{ii}",
data=self.results_buffer[name],
dtype=h5py.string_dtype(),
)
self.results_buffer = {}
[docs]
def read(self, fh: h5py.File, **kwargs: Any) -> Iterator[tuple[str, str]]:
"""
Read file observable from HDF5 file.
**Arguments**
fh : h5py.File
Read the information about the measurements from this HDF5 file.
params : dict (in kwargs)
The parameter dictionary, which is required to obtain
the output folder. It is required to evaluate
callable etc. used in ``self.name``.
"""
params = kwargs.get("params")
# check that group exists
if str(self) not in fh:
raise QTeaLeavesError("Observable group not found in file.")
fg = fh[str(self)]
is_measured = fg.attrs.get("is_measured", False)
for ii, name in enumerate(self.name):
if is_measured:
filename = self.eval_str_param(name, params)
filename_final = fg[f"{str(self)}_{ii}"][()].decode("utf-8")
yield filename, filename_final
else:
filename = self.eval_str_param(name, params)
yield filename, ""