ASEM000/PyTreeClass: v0.6.0
Description
Allow nested mutations using
<details>.at[method](*args, **kwargs). After the change, inner methods can mutate copied new instances at any level not just the top level. a motivation for this is to experiment with lazy initialization scheme, where inner layers need to mutate their inner state. see the example below forflax-like lazy initialization as descriped here```python
import pytreeclass as pytc import jax.random as jr from typing import Any import jax import jax.numpy as jnp from typing import Callable, TypeVar
T = TypeVar("T")
@pytc.autoinit class LazyLinear(pytc.TreeClass):
outdim: int weight_init: Callable[..., T] = jax.nn.initializers.glorot_normal() bias_init: Callable[..., T] = jax.nn.initializers.zeros def param(self, name: str, init_func: Callable[..., T], *args) -> T: if name not in vars(self): setattr(self, name, init_func(*args)) return vars(self)[name] def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)): w = self.param("weight", self.weight_init, key, (x.shape[-1], self.outdim)) y = x @ w if self.bias_init is not None: b = self.param("bias", self.bias_init, key, (self.outdim,)) return y + b return y
@pytc.autoinit
class StackedLinear(pytc.TreeClass):
l1: LazyLinear = LazyLinear(outdim=10)
l2: LazyLinear = LazyLinear(outdim=1)
def call(self, x: jax.Array):
return self.l2(jax.nn.relu(self.l1(x)))
lazy_layer = StackedLinear()
print(repr(lazy_layer))
# StackedLinear(
# l1=LazyLinear(
# outdim=10,
# weight_init=init(key, shape, dtype),
# bias_init=zeros(key, shape, dtype)
# ),
# l2=LazyLinear(
# outdim=1,
# weight_init=init(key, shape, dtype),
# bias_init=zeros(key, shape, dtype)
# )
# )
_, materialized_layer = lazy_layer.at["call"](jnp.ones((1, 5)))
materialized_layer
# StackedLinear(
# l1=LazyLinear(
# outdim=10,
# weight_init=init(key, shape, dtype),
# bias_init=zeros(key, shape, dtype),
# weight=f32[5,10](μ=-0.04, σ=0.32, ∈[-0.74,0.63]),
# bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00])
# ),
# l2=LazyLinear(
# outdim=1,
# weight_init=init(key, shape, dtype),
# bias_init=zeros(key, shape, dtype),
# weight=f32[10,1](μ=-0.07, σ=0.23, ∈[-0.34,0.34]),
# bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
# )
# )
materialized_layer(jnp.ones((1, 5)))
# Array([[0.16712935]], dtype=float32)
```
</details>
Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.5...v0.6.0
Files
ASEM000/PyTreeClass-v0.6.0.zip
Files
(392.2 kB)
| Name | Size | Download all |
|---|---|---|
|
md5:88d833974823ea22f874c1a1962f85e1
|
392.2 kB | Preview Download |
Additional details
Related works
- Is supplement to
- https://github.com/ASEM000/PyTreeClass/tree/v0.6.0 (URL)