Stablebaseline3 Masked PPO

[1]:
import gymnasium as gym
import sb3_contrib
import numpy as np
from stable_baselines3.common.monitor import Monitor
[2]:
from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
from graph_jsp_env.disjunctive_graph_logger import log
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy



     ▐███▌         ▟███▛▟███████▛▐███▌
     ▐███▌        ▟███▛▟███████▛ ▐███▌
     ▐███▌ ▟███  ▟███▛    ▟███▛  ▐███▌
     ▐███▌▟████ ▟███▛    ▟███▛   ▐███▌
     ▐██████▛▐█████▛    ▟████████▐█████████▛
     ▐█████▛ ▐████▛    ▟█████████▐████████▛


     ▐█▀▜▙█▙   ███████▐█ ▐█    ▟█▙   ▟█▙  ▟███▐█ ▐█▐█▀▀▐█▙ █
     ▐█▄▟▛▜█▙▟▙██ ▐█  ▐████   ▟▛ ▜▙ ▟▛ ▜▙ █▍  ▐████▐█▀▀▐██▙█
     ▐█ ▜▙ ▜█▛▜██ ▐█  ▐█ ▐█  ▟█▛▀▜█▙█▛▀▜█▙▜███▐█ ▐█▐█▆▆▐█ ▜█
           ▐█  ▐█▐█▙ █▐███▜█▙ ▟███▀▀▐█▀▜▙▟█▀▜█▐███▐█████▙ ▟▛
           ▐█  ▐█▐██▙█ ▐█  ▜█▄█▛▐█▀▀▐█▄▟▛▜█▆▆▄ ▐█  ▐█  ▜█▄▛
            ▜███▛▐█ ▜█▐███  ▜█▛ ▐███▐█ ▜▙▐█▆▆▛▐███ ▐█   ██

    
    Graph Matrix Job Shop Problem Environment
    

    Version:    1.1.0

[3]:
jsp = np.array([
    [[1, 2, 0],  # job 0
     [0, 2, 1]],  # job 1
    [[17, 12, 19],  # task durations of job 0
     [8, 6, 2]]  # task durations of job 1
])
[4]:
env = DisjunctiveGraphJspEnv(
    jps_instance=jsp,
    perform_left_shift_if_possible=True,
    normalize_observation_space=True,
    flat_observation_space=True,
    action_mode='task',  # alternative 'job'
)
env = Monitor(env)
[5]:
def mask_fn(env: gym.Env) -> np.ndarray:
    return env.unwrapped.valid_action_mask()
[6]:
env = ActionMasker(env, mask_fn)
[7]:
model = sb3_contrib.MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)
Using cpu device
Wrapping the env in a DummyVecEnv.
[8]:
# Train the agent
log.info("training the model")
model.learn(total_timesteps=10_000)
[09:51:29] INFO     training the model
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 6        |
|    ep_rew_mean     | -49.9    |
| time/              |          |
|    fps             | 1336     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 6           |
|    ep_rew_mean          | -48.6       |
| time/                   |             |
|    fps                  | 1129        |
|    iterations           | 2           |
|    time_elapsed         | 3           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.010743486 |
|    clip_fraction        | 0.122       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.47       |
|    explained_variance   | -0.00118    |
|    learning_rate        | 0.0003      |
|    loss                 | 421         |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0223     |
|    value_loss           | 1.31e+03    |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 6           |
|    ep_rew_mean          | -48.3       |
| time/                   |             |
|    fps                  | 1079        |
|    iterations           | 3           |
|    time_elapsed         | 5           |
|    total_timesteps      | 6144        |
| train/                  |             |
|    approx_kl            | 0.013307016 |
|    clip_fraction        | 0.135       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.458      |
|    explained_variance   | 0.000836    |
|    learning_rate        | 0.0003      |
|    loss                 | 299         |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.0145     |
|    value_loss           | 783         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 6           |
|    ep_rew_mean          | -48         |
| time/                   |             |
|    fps                  | 1056        |
|    iterations           | 4           |
|    time_elapsed         | 7           |
|    total_timesteps      | 8192        |
| train/                  |             |
|    approx_kl            | 0.013091141 |
|    clip_fraction        | 0.0723      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.425      |
|    explained_variance   | 0.000584    |
|    learning_rate        | 0.0003      |
|    loss                 | 169         |
|    n_updates            | 30          |
|    policy_gradient_loss | -0.0109     |
|    value_loss           | 474         |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 6           |
|    ep_rew_mean          | -48         |
| time/                   |             |
|    fps                  | 1043        |
|    iterations           | 5           |
|    time_elapsed         | 9           |
|    total_timesteps      | 10240       |
| train/                  |             |
|    approx_kl            | 0.018714432 |
|    clip_fraction        | 0.094       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.372      |
|    explained_variance   | 0.00048     |
|    learning_rate        | 0.0003      |
|    loss                 | 86          |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.00577    |
|    value_loss           | 254         |
-----------------------------------------
[8]:
<sb3_contrib.ppo_mask.ppo_mask.MaskablePPO at 0x73bb9be42c10>