{ "cells": [ { "cell_type": "markdown", "id": "c2308ffe", "metadata": {}, "source": [ "#### Инициализация Keras\n", "\n", "torch был заменен на jax, так как с torch рекуррентные сети не работали" ] }, { "cell_type": "code", "execution_count": 1, "id": "507915ea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.9.2\n" ] } ], "source": [ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", "import keras\n", "\n", "print(keras.__version__)" ] }, { "cell_type": "markdown", "id": "8e4a9a71", "metadata": {}, "source": [ "#### Загрузка данных для классификации с помощью глубоких сетей\n", "\n", "В качестве набора данных используется набор отзывов к фильмам с сайта IMDB.\n", "\n", "Набор включает 50 000 отзывов, половина из которых находится в обучающем наборе данных (x_train), а половина - в тестовом (x_valid). \n", "\n", "Метки (y_train и y_valid) имеют бинарный характер и назначены в соответствии с этими 10-балльными оценками:\n", "- отзывы с четырьмя звездами или меньше считаются отрицательным (y = 0);\n", "- отзывы с семью звездами или больше считаются положительными (y = 1);\n", "- умеренные отзывы — с пятью или шестью звездами — не включались в набор данных, что упрощает задачу бинарной классификации.\n", "\n", "Данные уже предобработаны для простоты работы с ними.\n", "\n", "unique_words - в векторное пространство включается только слова, которые встречаются в корпусе не менее 10 000 раз.\n", "\n", "max_length - максимальная длина отзыва (если больше, то обрезается, если меньше, то дополняется \"пустыми\" словами)." ] }, { "cell_type": "code", "execution_count": 2, "id": "e0043e5c", "metadata": {}, "outputs": [], "source": [ "from keras.api.datasets import imdb\n", "import os\n", "\n", "unique_words = 10000\n", "max_length = 100\n", "\n", "output_dir = \"tmp\"\n", "if not os.path.exists(output_dir):\n", " os.makedirs(output_dir)\n", "\n", "(X_train, y_train), (X_valid, y_valid) = imdb.load_data(num_words=unique_words)" ] }, { "cell_type": "markdown", "id": "c58423e9", "metadata": {}, "source": [ "#### Приведение отзывов к длине max_length (100)\n", "\n", "padding и truncating - дополнение и обрезка отзывов начинается с начала (учитывается специфика затухания градиента в рекуррентных сетях)" ] }, { "cell_type": "code", "execution_count": 3, "id": "131e125a", "metadata": {}, "outputs": [], "source": [ "from keras.api.preprocessing.sequence import pad_sequences\n", "\n", "X_train = pad_sequences(X_train, maxlen=max_length, padding=\"pre\", truncating=\"pre\", value=0)\n", "X_valid = pad_sequences(X_valid, maxlen=max_length, padding=\"pre\", truncating=\"pre\", value=0)" ] }, { "cell_type": "markdown", "id": "7db364f4", "metadata": {}, "source": [ "#### Формирование архитектуры глубокой рекуррентной двунаправленной LSTM сети\n", "\n", "\n", "Первый слой (Embedding) выполняет векторизацию" ] }, { "cell_type": "code", "execution_count": 4, "id": "1e3fb0ec", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ embedding (Embedding) │ (None, 100, 64) │ 640,000 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ spatial_dropout1d │ (None, 100, 64) │ 0 │\n", "│ (SpatialDropout1D) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ bidirectional (Bidirectional) │ (None, 512) │ 657,408 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense (Dense) │ (None, 1) │ 513 │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m640,000\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ spatial_dropout1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "│ (\u001b[38;5;33mSpatialDropout1D\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ bidirectional (\u001b[38;5;33mBidirectional\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m657,408\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m513\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 1,297,921 (4.95 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,297,921\u001b[0m (4.95 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 1,297,921 (4.95 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,297,921\u001b[0m (4.95 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from keras.api.models import Sequential\n", "from keras.api.layers import InputLayer, Embedding, SpatialDropout1D, LSTM, Bidirectional, Dense\n", "\n", "blstm_model = Sequential()\n", "blstm_model.add(InputLayer(shape=(max_length,), dtype=\"float32\"))\n", "blstm_model.add(Embedding(unique_words, 64))\n", "blstm_model.add(SpatialDropout1D(0.2))\n", "blstm_model.add(Bidirectional(LSTM(256, dropout=0.2)))\n", "blstm_model.add(Dense(1, activation=\"sigmoid\"))\n", "\n", "blstm_model.summary()" ] }, { "cell_type": "markdown", "id": "3a826105", "metadata": {}, "source": [ "#### Обучение модели\n", "\n", "Веса модели сохраняются в каталог tmp после каждой эпохи обучения с помощью callback-параметра\n", "\n", "В дальнейшем веса можно загрузить" ] }, { "cell_type": "code", "execution_count": 5, "id": "11236198", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/6\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m134s\u001b[0m 682ms/step - accuracy: 0.6565 - loss: 0.6039 - val_accuracy: 0.8432 - val_loss: 0.3756\n", "Epoch 2/6\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m166s\u001b[0m 848ms/step - accuracy: 0.8841 - loss: 0.2820 - val_accuracy: 0.8425 - val_loss: 0.3577\n", "Epoch 3/6\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m176s\u001b[0m 902ms/step - accuracy: 0.9148 - loss: 0.2238 - val_accuracy: 0.8459 - val_loss: 0.3929\n", "Epoch 4/6\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m171s\u001b[0m 875ms/step - accuracy: 0.9375 - loss: 0.1744 - val_accuracy: 0.8434 - val_loss: 0.3572\n", "Epoch 5/6\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m155s\u001b[0m 790ms/step - accuracy: 0.9466 - loss: 0.1520 - val_accuracy: 0.8385 - val_loss: 0.4029\n", "Epoch 6/6\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m158s\u001b[0m 807ms/step - accuracy: 0.9584 - loss: 0.1172 - val_accuracy: 0.8337 - val_loss: 0.4419\n" ] }, { "data": { "text/plain": [ "