Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] DeepMind preprocessor not working as expected #45186

Open
rajfly opened this issue May 7, 2024 · 0 comments
Open

[RLlib] DeepMind preprocessor not working as expected #45186

rajfly opened this issue May 7, 2024 · 0 comments
Labels
bug Something that is supposed to be working; but isn't rllib RLlib related issues rllib-oldstack-cleanup Issues related to cleaning up classes, utilities on the old API stack

Comments

@rajfly
Copy link

rajfly commented May 7, 2024

What happened + What you expected to happen

With the PPO algorithm (hyperparameters configured as per the original paper) on the Atari environments, the in-built deepmind preprocessor is not working as expected. I get the following error which suggests that the model is getting the raw environment observation (with 3 rgb channels) instead of the deepmind preprocessed observation (framestack of 4):

RuntimeError: Given groups=1, weight of size [32, 4, 8, 8], expected input[32, 84, 84, 3] to have 4 channels, but got 84 channels instead

To run the reproduction script, simply save it to a file and run it with python e.g., python ppo_atari.py

Versions / Dependencies

python = 3.11
ray[tune,rllib] = 2.20.0
os: ubuntu lts

Reproduction script

import argparse
import json
import os
import pathlib
import time
import uuid

import numpy as np
import torch
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.tune.logger import pretty_print
from torch import nn


def linear_schedule(lr, n_iterations, iteration_steps):
    ts_lr = []
    ts = 0
    for iteration in range(1, n_iterations + 1):
        frac = 1.0 - (iteration - 1.0) / n_iterations
        ts_lr.append((ts, frac * lr))
        ts += iteration_steps
    return ts_lr


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(TorchModelV2, nn.Module):
    def __init__(
        self, observation_space, action_space, num_outputs, model_config, name
    ):
        TorchModelV2.__init__(
            self, observation_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)

        assert action_space.n == num_outputs

        self.network = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, 512)),
            nn.ReLU(),
        )
        self.actor = layer_init(nn.Linear(512, num_outputs), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1)
        self.output = None

    @override(ModelV2)
    def forward(self, input_dict, state, seq_lens):
        self.output = self.network(input_dict["obs"] / 255.0)
        return self.actor(self.output), []

    @override(ModelV2)
    def value_function(self):
        assert self.output is not None, "must call forward first!"
        return torch.reshape(self.critic(self.output), [-1])


def train_atari(args):
    total_timesteps = int(10e6)
    lr = 2.5e-4
    n_envs = 8
    n_steps = 128
    n_iterations = total_timesteps // (n_envs * n_steps)
    lr_schedule = linear_schedule(lr, n_iterations, n_steps * n_envs)

    ModelCatalog.register_custom_model("Agent", Agent)

    ppo = (
        PPOConfig()
        .training(
            gamma=0.99,
            grad_clip_by="global_norm",
            train_batch_size=128 * 8,
            model={"custom_model": "Agent"},
            optimizer={"eps": 1e-5},
            lr_schedule=lr_schedule,
            use_critic=True,
            use_gae=True,
            lambda_=0.95,
            use_kl_loss=False,
            kl_coeff=None,
            kl_target=None,
            sgd_minibatch_size=256,
            num_sgd_iter=4,
            shuffle_sequences=True,
            vf_loss_coeff=0.5,
            entropy_coeff=0.01,
            entropy_coeff_schedule=None,
            clip_param=0.1,
            vf_clip_param=0.1,
            grad_clip=0.5,
        )
        .environment(
            env=f"{args.env}NoFrameskip-v4",
            env_config={"frameskip": 1},
            render_env=False,
            clip_rewards=True,
            normalize_actions=False,
            clip_actions=False,
            is_atari=True,
        )
        .env_runners(
            num_env_runners=1,
            num_envs_per_env_runner=8,
            rollout_fragment_length=128,
            batch_mode="truncate_episodes",
            create_env_on_local_worker=False,
            preprocessor_pref="deepmind",
            observation_filter="NoFilter",
            explore=True,
            exploration_config={"type": "StochasticSampling"},
        )
        .framework(framework="torch")
        .evaluation(
            evaluation_interval=None,
            evaluation_duration=100,
            evaluation_duration_unit="episodes",
            evaluation_config={
                "explore": True,
                "exploration_config": {"type": "StochasticSampling"},
            },
            evaluation_num_env_runners=1,
        )
        .debugging(logger_config={"type": "ray.tune.logger.NoopLogger"}, seed=args.seed)
        .resources(
            num_gpus=0.4,
            num_cpus_per_worker=1,
            num_gpus_per_worker=0,
        )
        .reporting(
            metrics_num_episodes_for_smoothing=100,
            min_train_timesteps_per_iteration=128 * 8,
            min_sample_timesteps_per_iteration=128 * 8,
        )
        .build()
    )
    for i in range(10):
        result = ppo.train()
        print(pretty_print(result))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-g",
        "--gpu",
        type=int,
        help="Specify GPU index",
        default=0,
    )
    parser.add_argument(
        "-e",
        "--env",
        type=str,
        help="Specify Atari environment w/o version",
        default="Pong",
    )
    parser.add_argument(
        "-t",
        "--trials",
        type=int,
        help="Specify number of trials",
        default=5,
    )
    args = parser.parse_args()
    for _ in range(args.trials):
        args.id = uuid.uuid4().hex
        args.path = os.path.join("trials", "ppo", args.env, args.id)
        args.seed = int(time.time())

        # create dir
        # pathlib.Path(args.path).mkdir(parents=True, exist_ok=True)

        # set gpu
        os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}"

        train_atari(args)
        break

        # save trial info
        # with open(os.path.join(args.path, "info.json"), "w") as f:
        #     json.dump(vars(args), f, indent=4)

Issue Severity

High: It blocks me from completing my task.

@rajfly rajfly added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels May 7, 2024
@anyscalesam anyscalesam added the rllib RLlib related issues label May 7, 2024
@simonsays1980 simonsays1980 added rllib-oldstack-cleanup Issues related to cleaning up classes, utilities on the old API stack and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't rllib RLlib related issues rllib-oldstack-cleanup Issues related to cleaning up classes, utilities on the old API stack
Projects
None yet
Development

No branches or pull requests

3 participants