Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NoisyNets implementation issues #189

Open
pseudo-rnd-thoughts opened this issue Jan 2, 2022 · 6 comments
Open

NoisyNets implementation issues #189

pseudo-rnd-thoughts opened this issue Jan 2, 2022 · 6 comments

Comments

@pseudo-rnd-thoughts
Copy link

I'm implementing my own RL framework in Jax to better understand RL algorithms and found your code very helpful

Looking at the NoisyNets implementation, on line 316 and 317 (https://github.com/google/dopamine/blob/master/dopamine/jax/networks.py)
The same rng_key is used each time noise is generated meaning that no 'new' noise is generated each time an input is passed to the layer. In effect, the layer just applies a linear transform I think

This is a short testing example

import jax
import numpy as np

from dopamine.jax.networks import NoisyNetwork

if __name__ == '__main__':
    rng = jax.random.PRNGKey(1)
    rng, rng_net_def, rng_net_param = jax.random.split(rng, num=3)

    net_def = NoisyNetwork(rng_key=rng_net_def, eval_mode=False)
    net_params = net_def.init(rng_net_param, x=np.zeros(10), features=3)

    state = np.random.random(10)
    print(net_def.apply(net_params, x=state, features=3))
    print(net_def.apply(net_params, x=state, features=3))

If this is an issue, then I implemented the following code for my framework

from typing import Sequence

import jax
import numpy as onp
import jax.numpy as jnp
from flax import linen as nn

class NoisyDense(nn.Module):
    features: int

    use_bias: bool = True

    @staticmethod
    @jax.jit
    def _f(x: jnp.ndarray) -> jnp.ndarray:
        # See (10) and (11) in Fortunato et al. (2018).
        return jnp.multiply(jnp.sign(x), jnp.power(jnp.abs(x), 0.5))

    @nn.compact
    def __call__(self, inputs: onp.ndarray, eval_mode: bool = True, rng: jnp.DeviceArray = None) -> jnp.ndarray:
        if eval_mode:  # Turn off noise during evaluation
            w_epsilon = jnp.zeros(shape=(inputs.shape[0], self.features), dtype=onp.float32)
            b_epsilon = jnp.zeros(shape=(self.features,), dtype=onp.float32)
        else:  # Factored gaussian noise in (10) and (11) in Fortunato et al. (2018).
            p_key, q_key = jax.random.split(rng)
            p, q = jax.random.normal(p_key, [inputs.shape[0], 1]), jax.random.normal(q_key, [1, self.features])
            f_p, f_q = self._f(p), self._f(q)
            w_epsilon, b_epsilon = f_p * f_p, jnp.squeeze(f_q)

        def _mu_init(key: jnp.DeviceArray, shape: Sequence[int]):
            # Initialization of mean noise parameters (Section 3.2)
            mean = 1 / jnp.power(inputs.shape[0], 0.5)
            return jax.random.uniform(key, minval=-mean, maxval=mean, shape=shape)

        def _sigma_init(_key: jnp.DeviceArray, shape: Sequence[int], dtype=jnp.float32):
            # Initialization of sigma noise parameters (Section 3.2)
            return jnp.ones(shape, dtype) * (0.1 / onp.sqrt(inputs.shape[0]))

        # See (8) and (9) in Fortunato et al. (2018) for output computation.
        w_mu = self.param('kernel_mu', _mu_init, (inputs.shape[0], self.features))
        w_sigma = self.param('kernel_sigma', _sigma_init, (inputs.shape[0], self.features))
        out = jnp.matmul(inputs, w_mu + jnp.multiply(w_sigma, w_epsilon))

        if self.use_bias:
            b_mu = self.param('bias_mu', _mu_init, (self.features,))
            b_sigma = self.param('bias_sigma', _sigma_init, (self.features,))
            out = out + b_mu + jnp.multiply(b_sigma, b_epsilon)
        return out

Here is some similar testing code

if __name__ == '__main__':
    rng = jax.random.PRNGKey(1)
    rng, rng_net_def, rng_net_param = jax.random.split(rng, num=3)

    net_def = NoisyDense(features=2)
    net_params = net_def.init(rng_net_param, np.zeros(10))

    state = np.random.random(10)
    print(net_def.apply(net_params, inputs=state))
    print(net_def.apply(net_params, inputs=state, eval_mode=False, rng=rng_net_def))
    print(net_def.apply(net_params, inputs=state, eval_mode=False, rng=rng))

I would have submitted this as a pull request but noticed that you are not accepting merges

@young-geng
Copy link

I also realized that this might be an issue. If we want to resample noise we should use either explicitly pass in a new rng every time or use self.make_rng to ensure that RNGs are split correctly.

@pseudo-rnd-thoughts
Copy link
Author

Flax linen module variables are not able to be updated so the only way to have "new" random noise is to pass in the PRNG as a parameter like I have done in my example code

@agarwl
Copy link
Contributor

agarwl commented Feb 22, 2022

Edit: I understood the original comment incorrectly -- it was pointing out the correlated noise in Line 316 & 317 -- It's unclear how much impact it has on performance but will fix it - thanks for pointing it out! Also, this should fix it:

rng_p, rng_q = jax.random.split(self.rng_key, num=2)
p = NoisyNetwork.sample_noise(rng_p, [x.shape[0], 1])
q = NoisyNetwork.sample_noise(rng_q, [1, features])

I am not sure if this is a bug -- as @young-geng mentioned, if we want to resample noise, then we need to pass an explicit rng every time as done in the FullRainbowNetwork here. That said, this does seem like a documentation issue about how we expect NoisyNets to work. @psc-g for further visibility

Here's a simplified example to verify that explicitly passing rng works:

class DummyNetwork(nn.Module):
  """Dummy network for testing NoisyNets."""

  @nn.compact
  def __call__(self, x, eval_mode=False, key=None):
    if key is None:
      key = jax.random.PRNGKey(int(time.time() * 1e6))
    return NoisyNetwork(rng_key=key, eval_mode=eval_mode)(x, features=2)

def create_noisy_net_and_eval(num_runs=5):
  network_def = DummyNetwork()
  x = jnp.ones(5)
  rng = jax.random.PRNGKey(0)
  rng1, rng = jax.random.split(rng, 2)
  params = network_def.init(rng1, x=x)
  for i in range(num_runs):
    rng1, rng = jax.random.split(rng)
    print(f'rng{i}', network_def.apply(params, x=x, key=rng1))
>> create_noisy_net_and_eval()
rng0 [ 0.49825954 -0.5264382 ]
rng1 [ 0.3296632  -0.56998575]
rng2 [ 0.5706229  -0.42372862]
rng3 [ 0.5419281  -0.47531918]
rng4 [ 0.52439386 -0.46529555]

@psc-g
Copy link
Collaborator

psc-g commented Feb 22, 2022

hi, thanks for raising this! i agree with what rishabh pointed out. i believe once the rngs used for p and q are uncorrelated, i believe it is working as expected (e.g. a new rng is not passed in every time)

@pseudo-rnd-thoughts
Copy link
Author

@agarwl Thanks, I hadn't spotted the FullRainbowNetwork implementation passed a new rng key to the noisy network each time so you are correct. With the modification that you propose then the noisy network works are expected

But as the eval_mode and rng_key are attributes of the network then it is potentially misleading as these are actually attributes that need to be passed to the call function every time. And in reverse, the features, use_bias and kernel_init should not be modified after init.
This is the reason that I shifted these variables from the init to call and vice versa in my implementation

@psc-g I may be wrong but I think a new rng should be passed every time (when eval_mode = False) as if new noise is not added each time then all that is happening is a linear transformation is being applied to the inputs.
In my eyes, defeating the point of the noisy network heuristic to both increase stability/resilience of the network and increase the "observations" seen by the network.

@young-geng
Copy link

Now I see that it passes in a new RNG key every time so I believe I was wrong about the noise not being resampled and the implementation should be correct. Sorry for the confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants