33 KiB
33 KiB
In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
print(keras.__version__)
In [2]:
from keras.api.datasets import imdb
import os
unique_words = 10000
max_length = 100
output_dir = "tmp"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
(X_train, y_train), (X_valid, y_valid) = imdb.load_data(num_words=unique_words)
In [3]:
from keras.api.preprocessing.sequence import pad_sequences
X_train = pad_sequences(X_train, maxlen=max_length, padding="pre", truncating="pre", value=0)
X_valid = pad_sequences(X_valid, maxlen=max_length, padding="pre", truncating="pre", value=0)
In [4]:
from keras.api.models import Sequential
from keras.api.layers import InputLayer, Embedding, SpatialDropout1D, SimpleRNN, Dense
rnn_model = Sequential()
rnn_model.add(InputLayer(shape=(max_length,), dtype="float32"))
rnn_model.add(Embedding(unique_words, 64))
rnn_model.add(SpatialDropout1D(0.2))
rnn_model.add(SimpleRNN(256, dropout=0.2))
rnn_model.add(Dense(1, activation="sigmoid"))
rnn_model.summary()
In [5]:
from keras.api.callbacks import ModelCheckpoint
rnn_model.compile(
loss="binary_crossentropy",
optimizer="adam",
metrics=["accuracy"],
)
rnn_model.fit(
X_train,
y_train,
batch_size=128,
epochs=16,
validation_data=(X_valid, y_valid),
callbacks=[ModelCheckpoint(filepath=output_dir + "/rnn_weights.{epoch:02d}.keras")],
)
Out[5]:
In [6]:
rnn_model.load_weights(output_dir + "/rnn_weights.10.keras")
rnn_model.evaluate(X_valid, y_valid)
Out[6]:
In [7]:
import matplotlib.pyplot as plt
plt.hist(rnn_model.predict(X_valid))
_ = plt.axvline(x=0.5, color="orange")