diff --git a/lec6.ipynb b/lec6.ipynb new file mode 100644 index 0000000..0b09299 --- /dev/null +++ b/lec6.ipynb @@ -0,0 +1,549 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[['_' '_' 'X']\n", + " ['X' '' '']]\n", + "[[0. 0. 0.]\n", + " [0. 0. 0.]\n", + " [0. 0. 0.]]\n", + "[['_' '_' 'X']\n", + " ['X' '' '']]\n", + "[[-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "[['_' '_' 'X']\n", + " ['' 'X' '']]\n", + "[[-0.1 -0.1 0. ]\n", + " [ 0. 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "[['_' '_' 'X']\n", + " ['X' '' '']]\n", + "[[-0.1 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "[['_' '_' 'X']\n", + " ['' '' 'X']]\n", + "[[-0.1 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "Episode 1 is done\n", + "[['_' '_' 'X']\n", + " ['X' '' '']]\n", + "[[-0.19 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "[['_' '_' 'X']\n", + " ['' '' 'X']]\n", + "[[-0.19 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "Episode 2 is done\n", + "[['_' '_' 'X']\n", + " ['' '' 'X']]\n", + "[[-0.19 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "Episode 3 is done\n", + "[['_' '_' 'X']\n", + " ['' '' 'X']]\n", + "[[-0.19 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "Episode 4 is done\n", + "[['_' '_' 'X']\n", + " ['' '' 'X']]\n", + "[[-0.19 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "Episode 5 is done\n", + "[['_' '_' 'X']\n", + " ['X' '' '']]\n", + "[[-0.271 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "[['_' '_' 'X']\n", + " ['' '' 'X']]\n", + "[[-0.271 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "Episode 6 is done\n", + "[['_' '_' 'X']\n", + " ['' '' 'X']]\n", + "[[-0.271 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "Episode 7 is done\n", + "[['_' '_' 'X']\n", + " ['' '' 'X']]\n", + "[[-0.271 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "Episode 8 is done\n", + "[['_' '_' 'X']\n", + " ['' '' 'X']]\n", + "[[-0.271 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "Episode 9 is done\n", + "[['_' '_' 'X']\n", + " ['' '' 'X']]\n", + "[[-0.271 -0.1 0. ]\n", + " [-0.1 0. 0. ]\n", + " [ 0. 0. 0. ]]\n", + "Episode 10 is done\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "def set_state(Q, map, state):\n", + " map[1] = [\"\", \"\", \"\"]\n", + " map[1][state] = \"X\"\n", + " print(map)\n", + " print(Q)\n", + "\n", + "Q = np.zeros([3, 3])\n", + "map = np.array([[\"_\", \"_\", \"X\"], [\"\", \"\", \"\"]])\n", + "\n", + "max_steps = 10\n", + "alpha = 0.1 # коэф. обучения\n", + "gamma = 0.9 # коэф. дисконтирования\n", + "epsilon = 0.1 # параметр исследования vs. эксплуатации\n", + "\n", + "actions = {\"left\": -1, \"right\": 1}\n", + "\n", + "for episode in range(max_steps): # колво эпизодов обучения\n", + " state = 0 # стартовое состояние\n", + " if episode == 0:\n", + " set_state(Q, map, state)\n", + "\n", + " while state != 2: # пока не достигнута цель\n", + " # выбор действия\n", + " if np.random.rand() < epsilon:\n", + " direction = actions[list(actions.keys())[np.random.randint(0, 1)]]\n", + " action = np.clip(state + direction, 0, 2) # случайное действие\n", + " else:\n", + " action = np.argmax(Q[state])\n", + "\n", + " # переход в новое состояние и получение награды\n", + " new_state = action\n", + " reward = (\n", + " -1 if new_state != 2 else 0\n", + " ) # награда -1 за каждый шаг, 0 за достижение цели\n", + "\n", + " # обновление Q-значения\n", + " Q[state, action] = (1 - alpha) * Q[state, action] + alpha * (\n", + " reward + gamma * np.max(Q[new_state])\n", + " )\n", + "\n", + " state = new_state\n", + " set_state(Q, map, state)\n", + " \n", + " print(f\"Episode {episode + 1} is done\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\user\\Projects\\python\\ckmai\\.venv\\Lib\\site-packages\\gymnasium\\envs\\registration.py:642: UserWarning: \u001b[33mWARN: Overriding environment gymnasium_env/MyEnv-v1 already in registry.\u001b[0m\n", + " logger.warn(f\"Overriding environment {new_spec.id} already in registry.\")\n" + ] + } + ], + "source": [ + "import gymnasium as gym\n", + "\n", + "\n", + "class MyAgent:\n", + " def __init__(\n", + " self,\n", + " env: gym.Env,\n", + " learning_rate: float,\n", + " epsilon: float,\n", + " discount_factor,\n", + " ):\n", + " self.env = env\n", + " self.q_values = np.zeros([3, 3]) # type: ignore\n", + "\n", + " self.learning_rate = learning_rate\n", + " self.discount_factor = discount_factor\n", + "\n", + " self.epsilon = epsilon\n", + "\n", + " self.training_error = []\n", + "\n", + " def get_action(self, state: int) -> int:\n", + " if np.random.rand() < self.epsilon:\n", + " return self.env.action_space.sample()\n", + " else:\n", + " return int(np.argmax(self.q_values[state]))\n", + "\n", + " def update(\n", + " self,\n", + " state: int,\n", + " action: int,\n", + " reward: float,\n", + " terminated: bool,\n", + " new_state,\n", + " ):\n", + " future_q_value = (not terminated) * np.max(self.q_values[new_state])\n", + " current_q_value = self.q_values[state][action]\n", + " temporal_difference = (\n", + " reward + self.discount_factor * future_q_value - current_q_value\n", + " )\n", + "\n", + " self.q_values[state][action] = (\n", + " 1 - self.learning_rate\n", + " ) * current_q_value + self.learning_rate * (\n", + " reward + self.discount_factor * future_q_value\n", + " )\n", + "\n", + " self.training_error.append(temporal_difference)\n", + "\n", + " def render(self):\n", + " return \"\\n\".join(\n", + " [\", \".join([str(item) for item in row]) for row in self.q_values]\n", + " )\n", + "\n", + "\n", + "class MyEnv(gym.Env):\n", + " metadata = {\"render_modes\": [\"ansi\", \"rgb_array\"], \"render_fps\": 4}\n", + "\n", + " def __init__(self, render_mode=None):\n", + " self.size = 3\n", + "\n", + " self.render_mode = render_mode\n", + "\n", + " self._agent_location = 0\n", + " self._target_location = 2\n", + "\n", + " self.observation_space = gym.spaces.Dict(\n", + " {\n", + " \"agent\": gym.spaces.Discrete(1),\n", + " \"target\": gym.spaces.Discrete(1),\n", + " }\n", + " )\n", + "\n", + " self.action_space = gym.spaces.Discrete(2)\n", + " self._action_to_direction = {\n", + " 0: -1, # left\n", + " 1: 1, # right\n", + " }\n", + "\n", + " def _get_obs(self):\n", + " return {\n", + " \"agent\": self._agent_location,\n", + " \"target\": self._target_location,\n", + " }\n", + "\n", + " def _get_info(self):\n", + " return {\"distance\": abs(self._agent_location - self._target_location)}\n", + "\n", + " def reset(self, seed=None, options=None):\n", + " super().reset(seed=seed)\n", + " self._agent_location = 0\n", + " self._target_location = 2\n", + " observation = self._get_obs()\n", + " info = self._get_info()\n", + "\n", + " return observation, info\n", + "\n", + " def step(self, action):\n", + " direction = 0 if action == 2 else self._action_to_direction[action]\n", + " self._agent_location = np.clip(\n", + " self._agent_location + direction, 0, self.size - 1\n", + " )\n", + " if (action == 2):\n", + " self._agent_location = self._target_location\n", + "\n", + " terminated = self._agent_location == self._target_location\n", + " truncated = False\n", + " reward = -1 if not terminated else 0\n", + " observation = self._get_obs()\n", + " info = self._get_info()\n", + "\n", + " return observation, reward, terminated, truncated, info\n", + "\n", + " def render(self):\n", + " map = ['_', '_', '_']\n", + " map[self._agent_location] = \"A\"\n", + " map[self._target_location] = \"X\"\n", + " return \" \".join(map)\n", + "\n", + "\n", + "env_id = \"gymnasium_env/MyEnv-v1\"\n", + "gym.register(\n", + " id=env_id,\n", + " entry_point=MyEnv, # type: ignore\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A _ X\n", + "0.0, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Start episode 1\n", + "\n", + "Start action 1\n", + "A _ X\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "Start action 2\n", + "_ A X\n", + "-0.1, -0.1, 0.0\n", + "0.0, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "Start action 3\n", + "A _ X\n", + "-0.1, -0.1, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "Start action 4\n", + "_ _ X\n", + "-0.1, -0.1, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Episode 1 is done\n", + "\n", + "===> Start episode 2\n", + "\n", + "Start action 5\n", + "_ _ X\n", + "-0.1, -0.1, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Episode 2 is done\n", + "\n", + "===> Start episode 3\n", + "\n", + "Start action 6\n", + "_ _ X\n", + "-0.1, -0.1, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Episode 3 is done\n", + "\n", + "===> Start episode 4\n", + "\n", + "Start action 7\n", + "_ A X\n", + "-0.1, -0.19, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "Start action 8\n", + "_ _ X\n", + "-0.1, -0.19, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Episode 4 is done\n", + "\n", + "===> Start episode 5\n", + "\n", + "Start action 9\n", + "_ _ X\n", + "-0.1, -0.19, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Episode 5 is done\n", + "\n", + "===> Start episode 6\n", + "\n", + "Start action 10\n", + "_ _ X\n", + "-0.1, -0.19, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Episode 6 is done\n", + "\n", + "===> Start episode 7\n", + "\n", + "Start action 11\n", + "_ _ X\n", + "-0.1, -0.19, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Episode 7 is done\n", + "\n", + "===> Start episode 8\n", + "\n", + "Start action 12\n", + "_ _ X\n", + "-0.1, -0.19, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Episode 8 is done\n", + "\n", + "===> Start episode 9\n", + "\n", + "Start action 13\n", + "_ A X\n", + "-0.1, -0.271, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "Start action 14\n", + "_ _ X\n", + "-0.1, -0.271, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Episode 9 is done\n", + "\n", + "===> Start episode 10\n", + "\n", + "Start action 15\n", + "_ _ X\n", + "-0.1, -0.271, 0.0\n", + "-0.1, 0.0, 0.0\n", + "0.0, 0.0, 0.0\n", + "\n", + "===> Episode 10 is done\n" + ] + } + ], + "source": [ + "from gymnasium.wrappers import RecordEpisodeStatistics\n", + "\n", + "myenv = gym.make(env_id, render_mode=\"ansi\", max_episode_steps=max_steps)\n", + "myenv = RecordEpisodeStatistics(myenv, buffer_length=max_steps)\n", + "\n", + "agent = MyAgent(\n", + " env=myenv,\n", + " learning_rate=alpha,\n", + " epsilon=epsilon,\n", + " discount_factor=gamma,\n", + ")\n", + "\n", + "action_num = 0\n", + "for episode in range(max_steps):\n", + " obs, info = myenv.reset()\n", + " done = False\n", + "\n", + " if episode == 0:\n", + " print(myenv.render())\n", + " print(agent.render())\n", + "\n", + " print(f\"\\n===> Start episode {episode + 1}\")\n", + "\n", + " while not done:\n", + " print(f\"\\nStart action {action_num + 1}\")\n", + " action_num = action_num + 1\n", + "\n", + " action = agent.get_action(obs[\"agent\"])\n", + " next_obs, reward, terminated, truncated, info = myenv.step(action)\n", + "\n", + " # update the agent\n", + " agent.update(obs[\"agent\"], action, float(reward), terminated, next_obs[\"agent\"])\n", + "\n", + " # update if the environment is done and the current obs\n", + " done = terminated or truncated\n", + " obs = next_obs\n", + " print(myenv.render())\n", + " print(agent.render())\n", + "\n", + " print(f\"\\n===> Episode {episode + 1} is done\")\n", + "\n", + "myenv.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from matplotlib import pyplot as plt\n", + "\n", + "fig, axs = plt.subplots(1, 3, figsize=(20, 8))\n", + "\n", + "axs[0].plot(np.convolve(myenv.return_queue, np.ones(1))) # type: ignore\n", + "axs[0].set_title(\"Episode Rewards\")\n", + "axs[0].set_xlabel(\"Episode\")\n", + "axs[0].set_ylabel(\"Reward\")\n", + "\n", + "axs[1].plot(np.convolve(myenv.length_queue, np.ones(1))) # type: ignore\n", + "axs[1].set_title(\"Episode Lengths\")\n", + "axs[1].set_xlabel(\"Episode\")\n", + "axs[1].set_ylabel(\"Length\")\n", + "\n", + "axs[2].plot(np.convolve(agent.training_error, np.ones(1) * -1))\n", + "axs[2].set_title(\"Training Error\")\n", + "axs[2].set_xlabel(\"Action\")\n", + "axs[2].set_ylabel(\"Temporal Difference\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/poetry.lock b/poetry.lock index 7ae692c..5af1466 100644 --- a/poetry.lock +++ b/poetry.lock @@ -606,6 +606,17 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +[[package]] +name = "farama-notifications" +version = "0.0.4" +description = "Notifications for all Farama Foundation maintained libraries." +optional = false +python-versions = "*" +files = [ + {file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"}, + {file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"}, +] + [[package]] name = "fastjsonschema" version = "2.21.1" @@ -739,6 +750,36 @@ files = [ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] +[[package]] +name = "gymnasium" +version = "1.0.0" +description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." +optional = false +python-versions = ">=3.8" +files = [ + {file = "gymnasium-1.0.0-py3-none-any.whl", hash = "sha256:b6f40e1e24c5bd419361e1a5b86a9117d2499baecc3a660d44dfff4c465393ad"}, + {file = "gymnasium-1.0.0.tar.gz", hash = "sha256:9d2b66f30c1b34fe3c2ce7fae65ecf365d0e9982d2b3d860235e773328a3b403"}, +] + +[package.dependencies] +cloudpickle = ">=1.2.0" +farama-notifications = ">=0.0.1" +numpy = ">=1.21.0" +typing-extensions = ">=4.3.0" + +[package.extras] +all = ["ale-py (>=0.9)", "box2d-py (==2.3.5)", "cython (<3)", "flax (>=0.5.0)", "imageio (>=2.14.1)", "jax (>=0.4.0)", "jaxlib (>=0.4.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.1.5)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (>=2.1.3)", "swig (==4.*)", "torch (>=1.0.0)"] +atari = ["ale-py (>=0.9)"] +box2d = ["box2d-py (==2.3.5)", "pygame (>=2.1.3)", "swig (==4.*)"] +classic-control = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"] +jax = ["flax (>=0.5.0)", "jax (>=0.4.0)", "jaxlib (>=0.4.0)"] +mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.1.5)"] +mujoco-py = ["cython (<3)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "mujoco-py (>=2.1,<2.2)"] +other = ["matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)"] +testing = ["dill (>=0.3.7)", "pytest (==7.1.3)", "scipy (>=1.7.3)"] +torch = ["torch (>=1.0.0)"] +toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"] + [[package]] name = "h11" version = "0.14.0" @@ -2927,6 +2968,17 @@ files = [ {file = "types_python_dateutil-2.9.0.20241003-py3-none-any.whl", hash = "sha256:250e1d8e80e7bbc3a6c99b907762711d1a1cdd00e978ad39cb5940f6f0a87f3d"}, ] +[[package]] +name = "typing-extensions" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, +] + [[package]] name = "tzdata" version = "2024.2" @@ -3058,4 +3110,4 @@ updater = ["alteryx-open-src-update-checker (>=3.1.0)"] [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "14251a2aa051d0453baa081f3c5967a8fc2d57d32f379f3b899973001543c094" +content-hash = "299cd82afa9f00a090d3ef039f4dabc019809a8a0f8b4111a3086133dec02d69" diff --git a/pyproject.toml b/pyproject.toml index d84d8aa..4ef6233 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ pandas = "^2.2.2" matplotlib = "^3.9.2" imbalanced-learn = "^0.12.3" featuretools = "^1.31.0" +gymnasium = "^1.0.0" [build-system]