numpyro/optim.py
Killed 13 out of 31 mutantsSurvived
Survived mutation testing. These mutants show holes in your test suite.Mutant 536
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -24,7 +24,7 @@
'SM3',
]
-_Params = TypeVar('_Params')
+_Params = TypeVar('XX_ParamsXX')
_OptState = TypeVar('_OptState')
_IterOptState = Tuple[int, _OptState]
Mutant 537
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -24,7 +24,7 @@
'SM3',
]
-_Params = TypeVar('_Params')
+_Params = None
_OptState = TypeVar('_OptState')
_IterOptState = Tuple[int, _OptState]
Mutant 538
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -25,7 +25,7 @@
]
_Params = TypeVar('_Params')
-_OptState = TypeVar('_OptState')
+_OptState = TypeVar('XX_OptStateXX')
_IterOptState = Tuple[int, _OptState]
Mutant 539
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -25,7 +25,7 @@
]
_Params = TypeVar('_Params')
-_OptState = TypeVar('_OptState')
+_OptState = None
_IterOptState = Tuple[int, _OptState]
Mutant 540
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -26,7 +26,7 @@
_Params = TypeVar('_Params')
_OptState = TypeVar('_OptState')
-_IterOptState = Tuple[int, _OptState]
+_IterOptState = None
class _NumpyroOptim(object):
Mutant 543
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -41,7 +41,7 @@
:return: initial optimizer state.
"""
opt_state = self.init_fn(params)
- return jnp.array(0), opt_state
+ return jnp.array(1), opt_state
def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
"""
Mutant 547
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -53,7 +53,7 @@
"""
i, opt_state = state
opt_state = self.update_fn(i, g, opt_state)
- return i + 1, opt_state
+ return i + 2, opt_state
def get_params(self, state: _IterOptState) -> _Params:
"""
Mutant 549
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -68,7 +68,7 @@
def _add_doc(fn):
def _wrapped(cls):
- cls.__doc__ = 'Wrapper class for the JAX optimizer: :func:`~jax.experimental.optimizers.{}`'\
+ cls.__doc__ = 'XXWrapper class for the JAX optimizer: :func:`~jax.experimental.optimizers.{}`XX'\
.format(fn.__name__)
return cls
Mutant 550
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -68,8 +68,7 @@
def _add_doc(fn):
def _wrapped(cls):
- cls.__doc__ = 'Wrapper class for the JAX optimizer: :func:`~jax.experimental.optimizers.{}`'\
- .format(fn.__name__)
+ cls.__doc__ = None
return cls
return _wrapped
Mutant 551
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -74,8 +74,6 @@
return _wrapped
-
-@_add_doc(optimizers.adam)
class Adam(_NumpyroOptim):
def __init__(self, *args, **kwargs):
super(Adam, self).__init__(optimizers.adam, *args, **kwargs)
Mutant 552
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -93,7 +93,7 @@
`A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba
https://arxiv.org/abs/1412.6980
"""
- def __init__(self, *args, clip_norm=10., **kwargs):
+ def __init__(self, *args, clip_norm=11.0, **kwargs):
self.clip_norm = clip_norm
super(ClippedAdam, self).__init__(optimizers.adam, *args, **kwargs)
Mutant 560
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -102,7 +102,7 @@
# clip norm
g = tree_map(lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g)
opt_state = self.update_fn(i, g, opt_state)
- return i + 1, opt_state
+ return i + 2, opt_state
@_add_doc(optimizers.adagrad)
Mutant 561
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -104,8 +104,6 @@
opt_state = self.update_fn(i, g, opt_state)
return i + 1, opt_state
-
-@_add_doc(optimizers.adagrad)
class Adagrad(_NumpyroOptim):
def __init__(self, *args, **kwargs):
super(Adagrad, self).__init__(optimizers.adagrad, *args, **kwargs)
Mutant 562
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -110,8 +110,6 @@
def __init__(self, *args, **kwargs):
super(Adagrad, self).__init__(optimizers.adagrad, *args, **kwargs)
-
-@_add_doc(optimizers.momentum)
class Momentum(_NumpyroOptim):
def __init__(self, *args, **kwargs):
super(Momentum, self).__init__(optimizers.momentum, *args, **kwargs)
Mutant 563
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -116,8 +116,6 @@
def __init__(self, *args, **kwargs):
super(Momentum, self).__init__(optimizers.momentum, *args, **kwargs)
-
-@_add_doc(optimizers.rmsprop)
class RMSProp(_NumpyroOptim):
def __init__(self, *args, **kwargs):
super(RMSProp, self).__init__(optimizers.rmsprop, *args, **kwargs)
Mutant 564
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -122,8 +122,6 @@
def __init__(self, *args, **kwargs):
super(RMSProp, self).__init__(optimizers.rmsprop, *args, **kwargs)
-
-@_add_doc(optimizers.rmsprop_momentum)
class RMSPropMomentum(_NumpyroOptim):
def __init__(self, *args, **kwargs):
super(RMSPropMomentum, self).__init__(optimizers.rmsprop_momentum, *args, **kwargs)
Mutant 565
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -128,8 +128,6 @@
def __init__(self, *args, **kwargs):
super(RMSPropMomentum, self).__init__(optimizers.rmsprop_momentum, *args, **kwargs)
-
-@_add_doc(optimizers.sgd)
class SGD(_NumpyroOptim):
def __init__(self, *args, **kwargs):
super(SGD, self).__init__(optimizers.sgd, *args, **kwargs)
Mutant 566
--- numpyro/optim.py
+++ numpyro/optim.py
@@ -134,8 +134,6 @@
def __init__(self, *args, **kwargs):
super(SGD, self).__init__(optimizers.sgd, *args, **kwargs)
-
-@_add_doc(optimizers.sm3)
class SM3(_NumpyroOptim):
def __init__(self, *args, **kwargs):
super(SM3, self).__init__(optimizers.sm3, *args, **kwargs)