Stablebaseline3 Masked PPO¶
[1]:
import gymnasium as gym
import sb3_contrib
import numpy as np
from stable_baselines3.common.monitor import Monitor
[2]:
from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
from graph_jsp_env.disjunctive_graph_logger import log
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
▐███▌ ▟███▛▟███████▛▐███▌
▐███▌ ▟███▛▟███████▛ ▐███▌
▐███▌ ▟███ ▟███▛ ▟███▛ ▐███▌
▐███▌▟████ ▟███▛ ▟███▛ ▐███▌
▐██████▛▐█████▛ ▟████████▐█████████▛
▐█████▛ ▐████▛ ▟█████████▐████████▛
▐█▀▜▙█▙ ███████▐█ ▐█ ▟█▙ ▟█▙ ▟███▐█ ▐█▐█▀▀▐█▙ █
▐█▄▟▛▜█▙▟▙██ ▐█ ▐████ ▟▛ ▜▙ ▟▛ ▜▙ █▍ ▐████▐█▀▀▐██▙█
▐█ ▜▙ ▜█▛▜██ ▐█ ▐█ ▐█ ▟█▛▀▜█▙█▛▀▜█▙▜███▐█ ▐█▐█▆▆▐█ ▜█
▐█ ▐█▐█▙ █▐███▜█▙ ▟███▀▀▐█▀▜▙▟█▀▜█▐███▐█████▙ ▟▛
▐█ ▐█▐██▙█ ▐█ ▜█▄█▛▐█▀▀▐█▄▟▛▜█▆▆▄ ▐█ ▐█ ▜█▄▛
▜███▛▐█ ▜█▐███ ▜█▛ ▐███▐█ ▜▙▐█▆▆▛▐███ ▐█ ██
Graph Matrix Job Shop Problem Environment
Version: 1.1.0
[3]:
jsp = np.array([
[[1, 2, 0], # job 0
[0, 2, 1]], # job 1
[[17, 12, 19], # task durations of job 0
[8, 6, 2]] # task durations of job 1
])
[4]:
env = DisjunctiveGraphJspEnv(
jps_instance=jsp,
perform_left_shift_if_possible=True,
normalize_observation_space=True,
flat_observation_space=True,
action_mode='task', # alternative 'job'
)
env = Monitor(env)
[5]:
def mask_fn(env: gym.Env) -> np.ndarray:
return env.unwrapped.valid_action_mask()
[6]:
env = ActionMasker(env, mask_fn)
[7]:
model = sb3_contrib.MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)
Using cpu device
Wrapping the env in a DummyVecEnv.
[8]:
# Train the agent
log.info("training the model")
model.learn(total_timesteps=10_000)
[09:51:29] INFO training the model
---------------------------------
| rollout/ | |
| ep_len_mean | 6 |
| ep_rew_mean | -49.9 |
| time/ | |
| fps | 1336 |
| iterations | 1 |
| time_elapsed | 1 |
| total_timesteps | 2048 |
---------------------------------
-----------------------------------------
| rollout/ | |
| ep_len_mean | 6 |
| ep_rew_mean | -48.6 |
| time/ | |
| fps | 1129 |
| iterations | 2 |
| time_elapsed | 3 |
| total_timesteps | 4096 |
| train/ | |
| approx_kl | 0.010743486 |
| clip_fraction | 0.122 |
| clip_range | 0.2 |
| entropy_loss | -0.47 |
| explained_variance | -0.00118 |
| learning_rate | 0.0003 |
| loss | 421 |
| n_updates | 10 |
| policy_gradient_loss | -0.0223 |
| value_loss | 1.31e+03 |
-----------------------------------------
-----------------------------------------
| rollout/ | |
| ep_len_mean | 6 |
| ep_rew_mean | -48.3 |
| time/ | |
| fps | 1079 |
| iterations | 3 |
| time_elapsed | 5 |
| total_timesteps | 6144 |
| train/ | |
| approx_kl | 0.013307016 |
| clip_fraction | 0.135 |
| clip_range | 0.2 |
| entropy_loss | -0.458 |
| explained_variance | 0.000836 |
| learning_rate | 0.0003 |
| loss | 299 |
| n_updates | 20 |
| policy_gradient_loss | -0.0145 |
| value_loss | 783 |
-----------------------------------------
-----------------------------------------
| rollout/ | |
| ep_len_mean | 6 |
| ep_rew_mean | -48 |
| time/ | |
| fps | 1056 |
| iterations | 4 |
| time_elapsed | 7 |
| total_timesteps | 8192 |
| train/ | |
| approx_kl | 0.013091141 |
| clip_fraction | 0.0723 |
| clip_range | 0.2 |
| entropy_loss | -0.425 |
| explained_variance | 0.000584 |
| learning_rate | 0.0003 |
| loss | 169 |
| n_updates | 30 |
| policy_gradient_loss | -0.0109 |
| value_loss | 474 |
-----------------------------------------
-----------------------------------------
| rollout/ | |
| ep_len_mean | 6 |
| ep_rew_mean | -48 |
| time/ | |
| fps | 1043 |
| iterations | 5 |
| time_elapsed | 9 |
| total_timesteps | 10240 |
| train/ | |
| approx_kl | 0.018714432 |
| clip_fraction | 0.094 |
| clip_range | 0.2 |
| entropy_loss | -0.372 |
| explained_variance | 0.00048 |
| learning_rate | 0.0003 |
| loss | 86 |
| n_updates | 40 |
| policy_gradient_loss | -0.00577 |
| value_loss | 254 |
-----------------------------------------
[8]:
<sb3_contrib.ppo_mask.ppo_mask.MaskablePPO at 0x73bb9be42c10>