32 KiB
Инициализация Keras¶
torch был заменен на jax, так как с torch рекуррентные сети не работали
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
print(keras.__version__)
Загрузка данных для классификации с помощью глубоких сетей¶
В качестве набора данных используется набор отзывов к фильмам с сайта IMDB.
Набор включает 50 000 отзывов, половина из которых находится в обучающем наборе данных (x_train), а половина - в тестовом (x_valid).
Метки (y_train и y_valid) имеют бинарный характер и назначены в соответствии с этими 10-балльными оценками:
- отзывы с четырьмя звездами или меньше считаются отрицательным (y = 0);
- отзывы с семью звездами или больше считаются положительными (y = 1);
- умеренные отзывы — с пятью или шестью звездами — не включались в набор данных, что упрощает задачу бинарной классификации.
Данные уже предобработаны для простоты работы с ними.
unique_words - в векторное пространство включается только слова, которые встречаются в корпусе не менее 10 000 раз.
max_length - максимальная длина отзыва (если больше, то обрезается, если меньше, то дополняется "пустыми" словами).
from keras.api.datasets import imdb
import os
unique_words = 10000
max_length = 100
output_dir = "tmp"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
(X_train, y_train), (X_valid, y_valid) = imdb.load_data(num_words=unique_words)
Приведение отзывов к длине max_length (100)¶
padding и truncating - дополнение и обрезка отзывов начинается с начала (учитывается специфика затухания градиента в рекуррентных сетях)
from keras.api.preprocessing.sequence import pad_sequences
X_train = pad_sequences(X_train, maxlen=max_length, padding="pre", truncating="pre", value=0)
X_valid = pad_sequences(X_valid, maxlen=max_length, padding="pre", truncating="pre", value=0)
Формирование архитектуры глубокой рекуррентной LSTM сети¶
Первый слой (Embedding) выполняет векторизацию
from keras.api.models import Sequential
from keras.api.layers import InputLayer, Embedding, SpatialDropout1D, LSTM, Dense
lstm_model = Sequential()
lstm_model.add(InputLayer(shape=(max_length,), dtype="float32"))
lstm_model.add(Embedding(unique_words, 64))
lstm_model.add(SpatialDropout1D(0.2))
lstm_model.add(LSTM(256, dropout=0.2))
lstm_model.add(Dense(1, activation="sigmoid"))
lstm_model.summary()
Обучение модели¶
Веса модели сохраняются в каталог tmp после каждой эпохи обучения с помощью callback-параметра
В дальнейшем веса можно загрузить
from keras.api.callbacks import ModelCheckpoint
lstm_model.compile(
loss="binary_crossentropy",
optimizer="adam",
metrics=["accuracy"],
)
lstm_model.fit(
X_train,
y_train,
batch_size=128,
epochs=4,
validation_data=(X_valid, y_valid),
callbacks=[ModelCheckpoint(filepath=output_dir + "/lstm_weights.{epoch:02d}.keras")],
)
Загрузка лучшей модели и оценка ее качества¶
Качество модели - 85.3 %.
lstm_model.load_weights(output_dir + "/lstm_weights.02.keras")
lstm_model.evaluate(X_valid, y_valid)
Визуализация распределения вероятностей результатов модели на валидационной выборке¶
import matplotlib.pyplot as plt
plt.hist(lstm_model.predict(X_valid))
_ = plt.axvline(x=0.5, color="orange")