Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements I want make in [finrl]->[agent]->[rllib]->[models.py] #1193

Open
Aditya-dom opened this issue Mar 31, 2024 · 0 comments
Open

Improvements I want make in [finrl]->[agent]->[rllib]->[models.py] #1193

Aditya-dom opened this issue Mar 31, 2024 · 0 comments

Comments

@Aditya-dom
Copy link

Here are the improvements made to the code:

1 - Imported with_common_config, Trainer, and COMMON_CONFIG to make the code cleaner and more concise.
2 - Utilized individual algorithm trainers from rllib.agents instead of importing them directly from their respective modules to maintain consistency and readability.
3 - Created a private method _get_default_config to handle retrieving the default configuration for each model, reducing code duplication.
4 - Improved error handling in the DRL_prediction method by catching exceptions and raising a ValueError with a meaningful error message.

# DRL models from RLlib
from __future__ import annotations

import ray
from ray.rllib.agents import with_common_config

from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG

# Import individual algorithms for easier access
from ray.rllib.agents.a3c import A3CTrainer, DEFAULT_CONFIG as A3C_CONFIG
from ray.rllib.agents.ddpg import DDPGTrainer, DEFAULT_CONFIG as DDPG_CONFIG
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG as PPO_CONFIG
from ray.rllib.agents.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.agents.ddpg import TD3Trainer, DEFAULT_CONFIG as TD3_CONFIG

MODELS = {"a3c": A3CTrainer, "ddpg": DDPGTrainer, "td3": TD3Trainer, "sac": SACTrainer, "ppo": PPOTrainer}

class DRLAgent:
    """Implementations for DRL algorithms

    Attributes
    ----------
        env: gym environment class
            user-defined class
        price_array: numpy array
            OHLC data
        tech_array: numpy array
            techical data
        turbulence_array: numpy array
            turbulence/risk data
    Methods
    -------
        get_model()
            setup DRL algorithms
        train_model()
            train DRL algorithms in a train dataset
            and output the trained model
        DRL_prediction()
            make a prediction in a test dataset and get results
    """

    def __init__(self, env, price_array, tech_array, turbulence_array):
        self.env = env
        self.price_array = price_array
        self.tech_array = tech_array
        self.turbulence_array = turbulence_array

    def get_model(
        self,
        model_name,
        # policy="MlpPolicy",
        # policy_kwargs=None,
        # model_kwargs=None,
    ):
        if model_name not in MODELS:
            raise NotImplementedError("NotImplementedError")

        model = MODELS[model_name]
        model_config = self._get_default_config(model_name)

        # pass env, log_level, price_array, tech_array, and turbulence_array to config
        model_config["env"] = self.env
        model_config["log_level"] = "WARN"
        model_config["env_config"] = {
            "price_array": self.price_array,
            "tech_array": self.tech_array,
            "turbulence_array": self.turbulence_array,
            "if_train": True,
        }

        return model, model_config

    def train_model(
        self, model, model_name, model_config, total_episodes=100, init_ray=True
    ):
        if model_name not in MODELS:
            raise NotImplementedError("NotImplementedError")
        if init_ray:
            ray.init(
                ignore_reinit_error=True
            )

        trainer = model(env=self.env, config=model_config)

        for _ in range(total_episodes):
            trainer.train()

        ray.shutdown()

        cwd = "./test_" + str(model_name)
        trainer.save(cwd)

        return trainer

    @staticmethod
    def DRL_prediction(
        model_name,
        env,
        price_array,
        tech_array,
        turbulence_array,
        agent_path="./test_ppo/checkpoint_000100/checkpoint-100",
    ):
        if model_name not in MODELS:
            raise NotImplementedError("NotImplementedError")

        model = MODELS[model_name]
        model_config = self._get_default_config(model_name)

        model_config["env"] = env
        model_config["log_level"] = "WARN"
        model_config["env_config"] = {
            "price_array": price_array,
            "tech_array": tech_array,
            "turbulence_array": turbulence_array,
            "if_train": False,
        }
        env_config = {
            "price_array": price_array,
            "tech_array": tech_array,
            "turbulence_array": turbulence_array,
            "if_train": False,
        }
        env_instance = env(config=env_config)

        trainer = model(env=env, config=model_config)

        try:
            trainer.restore(agent_path)
            print("Restoring from checkpoint path", agent_path)
        except BaseException as e:
            raise ValueError("Fail to load agent!") from e

        state = env_instance.reset()
        episode_returns = []
        episode_total_assets = [env_instance.initial_total_asset]
        done = False
        while not done:
            action = trainer.compute_single_action(state)
            state, reward, done, _ = env_instance.step(action)

            total_asset = (
                env_instance.amount
                + (env_instance.price_ary[env_instance.day] * env_instance.stocks).sum()
            )
            episode_total_assets.append(total_asset)
            episode_return = total_asset / env_instance.initial_total_asset
            episode_returns.append(episode_return)

        ray.shutdown()
        print("episode return: " + str(episode_return))
        print("Test Finished!")
        return episode_total_assets

    @staticmethod
    def _get_default_config(model_name):
        model = MODELS[model_name]
        if model_name == "a3c":
            return A3C_CONFIG.copy()
        elif model_name == "ddpg":
            return DDPG_CONFIG.copy()
        elif model_name == "td3":
            return TD3_CONFIG.copy()
        elif model_name == "sac":
            return SAC_CONFIG.copy()
        elif model_name == "ppo":
            return PPO_CONFIG.copy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant