{ "cells": [ { "cell_type": "markdown", "id": "545e3473", "metadata": {}, "source": [ "# Stablebaseline3 Masked PPO " ] }, { "cell_type": "code", "execution_count": null, "id": "ca0c8d0b", "metadata": {}, "outputs": [], "source": [ "import gymnasium as gym\n", "import sb3_contrib\n", "import numpy as np\n", "from stable_baselines3.common.monitor import Monitor" ] }, { "cell_type": "code", "execution_count": null, "id": "163ed03e", "metadata": {}, "outputs": [], "source": [ "from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv\n", "from graph_jsp_env.disjunctive_graph_logger import log\n", "from sb3_contrib.common.wrappers import ActionMasker\n", "from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy" ] }, { "cell_type": "code", "execution_count": null, "id": "367c9f7e", "metadata": {}, "outputs": [], "source": [ "jsp = np.array([\n", " [[1, 2, 0], # job 0\n", " [0, 2, 1]], # job 1\n", " [[17, 12, 19], # task durations of job 0\n", " [8, 6, 2]] # task durations of job 1\n", "])" ] }, { "cell_type": "code", "execution_count": null, "id": "d3dd47b6", "metadata": {}, "outputs": [], "source": [ "env = DisjunctiveGraphJspEnv(\n", " jps_instance=jsp,\n", " perform_left_shift_if_possible=True,\n", " normalize_observation_space=True,\n", " flat_observation_space=True,\n", " action_mode='task', # alternative 'job'\n", ")\n", "env = Monitor(env)" ] }, { "cell_type": "code", "execution_count": null, "id": "1035ba3c", "metadata": {}, "outputs": [], "source": [ "def mask_fn(env: gym.Env) -> np.ndarray:\n", " return env.unwrapped.valid_action_mask()" ] }, { "cell_type": "code", "execution_count": null, "id": "dbc7a1f9", "metadata": {}, "outputs": [], "source": [ "env = ActionMasker(env, mask_fn)" ] }, { "cell_type": "code", "execution_count": null, "id": "15366ef4", "metadata": {}, "outputs": [], "source": [ "model = sb3_contrib.MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)" ] }, { "cell_type": "code", "execution_count": null, "id": "48f8df00", "metadata": {}, "outputs": [], "source": [ "# Train the agent\n", "log.info(\"training the model\")\n", "model.learn(total_timesteps=10_000)" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }