388 lines
30 KiB
Plaintext
388 lines
30 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c2308ffe",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Инициализация Keras\n",
|
||
"\n",
|
||
"torch был заменен на jax, так как с torch рекуррентные сети не работали"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "507915ea",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"3.9.2\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import os\n",
|
||
"\n",
|
||
"os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
|
||
"import keras\n",
|
||
"\n",
|
||
"print(keras.__version__)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8e4a9a71",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Загрузка данных для классификации с помощью глубоких сетей\n",
|
||
"\n",
|
||
"В качестве набора данных используется набор отзывов к фильмам с сайта IMDB.\n",
|
||
"\n",
|
||
"Набор включает 50 000 отзывов, половина из которых находится в обучающем наборе данных (x_train), а половина - в тестовом (x_valid). \n",
|
||
"\n",
|
||
"Метки (y_train и y_valid) имеют бинарный характер и назначены в соответствии с этими 10-балльными оценками:\n",
|
||
"- отзывы с четырьмя звездами или меньше считаются отрицательным (y = 0);\n",
|
||
"- отзывы с семью звездами или больше считаются положительными (y = 1);\n",
|
||
"- умеренные отзывы — с пятью или шестью звездами — не включались в набор данных, что упрощает задачу бинарной классификации.\n",
|
||
"\n",
|
||
"Данные уже предобработаны для простоты работы с ними.\n",
|
||
"\n",
|
||
"unique_words - в векторное пространство включается только слова, которые встречаются в корпусе не менее 10 000 раз.\n",
|
||
"\n",
|
||
"max_length - максимальная длина отзыва (если больше, то обрезается, если меньше, то дополняется \"пустыми\" словами)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "e0043e5c",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from keras.api.datasets import imdb\n",
|
||
"import os\n",
|
||
"\n",
|
||
"unique_words = 10000\n",
|
||
"max_length = 100\n",
|
||
"\n",
|
||
"output_dir = \"tmp\"\n",
|
||
"if not os.path.exists(output_dir):\n",
|
||
" os.makedirs(output_dir)\n",
|
||
"\n",
|
||
"(X_train, y_train), (X_valid, y_valid) = imdb.load_data(num_words=unique_words)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c58423e9",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Приведение отзывов к длине max_length (100)\n",
|
||
"\n",
|
||
"padding и truncating - дополнение и обрезка отзывов начинается с начала (учитывается специфика затухания градиента в рекуррентных сетях)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "131e125a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from keras.api.preprocessing.sequence import pad_sequences\n",
|
||
"\n",
|
||
"X_train = pad_sequences(X_train, maxlen=max_length, padding=\"pre\", truncating=\"pre\", value=0)\n",
|
||
"X_valid = pad_sequences(X_valid, maxlen=max_length, padding=\"pre\", truncating=\"pre\", value=0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "7db364f4",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Формирование архитектуры глубокой рекуррентной двунаправленной LSTM сети\n",
|
||
"\n",
|
||
"\n",
|
||
"Первый слой (Embedding) выполняет векторизацию"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "1e3fb0ec",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
|
||
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
|
||
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
|
||
"│ embedding (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Embedding</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">640,000</span> │\n",
|
||
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
||
"│ spatial_dropout1d │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">100</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
|
||
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">SpatialDropout1D</span>) │ │ │\n",
|
||
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
||
"│ bidirectional (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Bidirectional</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">657,408</span> │\n",
|
||
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
||
"│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">513</span> │\n",
|
||
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
|
||
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
|
||
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
|
||
"│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m640,000\u001b[0m │\n",
|
||
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
||
"│ spatial_dropout1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
|
||
"│ (\u001b[38;5;33mSpatialDropout1D\u001b[0m) │ │ │\n",
|
||
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
||
"│ bidirectional (\u001b[38;5;33mBidirectional\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m657,408\u001b[0m │\n",
|
||
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
||
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m513\u001b[0m │\n",
|
||
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,297,921</span> (4.95 MB)\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,297,921\u001b[0m (4.95 MB)\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,297,921</span> (4.95 MB)\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,297,921\u001b[0m (4.95 MB)\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from keras.api.models import Sequential\n",
|
||
"from keras.api.layers import InputLayer, Embedding, SpatialDropout1D, LSTM, Bidirectional, Dense\n",
|
||
"\n",
|
||
"blstm_model = Sequential()\n",
|
||
"blstm_model.add(InputLayer(shape=(max_length,), dtype=\"float32\"))\n",
|
||
"blstm_model.add(Embedding(unique_words, 64))\n",
|
||
"blstm_model.add(SpatialDropout1D(0.2))\n",
|
||
"blstm_model.add(Bidirectional(LSTM(256, dropout=0.2)))\n",
|
||
"blstm_model.add(Dense(1, activation=\"sigmoid\"))\n",
|
||
"\n",
|
||
"blstm_model.summary()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "3a826105",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Обучение модели\n",
|
||
"\n",
|
||
"Веса модели сохраняются в каталог tmp после каждой эпохи обучения с помощью callback-параметра\n",
|
||
"\n",
|
||
"В дальнейшем веса можно загрузить"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "11236198",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 1/6\n",
|
||
"\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m134s\u001b[0m 682ms/step - accuracy: 0.6565 - loss: 0.6039 - val_accuracy: 0.8432 - val_loss: 0.3756\n",
|
||
"Epoch 2/6\n",
|
||
"\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m166s\u001b[0m 848ms/step - accuracy: 0.8841 - loss: 0.2820 - val_accuracy: 0.8425 - val_loss: 0.3577\n",
|
||
"Epoch 3/6\n",
|
||
"\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m176s\u001b[0m 902ms/step - accuracy: 0.9148 - loss: 0.2238 - val_accuracy: 0.8459 - val_loss: 0.3929\n",
|
||
"Epoch 4/6\n",
|
||
"\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m171s\u001b[0m 875ms/step - accuracy: 0.9375 - loss: 0.1744 - val_accuracy: 0.8434 - val_loss: 0.3572\n",
|
||
"Epoch 5/6\n",
|
||
"\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m155s\u001b[0m 790ms/step - accuracy: 0.9466 - loss: 0.1520 - val_accuracy: 0.8385 - val_loss: 0.4029\n",
|
||
"Epoch 6/6\n",
|
||
"\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m158s\u001b[0m 807ms/step - accuracy: 0.9584 - loss: 0.1172 - val_accuracy: 0.8337 - val_loss: 0.4419\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<keras.src.callbacks.history.History at 0x3455e57f0>"
|
||
]
|
||
},
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from keras.api.callbacks import ModelCheckpoint\n",
|
||
"\n",
|
||
"blstm_model.compile(\n",
|
||
" loss=\"binary_crossentropy\",\n",
|
||
" optimizer=\"adam\",\n",
|
||
" metrics=[\"accuracy\"],\n",
|
||
")\n",
|
||
"\n",
|
||
"blstm_model.fit(\n",
|
||
" X_train,\n",
|
||
" y_train,\n",
|
||
" batch_size=128,\n",
|
||
" epochs=6,\n",
|
||
" validation_data=(X_valid, y_valid),\n",
|
||
" callbacks=[ModelCheckpoint(filepath=output_dir + \"/blstm_weights.{epoch:02d}.keras\")],\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "a47a8ff6",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Загрузка лучшей модели и оценка ее качества\n",
|
||
"\n",
|
||
"Качество модели - 84.6 %."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "94987771",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\u001b[1m782/782\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m86s\u001b[0m 110ms/step - accuracy: 0.8449 - loss: 0.3976\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[0.3929494023323059, 0.8458799719810486]"
|
||
]
|
||
},
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"blstm_model.load_weights(output_dir + \"/blstm_weights.03.keras\")\n",
|
||
"blstm_model.evaluate(X_valid, y_valid)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "7001f712",
|
||
"metadata": {},
|
||
"source": [
|
||
"#### Визуализация распределения вероятностей результатов модели на валидационной выборке"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "8965a612",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\u001b[1m782/782\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m85s\u001b[0m 108ms/step\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"plt.hist(blstm_model.predict(X_valid))\n",
|
||
"_ = plt.axvline(x=0.5, color=\"orange\")"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv (3.12.10)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.10"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|