There is a newer version of the record available.

Published July 31, 2023 | Version v0.6.0
Software Open

ASEM000/PyTreeClass: v0.6.0

Authors/Creators

  • 1. KAIST

Description

v0.6.0
  • Allow nested mutations using .at[method](*args, **kwargs). After the change, inner methods can mutate copied new instances at any level not just the top level. a motivation for this is to experiment with lazy initialization scheme, where inner layers need to mutate their inner state. see the example below for flax-like lazy initialization as descriped here

    <details>

    ```python

    import pytreeclass as pytc import jax.random as jr from typing import Any import jax import jax.numpy as jnp from typing import Callable, TypeVar

    T = TypeVar("T")

    @pytc.autoinit class LazyLinear(pytc.TreeClass):

      outdim: int
      weight_init: Callable[..., T] = jax.nn.initializers.glorot_normal()
      bias_init: Callable[..., T] = jax.nn.initializers.zeros
    
      def param(self, name: str, init_func: Callable[..., T], *args) -> T:
          if name not in vars(self):
              setattr(self, name, init_func(*args))
          return vars(self)[name]
    
      def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
          w = self.param("weight", self.weight_init, key, (x.shape[-1], self.outdim))
          y = x @ w
          if self.bias_init is not None:
              b = self.param("bias", self.bias_init, key, (self.outdim,))
              return y + b
          return y
    
@pytc.autoinit
class StackedLinear(pytc.TreeClass):
    l1: LazyLinear = LazyLinear(outdim=10)
    l2: LazyLinear = LazyLinear(outdim=1)

    def call(self, x: jax.Array):
        return self.l2(jax.nn.relu(self.l1(x)))

lazy_layer = StackedLinear()
print(repr(lazy_layer))
# StackedLinear(
#   l1=LazyLinear(
#     outdim=10, 
#     weight_init=init(key, shape, dtype), 
#     bias_init=zeros(key, shape, dtype)
#   ), 
#   l2=LazyLinear(
#     outdim=1, 
#     weight_init=init(key, shape, dtype), 
#     bias_init=zeros(key, shape, dtype)
#   )
# )

_, materialized_layer = lazy_layer.at["call"](jnp.ones((1, 5)))
materialized_layer
# StackedLinear(
#   l1=LazyLinear(
#     outdim=10, 
#     weight_init=init(key, shape, dtype), 
#     bias_init=zeros(key, shape, dtype), 
#     weight=f32[5,10](μ=-0.04, σ=0.32, ∈[-0.74,0.63]), 
#     bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00])
#   ), 
#   l2=LazyLinear(
#     outdim=1, 
#     weight_init=init(key, shape, dtype), 
#     bias_init=zeros(key, shape, dtype), 
#     weight=f32[10,1](μ=-0.07, σ=0.23, ∈[-0.34,0.34]), 
#     bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
#   )
# )

materialized_layer(jnp.ones((1, 5)))
# Array([[0.16712935]], dtype=float32)
```
</details>

Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.5...v0.6.0

Files

ASEM000/PyTreeClass-v0.6.0.zip

Files (392.2 kB)

Name Size Download all
md5:88d833974823ea22f874c1a1962f85e1
392.2 kB Preview Download

Additional details

Related works