29 KiB
29 KiB
In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
print(keras.__version__)
In [2]:
from keras.api.datasets import imdb
import os
unique_words = 10000
max_length = 100
output_dir = "tmp"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
(X_train, y_train), (X_valid, y_valid) = imdb.load_data(num_words=unique_words)
In [3]:
from keras.api.preprocessing.sequence import pad_sequences
X_train = pad_sequences(X_train, maxlen=max_length, padding="pre", truncating="pre", value=0)
X_valid = pad_sequences(X_valid, maxlen=max_length, padding="pre", truncating="pre", value=0)
In [4]:
from keras.api.models import Sequential
from keras.api.layers import InputLayer, Embedding, SpatialDropout1D, LSTM, Dense
lstm_model = Sequential()
lstm_model.add(InputLayer(shape=(max_length,), dtype="float32"))
lstm_model.add(Embedding(unique_words, 64))
lstm_model.add(SpatialDropout1D(0.2))
lstm_model.add(LSTM(256, dropout=0.2))
lstm_model.add(Dense(1, activation="sigmoid"))
lstm_model.summary()
In [5]:
from keras.api.callbacks import ModelCheckpoint
lstm_model.compile(
loss="binary_crossentropy",
optimizer="adam",
metrics=["accuracy"],
)
lstm_model.fit(
X_train,
y_train,
batch_size=128,
epochs=4,
validation_data=(X_valid, y_valid),
callbacks=[ModelCheckpoint(filepath=output_dir + "/lstm_weights.{epoch:02d}.keras")],
)
Out[5]:
In [6]:
lstm_model.load_weights(output_dir + "/lstm_weights.02.keras")
lstm_model.evaluate(X_valid, y_valid)
Out[6]:
In [7]:
import matplotlib.pyplot as plt
plt.hist(lstm_model.predict(X_valid))
_ = plt.axvline(x=0.5, color="orange")