{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "507915ea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3.9.2\n" ] } ], "source": [ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", "import keras\n", "\n", "print(keras.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "id": "e0043e5c", "metadata": {}, "outputs": [], "source": [ "from keras.api.datasets import imdb\n", "import os\n", "\n", "unique_words = 10000\n", "max_length = 100\n", "\n", "output_dir = \"tmp\"\n", "if not os.path.exists(output_dir):\n", " os.makedirs(output_dir)\n", "\n", "(X_train, y_train), (X_valid, y_valid) = imdb.load_data(num_words=unique_words)" ] }, { "cell_type": "code", "execution_count": 3, "id": "131e125a", "metadata": {}, "outputs": [], "source": [ "from keras.api.preprocessing.sequence import pad_sequences\n", "\n", "X_train = pad_sequences(X_train, maxlen=max_length, padding=\"pre\", truncating=\"pre\", value=0)\n", "X_valid = pad_sequences(X_valid, maxlen=max_length, padding=\"pre\", truncating=\"pre\", value=0)" ] }, { "cell_type": "code", "execution_count": 4, "id": "1e3fb0ec", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
       "│ embedding (Embedding)           │ (None, 100, 64)        │       640,000 │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ spatial_dropout1d               │ (None, 100, 64)        │             0 │\n",
       "│ (SpatialDropout1D)              │                        │               │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ lstm (LSTM)                     │ (None, 256)            │       328,704 │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dense (Dense)                   │ (None, 1)              │           257 │\n",
       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m640,000\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ spatial_dropout1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", "│ (\u001b[38;5;33mSpatialDropout1D\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ lstm (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m328,704\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m257\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 968,961 (3.70 MB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m968,961\u001b[0m (3.70 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 968,961 (3.70 MB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m968,961\u001b[0m (3.70 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from keras.api.models import Sequential\n", "from keras.api.layers import InputLayer, Embedding, SpatialDropout1D, LSTM, Dense\n", "\n", "lstm_model = Sequential()\n", "lstm_model.add(InputLayer(shape=(max_length,), dtype=\"float32\"))\n", "lstm_model.add(Embedding(unique_words, 64))\n", "lstm_model.add(SpatialDropout1D(0.2))\n", "lstm_model.add(LSTM(256, dropout=0.2))\n", "lstm_model.add(Dense(1, activation=\"sigmoid\"))\n", "\n", "lstm_model.summary()" ] }, { "cell_type": "code", "execution_count": 5, "id": "11236198", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/4\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m42s\u001b[0m 214ms/step - accuracy: 0.6435 - loss: 0.6105 - val_accuracy: 0.8497 - val_loss: 0.3466\n", "Epoch 2/4\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m46s\u001b[0m 231ms/step - accuracy: 0.8819 - loss: 0.2947 - val_accuracy: 0.8527 - val_loss: 0.3380\n", "Epoch 3/4\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m53s\u001b[0m 273ms/step - accuracy: 0.9121 - loss: 0.2282 - val_accuracy: 0.8472 - val_loss: 0.3587\n", "Epoch 4/4\n", "\u001b[1m196/196\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m64s\u001b[0m 325ms/step - accuracy: 0.9299 - loss: 0.1847 - val_accuracy: 0.8332 - val_loss: 0.3998\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from keras.api.callbacks import ModelCheckpoint\n", "\n", "lstm_model.compile(\n", " loss=\"binary_crossentropy\",\n", " optimizer=\"adam\",\n", " metrics=[\"accuracy\"],\n", ")\n", "\n", "lstm_model.fit(\n", " X_train,\n", " y_train,\n", " batch_size=128,\n", " epochs=4,\n", " validation_data=(X_valid, y_valid),\n", " callbacks=[ModelCheckpoint(filepath=output_dir + \"/lstm_weights.{epoch:02d}.keras\")],\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "94987771", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m782/782\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m39s\u001b[0m 50ms/step - accuracy: 0.8509 - loss: 0.3421\n" ] }, { "data": { "text/plain": [ "[0.33803924918174744, 0.8527200222015381]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lstm_model.load_weights(output_dir + \"/lstm_weights.02.keras\")\n", "lstm_model.evaluate(X_valid, y_valid)" ] }, { "cell_type": "code", "execution_count": 7, "id": "8965a612", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m782/782\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m36s\u001b[0m 47ms/step\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.hist(lstm_model.predict(X_valid))\n", "_ = plt.axvline(x=0.5, color=\"orange\")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv (3.12.10)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 5 }