ckmai/lec6.ipynb

550 lines
126 KiB
Plaintext
Raw Permalink Normal View History

2025-01-10 15:07:53 +04:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[['_' '_' 'X']\n",
" ['X' '' '']]\n",
"[[0. 0. 0.]\n",
" [0. 0. 0.]\n",
" [0. 0. 0.]]\n",
"[['_' '_' 'X']\n",
" ['X' '' '']]\n",
"[[-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"[['_' '_' 'X']\n",
" ['' 'X' '']]\n",
"[[-0.1 -0.1 0. ]\n",
" [ 0. 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"[['_' '_' 'X']\n",
" ['X' '' '']]\n",
"[[-0.1 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"[['_' '_' 'X']\n",
" ['' '' 'X']]\n",
"[[-0.1 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"Episode 1 is done\n",
"[['_' '_' 'X']\n",
" ['X' '' '']]\n",
"[[-0.19 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"[['_' '_' 'X']\n",
" ['' '' 'X']]\n",
"[[-0.19 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"Episode 2 is done\n",
"[['_' '_' 'X']\n",
" ['' '' 'X']]\n",
"[[-0.19 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"Episode 3 is done\n",
"[['_' '_' 'X']\n",
" ['' '' 'X']]\n",
"[[-0.19 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"Episode 4 is done\n",
"[['_' '_' 'X']\n",
" ['' '' 'X']]\n",
"[[-0.19 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"Episode 5 is done\n",
"[['_' '_' 'X']\n",
" ['X' '' '']]\n",
"[[-0.271 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"[['_' '_' 'X']\n",
" ['' '' 'X']]\n",
"[[-0.271 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"Episode 6 is done\n",
"[['_' '_' 'X']\n",
" ['' '' 'X']]\n",
"[[-0.271 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"Episode 7 is done\n",
"[['_' '_' 'X']\n",
" ['' '' 'X']]\n",
"[[-0.271 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"Episode 8 is done\n",
"[['_' '_' 'X']\n",
" ['' '' 'X']]\n",
"[[-0.271 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"Episode 9 is done\n",
"[['_' '_' 'X']\n",
" ['' '' 'X']]\n",
"[[-0.271 -0.1 0. ]\n",
" [-0.1 0. 0. ]\n",
" [ 0. 0. 0. ]]\n",
"Episode 10 is done\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"def set_state(Q, map, state):\n",
" map[1] = [\"\", \"\", \"\"]\n",
" map[1][state] = \"X\"\n",
" print(map)\n",
" print(Q)\n",
"\n",
"Q = np.zeros([3, 3])\n",
"map = np.array([[\"_\", \"_\", \"X\"], [\"\", \"\", \"\"]])\n",
"\n",
"max_steps = 10\n",
"alpha = 0.1 # коэф. обучения\n",
"gamma = 0.9 # коэф. дисконтирования\n",
"epsilon = 0.1 # параметр исследования vs. эксплуатации\n",
"\n",
"actions = {\"left\": -1, \"right\": 1}\n",
"\n",
"for episode in range(max_steps): # колво эпизодов обучения\n",
" state = 0 # стартовое состояние\n",
" if episode == 0:\n",
" set_state(Q, map, state)\n",
"\n",
" while state != 2: # пока не достигнута цель\n",
" # выбор действия\n",
" if np.random.rand() < epsilon:\n",
" direction = actions[list(actions.keys())[np.random.randint(0, 1)]]\n",
" action = np.clip(state + direction, 0, 2) # случайное действие\n",
" else:\n",
" action = np.argmax(Q[state])\n",
"\n",
" # переход в новое состояние и получение награды\n",
" new_state = action\n",
" reward = (\n",
" -1 if new_state != 2 else 0\n",
" ) # награда -1 за каждый шаг, 0 за достижение цели\n",
"\n",
" # обновление Q-значения\n",
" Q[state, action] = (1 - alpha) * Q[state, action] + alpha * (\n",
" reward + gamma * np.max(Q[new_state])\n",
" )\n",
"\n",
" state = new_state\n",
" set_state(Q, map, state)\n",
" \n",
" print(f\"Episode {episode + 1} is done\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\user\\Projects\\python\\ckmai\\.venv\\Lib\\site-packages\\gymnasium\\envs\\registration.py:642: UserWarning: \u001b[33mWARN: Overriding environment gymnasium_env/MyEnv-v1 already in registry.\u001b[0m\n",
" logger.warn(f\"Overriding environment {new_spec.id} already in registry.\")\n"
]
}
],
"source": [
"import gymnasium as gym\n",
"\n",
"\n",
"class MyAgent:\n",
" def __init__(\n",
" self,\n",
" env: gym.Env,\n",
" learning_rate: float,\n",
" epsilon: float,\n",
" discount_factor,\n",
" ):\n",
" self.env = env\n",
" self.q_values = np.zeros([3, 3]) # type: ignore\n",
"\n",
" self.learning_rate = learning_rate\n",
" self.discount_factor = discount_factor\n",
"\n",
" self.epsilon = epsilon\n",
"\n",
" self.training_error = []\n",
"\n",
" def get_action(self, state: int) -> int:\n",
" if np.random.rand() < self.epsilon:\n",
" return self.env.action_space.sample()\n",
" else:\n",
" return int(np.argmax(self.q_values[state]))\n",
"\n",
" def update(\n",
" self,\n",
" state: int,\n",
" action: int,\n",
" reward: float,\n",
" terminated: bool,\n",
" new_state,\n",
" ):\n",
" future_q_value = (not terminated) * np.max(self.q_values[new_state])\n",
" current_q_value = self.q_values[state][action]\n",
" temporal_difference = (\n",
" reward + self.discount_factor * future_q_value - current_q_value\n",
" )\n",
"\n",
" self.q_values[state][action] = (\n",
" 1 - self.learning_rate\n",
" ) * current_q_value + self.learning_rate * (\n",
" reward + self.discount_factor * future_q_value\n",
" )\n",
"\n",
" self.training_error.append(temporal_difference)\n",
"\n",
" def render(self):\n",
" return \"\\n\".join(\n",
" [\", \".join([str(item) for item in row]) for row in self.q_values]\n",
" )\n",
"\n",
"\n",
"class MyEnv(gym.Env):\n",
" metadata = {\"render_modes\": [\"ansi\", \"rgb_array\"], \"render_fps\": 4}\n",
"\n",
" def __init__(self, render_mode=None):\n",
" self.size = 3\n",
"\n",
" self.render_mode = render_mode\n",
"\n",
" self._agent_location = 0\n",
" self._target_location = 2\n",
"\n",
" self.observation_space = gym.spaces.Dict(\n",
" {\n",
" \"agent\": gym.spaces.Discrete(1),\n",
" \"target\": gym.spaces.Discrete(1),\n",
" }\n",
" )\n",
"\n",
" self.action_space = gym.spaces.Discrete(2)\n",
" self._action_to_direction = {\n",
" 0: -1, # left\n",
" 1: 1, # right\n",
" }\n",
"\n",
" def _get_obs(self):\n",
" return {\n",
" \"agent\": self._agent_location,\n",
" \"target\": self._target_location,\n",
" }\n",
"\n",
" def _get_info(self):\n",
" return {\"distance\": abs(self._agent_location - self._target_location)}\n",
"\n",
" def reset(self, seed=None, options=None):\n",
" super().reset(seed=seed)\n",
" self._agent_location = 0\n",
" self._target_location = 2\n",
" observation = self._get_obs()\n",
" info = self._get_info()\n",
"\n",
" return observation, info\n",
"\n",
" def step(self, action):\n",
" direction = 0 if action == 2 else self._action_to_direction[action]\n",
" self._agent_location = np.clip(\n",
" self._agent_location + direction, 0, self.size - 1\n",
" )\n",
" if (action == 2):\n",
" self._agent_location = self._target_location\n",
"\n",
" terminated = self._agent_location == self._target_location\n",
" truncated = False\n",
" reward = -1 if not terminated else 0\n",
" observation = self._get_obs()\n",
" info = self._get_info()\n",
"\n",
" return observation, reward, terminated, truncated, info\n",
"\n",
" def render(self):\n",
" map = ['_', '_', '_']\n",
" map[self._agent_location] = \"A\"\n",
" map[self._target_location] = \"X\"\n",
" return \" \".join(map)\n",
"\n",
"\n",
"env_id = \"gymnasium_env/MyEnv-v1\"\n",
"gym.register(\n",
" id=env_id,\n",
" entry_point=MyEnv, # type: ignore\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A _ X\n",
"0.0, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Start episode 1\n",
"\n",
"Start action 1\n",
"A _ X\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"Start action 2\n",
"_ A X\n",
"-0.1, -0.1, 0.0\n",
"0.0, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"Start action 3\n",
"A _ X\n",
"-0.1, -0.1, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"Start action 4\n",
"_ _ X\n",
"-0.1, -0.1, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Episode 1 is done\n",
"\n",
"===> Start episode 2\n",
"\n",
"Start action 5\n",
"_ _ X\n",
"-0.1, -0.1, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Episode 2 is done\n",
"\n",
"===> Start episode 3\n",
"\n",
"Start action 6\n",
"_ _ X\n",
"-0.1, -0.1, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Episode 3 is done\n",
"\n",
"===> Start episode 4\n",
"\n",
"Start action 7\n",
"_ A X\n",
"-0.1, -0.19, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"Start action 8\n",
"_ _ X\n",
"-0.1, -0.19, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Episode 4 is done\n",
"\n",
"===> Start episode 5\n",
"\n",
"Start action 9\n",
"_ _ X\n",
"-0.1, -0.19, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Episode 5 is done\n",
"\n",
"===> Start episode 6\n",
"\n",
"Start action 10\n",
"_ _ X\n",
"-0.1, -0.19, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Episode 6 is done\n",
"\n",
"===> Start episode 7\n",
"\n",
"Start action 11\n",
"_ _ X\n",
"-0.1, -0.19, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Episode 7 is done\n",
"\n",
"===> Start episode 8\n",
"\n",
"Start action 12\n",
"_ _ X\n",
"-0.1, -0.19, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Episode 8 is done\n",
"\n",
"===> Start episode 9\n",
"\n",
"Start action 13\n",
"_ A X\n",
"-0.1, -0.271, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"Start action 14\n",
"_ _ X\n",
"-0.1, -0.271, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Episode 9 is done\n",
"\n",
"===> Start episode 10\n",
"\n",
"Start action 15\n",
"_ _ X\n",
"-0.1, -0.271, 0.0\n",
"-0.1, 0.0, 0.0\n",
"0.0, 0.0, 0.0\n",
"\n",
"===> Episode 10 is done\n"
]
}
],
"source": [
"from gymnasium.wrappers import RecordEpisodeStatistics\n",
"\n",
"myenv = gym.make(env_id, render_mode=\"ansi\", max_episode_steps=max_steps)\n",
"myenv = RecordEpisodeStatistics(myenv, buffer_length=max_steps)\n",
"\n",
"agent = MyAgent(\n",
" env=myenv,\n",
" learning_rate=alpha,\n",
" epsilon=epsilon,\n",
" discount_factor=gamma,\n",
")\n",
"\n",
"action_num = 0\n",
"for episode in range(max_steps):\n",
" obs, info = myenv.reset()\n",
" done = False\n",
"\n",
" if episode == 0:\n",
" print(myenv.render())\n",
" print(agent.render())\n",
"\n",
" print(f\"\\n===> Start episode {episode + 1}\")\n",
"\n",
" while not done:\n",
" print(f\"\\nStart action {action_num + 1}\")\n",
" action_num = action_num + 1\n",
"\n",
" action = agent.get_action(obs[\"agent\"])\n",
" next_obs, reward, terminated, truncated, info = myenv.step(action)\n",
"\n",
" # update the agent\n",
" agent.update(obs[\"agent\"], action, float(reward), terminated, next_obs[\"agent\"])\n",
"\n",
" # update if the environment is done and the current obs\n",
" done = terminated or truncated\n",
" obs = next_obs\n",
" print(myenv.render())\n",
" print(agent.render())\n",
"\n",
" print(f\"\\n===> Episode {episode + 1} is done\")\n",
"\n",
"myenv.close()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB8UAAAMWCAYAAABoQVdvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/GU6VOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXxcBbn/8e+ZmcxMkskkaZIutKVrWlqgLK2UlrZBFot4UXBDhVtARUVREL1IlUUBQb0XgSuyuAEuXBB+ghuLiHQBWqCFArK06U5Ll+yTPZmZ8/tj5kwSmn2WM8vn/Xr1pZ3M8jTNDKfne57nMUzTNAUAAAAAAAAAAAAAQBZy2F0AAAAAAAAAAAAAAADJQigOAAAAAAAAAAAAAMhahOIAAAAAAAAAAAAAgKxFKA4AAAAAAAAAAAAAyFqE4gAAAAAAAAAAAACArEUoDgAAAAAAAAAAAADIWoTiAAAAAAAAAAAAAICsRSgOAAAAAAAAAAAAAMhahOIAAAAAAAAAAAAAgKxFKA4gZ33/+9+XYRgpfc2dO3fKMAzdd999KX3dTHfhhRdq6tSpdpcBAADQB8eT9rrvvvtkGIY2bNhgdykAAAApF8/5MjuOYwHAboTiADKCdcJroF/r16+3u0TbvP974ff7VVVVpb///e92lwYAAJA2OJ4cmGEYuvTSS+0uY0B33nknFwEAAICMMdgxZ+9fq1atsrtUW1x44YUDfk+8Xq/d5QHIYi67CwCAkbj++us1bdq0Q26fOXPmiJ/r6quv1lVXXZWIsmx3+umna8WKFTJNU7t27dJdd92ls846S0888YSWL19ud3kAAABpg+PJzHPnnXeqvLxcF154od2lAAAADOl3v/tdn9//9re/1dNPP33I7XPmzInrdX75y18qHA6P6rF2H8d6PB796le/OuR2p9NpQzUAcgWhOICM8uEPf1gLFixIyHO5XC65XNnxMThr1iydf/75sd9/4hOf0Ny5c3X77bdnRCje0dEht9sth4MBJgAAILk4ngQAAEAy9T5HJ0nr16/X008/fcjt79fW1qaCgoJhv05eXt6o6pPsP451uVxDfj/609raqsLCwn6/NtLv3/sFg0GFw2G53e5RPweA9Eb6ACCrWDsW/+d//ke33nqrpkyZovz8fFVVVenf//53n/v2tzvn6aef1pIlS1RSUiKfz6fZs2fru9/9bp/7HDx4UF/4whc0btw4eb1eHXPMMbr//vsPqaWxsVEXXnihiouLVVJSogsuuECNjY391v3OO+/ok5/8pMaMGSOv16sFCxboL3/5y6i/D3PmzFF5ebm2bdvW5/bOzk5dd911mjlzpjwejyZPnqwrr7xSnZ2dsft8/OMf1/HHH9/ncWeddZYMw+hT04svvijDMPTEE09Ikurr6/Xtb39bRx99tHw+n/x+vz784Q/rtdde6/Ncq1atkmEYevDBB3X11Vdr4sSJKigoUCAQkCQ99thjOuqoo+T1enXUUUfp0Ucf7ffP+OCDD2r+/PkqKiqS3+/X0Ucfrdtvv33U3zMAAACJ48mBhMNh3XbbbTryyCPl9Xo1btw4ffnLX1ZDQ0Of+02dOlX/8R//oeeee04nnHCCvF6vpk+frt/+9reHPOfrr7+uqqoq5efna9KkSbrxxht17733yjAM7dy5M/Z8b775plavXh0bq3nyySf3eZ7Ozk5dccUVqqioUGFhoc455xzV1NT0uc+GDRu0fPlylZeXKz8/X9OmTdPnP//5hH1/AAAARuLkk0/WUUcdpY0bN2rZsmUqKCiIHTP++c9/1kc+8hEddthh8ng8mjFjhm644QaFQqE+z/H+neK9j2N/8YtfaMaMGfJ4PPrABz6gl19+uc9j+zuOtVbqWOfmPB6PjjzySD355JOH1L9q1SotWLBAXq9XM2bM0D333JPwPeXW+qPVq1frq1/9qsaOHatJkyZJGvz7N5xj7d7fq9tuuy32vXrrrbcSVj+A9MMl7QAySlNTk2pra/vcZhiGysrK+tz229/+Vs3Nzfra176mjo4O3X777TrllFP0xhtvaNy4cf0+95tvvqn/+I//0Lx583T99dfL4/Fo69atev7552P3aW9v18knn6ytW7fq0ksv1bRp0/Twww/rwgsvVGNjoy677DJJkmma+tjHPqbnnntOX/nKVzRnzhw9+uijuuCCC/p93ZNOOkkTJ07UVVddpcLCQv3xj3/U2Wefrf/3//6fzjnnnFF9nxoaGjRjxozYbeFwWB/96Ef13HPP6Utf+pLmzJmjN954Q7feequ2bNmixx57TJK0dOlS/fnPf1YgEJDf75dpmnr++eflcDi0du1affSjH5UkrV27Vg6HQyeddJIkafv27Xrsscf0qU99StOmTdOBAwd0zz33qKqqSm+99ZYOO+ywPjXecMMNcrvd+va3v63Ozk653W794x//iHW533zzzaqrq9NFF10UO+C1PP300/rsZz+rU089VT/+8Y8lSW+//baef/752N8BAABAfzieHJ0vf/nLuu+++3TRRRfpG9/4hnbs2KE77rhDr776qp5//vk+nUpbt27VJz/5SX3hC1/QBRdcoN/85je68MILNX/+fB155JGSpL179+qDH/ygDMPQypUrVVhYqF/96lfyeDx9Xve2227T17/+dfl8Pn3ve9+TpEO+/1//+tdVWlqq6667Tjt37tRtt92mSy+9VA899JCkyInRD33oQ6qoqNBVV12lkpIS7dy5U3/605/i/r4AAACMVl1dnT784Q/rM5/5jM4///zYMc59990nn8+nK664Qj6fT//617907bXXKhAI6L//+7+HfN4HHnhAzc3N+vKXvyzDMPSTn/xEH//4x7V9+/Yhu8ufe+45/elPf9JXv/pVFRUV6X//93/1iU98Qrt3744dL7/66qs644wzNGHCBP3gBz9QKBTS9ddfr4qKihH9+d9/TC5Jbrdbfr+/z21f/epXVVFRoWuvvVatra2x2/v7/g33WNty7733qqOjQ1/60pfk8Xg0ZsyYEf0ZAGQYEwAywL333mtK6veXx+OJ3W/Hjh2mJDM/P9/cs2dP7PYXX3zRlGR+85vfjN123XXXmb0/Bm+99VZTkllTUzNgHbfddpspyfz9738fu62rq8tctGiR6fP5zEAgYJqmaT722GOmJPMnP/lJ7H7BYNBcunSpKcm89957Y7efeuqp5tFHH212dHTEbguHw+bixYvNysrKIb83kswvfOELZk1NjXnw4EFzw4YN5hlnnGFKMv/7v/87dr/f/e53psPhMNeuXdvn8XfffbcpyXz++edN0zTNl19+2ZRkPv7446Zpmubrr79uSjI/9alPmQsXLow97qMf/ah53HHHxX7f0dFhhkKhPs+9Y8cO0+PxmNdff33stmeffdaUZE6fPt1sa2vrc/9jjz3WnDBhgtnY2Bi77R//+IcpyZwyZUrstssuu8z0+/1mMBgc8vsDAABgmhxPDkaS+bWvfW3Ar69du9aUZP7hD3/oc/uTTz55yO1TpkwxJZlr1qyJ3Xbw4EHT4/GY3/rWt2K3ff3rXzcNwzBfffXV2G11dXXmmDFjTEnmjh07YrcfeeSRZlVV1SF1WX+np512mhkOh2O3f/Ob3zSdTmfsmPLRRx81JZkvv/zykN8LAACARPva177W55jRNE2zqqrKlGTefffdh9z//efLTNM0v/zlL5sFBQV9jvcuuOCCPufLrOPYsrIys76+Pnb7n//8Z1OS+de//jV22/uPY00zckzodrvNrVu3xm577bXXTEnmz372s9htZ511lllQUGDu3bs3dlt1dbXpcrkOec7+XHDBBQMely9fvjx2P+tYb8mSJYecAxzo+zfcY23re+X3+82DBw8OWTOA7MD4dAAZ5ec//7mefvrpPr+s8d29nX322Zo4cWLs9yeccIIWLlyoxx9/fMDnLikpkRQZURQOh/u9z+OPP67x48frs5/9bOy2vLw8feMb31BLS4tWr14du5/L5dIll1wSu5/T6dTXv/71Ps9XX1+vf/3rX/r0pz+t5uZm1dbWqra2VnV1dVq+fLmqq6u1d+/
"text/plain": [
"<Figure size 2000x800 with 3 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"\n",
"fig, axs = plt.subplots(1, 3, figsize=(20, 8))\n",
"\n",
"axs[0].plot(np.convolve(myenv.return_queue, np.ones(1))) # type: ignore\n",
"axs[0].set_title(\"Episode Rewards\")\n",
"axs[0].set_xlabel(\"Episode\")\n",
"axs[0].set_ylabel(\"Reward\")\n",
"\n",
"axs[1].plot(np.convolve(myenv.length_queue, np.ones(1))) # type: ignore\n",
"axs[1].set_title(\"Episode Lengths\")\n",
"axs[1].set_xlabel(\"Episode\")\n",
"axs[1].set_ylabel(\"Length\")\n",
"\n",
"axs[2].plot(np.convolve(agent.training_error, np.ones(1) * -1))\n",
"axs[2].set_title(\"Training Error\")\n",
"axs[2].set_xlabel(\"Action\")\n",
"axs[2].set_ylabel(\"Temporal Difference\")\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}