{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "507915ea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.9.2\n" ] } ], "source": [ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", "import keras\n", "\n", "print(keras.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "id": "e0043e5c", "metadata": {}, "outputs": [], "source": [ "from keras.api.datasets import imdb\n", "import os\n", "\n", "unique_words = 10000\n", "max_length = 100\n", "\n", "output_dir = \"tmp\"\n", "if not os.path.exists(output_dir):\n", " os.makedirs(output_dir)\n", "\n", "(X_train, y_train), (X_valid, y_valid) = imdb.load_data(num_words=unique_words)" ] }, { "cell_type": "code", "execution_count": 3, "id": "131e125a", "metadata": {}, "outputs": [], "source": [ "from keras.api.preprocessing.sequence import pad_sequences\n", "\n", "X_train = pad_sequences(X_train, maxlen=max_length, padding=\"pre\", truncating=\"pre\", value=0)\n", "X_valid = pad_sequences(X_valid, maxlen=max_length, padding=\"pre\", truncating=\"pre\", value=0)" ] }, { "cell_type": "code", "execution_count": 4, "id": "1e3fb0ec", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ embedding (Embedding) │ (None, 100, 64) │ 640,000 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ spatial_dropout1d │ (None, 100, 64) │ 0 │\n", "│ (SpatialDropout1D) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ simple_rnn (SimpleRNN) │ (None, 256) │ 82,176 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense (Dense) │ (None, 1) │ 257 │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m640,000\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ spatial_dropout1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "│ (\u001b[38;5;33mSpatialDropout1D\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ simple_rnn (\u001b[38;5;33mSimpleRNN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m82,176\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m257\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 722,433 (2.76 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m722,433\u001b[0m (2.76 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 722,433 (2.76 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m722,433\u001b[0m (2.76 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from keras.api.models import Sequential\n", "from keras.api.layers import InputLayer, Embedding, SpatialDropout1D, SimpleRNN, Dense\n", "\n", "rnn_model = Sequential()\n", "rnn_model.add(InputLayer(shape=(max_length,), dtype=\"float32\"))\n", "rnn_model.add(Embedding(unique_words, 64))\n", "rnn_model.add(SpatialDropout1D(0.2))\n", "rnn_model.add(SimpleRNN(256, dropout=0.2))\n", "rnn_model.add(Dense(1, activation=\"sigmoid\"))\n", "\n", "rnn_model.summary()" ] }, { "cell_type": "code", "execution_count": 5, "id": "11236198", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 68ms/step - accuracy: 0.5207 - loss: 0.6994 - val_accuracy: 0.5872 - val_loss: 0.6700\n", "Epoch 2/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 64ms/step - accuracy: 0.6188 - loss: 0.6423 - val_accuracy: 0.6368 - val_loss: 0.6183\n", "Epoch 3/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 64ms/step - accuracy: 0.7102 - loss: 0.5539 - val_accuracy: 0.6463 - val_loss: 0.6441\n", "Epoch 4/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 65ms/step - accuracy: 0.7746 - loss: 0.4737 - val_accuracy: 0.7338 - val_loss: 0.5681\n", "Epoch 5/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 65ms/step - accuracy: 0.8127 - loss: 0.4065 - val_accuracy: 0.6766 - val_loss: 0.6422\n", "Epoch 6/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 67ms/step - accuracy: 0.8613 - loss: 0.3246 - val_accuracy: 0.7152 - val_loss: 0.6385\n", "Epoch 7/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 66ms/step - accuracy: 0.8923 - loss: 0.2667 - val_accuracy: 0.7202 - val_loss: 0.6684\n", "Epoch 8/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 67ms/step - accuracy: 0.9032 - loss: 0.2335 - val_accuracy: 0.7296 - val_loss: 0.6990\n", "Epoch 9/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 70ms/step - accuracy: 0.9118 - loss: 0.2143 - val_accuracy: 0.6944 - val_loss: 0.7852\n", "Epoch 10/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 70ms/step - accuracy: 0.9205 - loss: 0.2022 - val_accuracy: 0.7359 - val_loss: 0.7074\n", "Epoch 11/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 69ms/step - accuracy: 0.9418 - loss: 0.1523 - val_accuracy: 0.7127 - val_loss: 0.8376\n", "Epoch 12/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 68ms/step - accuracy: 0.9440 - loss: 0.1462 - val_accuracy: 0.7288 - val_loss: 0.8534\n", "Epoch 13/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 69ms/step - accuracy: 0.9344 - loss: 0.1649 - val_accuracy: 0.7157 - val_loss: 0.8279\n", "Epoch 14/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 70ms/step - accuracy: 0.9201 - loss: 0.1998 - val_accuracy: 0.6386 - val_loss: 1.1343\n", "Epoch 15/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 68ms/step - accuracy: 0.9301 - loss: 0.1774 - val_accuracy: 0.7041 - val_loss: 0.9636\n", "Epoch 16/16\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 69ms/step - accuracy: 0.9616 - loss: 0.1055 - val_accuracy: 0.6747 - val_loss: 1.1050\n" ] }, { "data": { "text/plain": [ "