You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
1 - Imported with_common_config, Trainer, and COMMON_CONFIG to make the code cleaner and more concise. 2 - Utilized individual algorithm trainers from rllib.agents instead of importing them directly from their respective modules to maintain consistency and readability. 3 - Created a private method _get_default_config to handle retrieving the default configuration for each model, reducing code duplication. 4 - Improved error handling in the DRL_prediction method by catching exceptions and raising a ValueError with a meaningful error message.
# DRL models from RLlib
from __future__ import annotations
import ray
from ray.rllib.agents import with_common_config
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
# Import individual algorithms for easier access
from ray.rllib.agents.a3c import A3CTrainer, DEFAULT_CONFIG as A3C_CONFIG
from ray.rllib.agents.ddpg import DDPGTrainer, DEFAULT_CONFIG as DDPG_CONFIG
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG as PPO_CONFIG
from ray.rllib.agents.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.agents.ddpg import TD3Trainer, DEFAULT_CONFIG as TD3_CONFIG
MODELS = {"a3c": A3CTrainer, "ddpg": DDPGTrainer, "td3": TD3Trainer, "sac": SACTrainer, "ppo": PPOTrainer}
class DRLAgent:
"""Implementations for DRL algorithms
Attributes
----------
env: gym environment class
user-defined class
price_array: numpy array
OHLC data
tech_array: numpy array
techical data
turbulence_array: numpy array
turbulence/risk data
Methods
-------
get_model()
setup DRL algorithms
train_model()
train DRL algorithms in a train dataset
and output the trained model
DRL_prediction()
make a prediction in a test dataset and get results
"""
def __init__(self, env, price_array, tech_array, turbulence_array):
self.env = env
self.price_array = price_array
self.tech_array = tech_array
self.turbulence_array = turbulence_array
def get_model(
self,
model_name,
# policy="MlpPolicy",
# policy_kwargs=None,
# model_kwargs=None,
):
if model_name not in MODELS:
raise NotImplementedError("NotImplementedError")
model = MODELS[model_name]
model_config = self._get_default_config(model_name)
# pass env, log_level, price_array, tech_array, and turbulence_array to config
model_config["env"] = self.env
model_config["log_level"] = "WARN"
model_config["env_config"] = {
"price_array": self.price_array,
"tech_array": self.tech_array,
"turbulence_array": self.turbulence_array,
"if_train": True,
}
return model, model_config
def train_model(
self, model, model_name, model_config, total_episodes=100, init_ray=True
):
if model_name not in MODELS:
raise NotImplementedError("NotImplementedError")
if init_ray:
ray.init(
ignore_reinit_error=True
)
trainer = model(env=self.env, config=model_config)
for _ in range(total_episodes):
trainer.train()
ray.shutdown()
cwd = "./test_" + str(model_name)
trainer.save(cwd)
return trainer
@staticmethod
def DRL_prediction(
model_name,
env,
price_array,
tech_array,
turbulence_array,
agent_path="./test_ppo/checkpoint_000100/checkpoint-100",
):
if model_name not in MODELS:
raise NotImplementedError("NotImplementedError")
model = MODELS[model_name]
model_config = self._get_default_config(model_name)
model_config["env"] = env
model_config["log_level"] = "WARN"
model_config["env_config"] = {
"price_array": price_array,
"tech_array": tech_array,
"turbulence_array": turbulence_array,
"if_train": False,
}
env_config = {
"price_array": price_array,
"tech_array": tech_array,
"turbulence_array": turbulence_array,
"if_train": False,
}
env_instance = env(config=env_config)
trainer = model(env=env, config=model_config)
try:
trainer.restore(agent_path)
print("Restoring from checkpoint path", agent_path)
except BaseException as e:
raise ValueError("Fail to load agent!") from e
state = env_instance.reset()
episode_returns = []
episode_total_assets = [env_instance.initial_total_asset]
done = False
while not done:
action = trainer.compute_single_action(state)
state, reward, done, _ = env_instance.step(action)
total_asset = (
env_instance.amount
+ (env_instance.price_ary[env_instance.day] * env_instance.stocks).sum()
)
episode_total_assets.append(total_asset)
episode_return = total_asset / env_instance.initial_total_asset
episode_returns.append(episode_return)
ray.shutdown()
print("episode return: " + str(episode_return))
print("Test Finished!")
return episode_total_assets
@staticmethod
def _get_default_config(model_name):
model = MODELS[model_name]
if model_name == "a3c":
return A3C_CONFIG.copy()
elif model_name == "ddpg":
return DDPG_CONFIG.copy()
elif model_name == "td3":
return TD3_CONFIG.copy()
elif model_name == "sac":
return SAC_CONFIG.copy()
elif model_name == "ppo":
return PPO_CONFIG.copy()
The text was updated successfully, but these errors were encountered:
The text was updated successfully, but these errors were encountered: