ckiias/lec3-1-lenet.ipynb

50 KiB
Raw Blame History

Инициализация Keras

In [1]:
import os

os.environ["KERAS_BACKEND"] = "torch"
import keras

print(keras.__version__)
3.9.2

Загрузка набора данных для задачи классификации

База данных MNIST (сокращение от "Modified National Institute of Standards and Technology") — объёмная база данных образцов рукописного написания цифр. База данных является стандартом, предложенным Национальным институтом стандартов и технологий США с целью обучения и сопоставления методов распознавания изображений с помощью машинного обучения в первую очередь на основе нейронных сетей. Данные состоят из заранее подготовленных примеров изображений, на основе которых проводится обучение и тестирование систем.

База данных MNIST содержит 60000 изображений для обучения и 10000 изображений для тестирования.

In [2]:
from keras.api.datasets import mnist

(X_train, y_train), (X_valid, y_valid) = mnist.load_data()

Предобработка данных

Количество классов - 10 (от 0 до 9).

Все изображения из X трансформируются в матрицы 28*28 признака и нормализуются.

Для целевых признаков применяется унитарное кодирование в бинарные векторы длиной 10 (нормализация).

Четвертое измерение в reshape определяет количество цветовых каналов.

Используется только один канал, так как изображения не цветные.

Для цветных изображений следует использовать три канала (RGB).

In [3]:
n_classes = 10

X_train = X_train.reshape(60000, 28, 28, 1).astype("float32") / 255
X_valid = X_valid.reshape(10000, 28, 28, 1).astype("float32") / 255
y_train = keras.utils.to_categorical(y_train, n_classes)
y_valid = keras.utils.to_categorical(y_valid, n_classes)

display(X_train[0])
display(y_train[0])
array([[[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.01176471],
        [0.07058824],
        [0.07058824],
        [0.07058824],
        [0.49411765],
        [0.53333336],
        [0.6862745 ],
        [0.10196079],
        [0.6509804 ],
        [1.        ],
        [0.96862745],
        [0.49803922],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.11764706],
        [0.14117648],
        [0.36862746],
        [0.6039216 ],
        [0.6666667 ],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.88235295],
        [0.6745098 ],
        [0.99215686],
        [0.9490196 ],
        [0.7647059 ],
        [0.2509804 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.19215687],
        [0.93333334],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.9843137 ],
        [0.3647059 ],
        [0.32156864],
        [0.32156864],
        [0.21960784],
        [0.15294118],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.07058824],
        [0.85882354],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.7764706 ],
        [0.7137255 ],
        [0.96862745],
        [0.94509804],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.3137255 ],
        [0.6117647 ],
        [0.41960785],
        [0.99215686],
        [0.99215686],
        [0.8039216 ],
        [0.04313726],
        [0.        ],
        [0.16862746],
        [0.6039216 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.05490196],
        [0.00392157],
        [0.6039216 ],
        [0.99215686],
        [0.3529412 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.54509807],
        [0.99215686],
        [0.74509805],
        [0.00784314],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.04313726],
        [0.74509805],
        [0.99215686],
        [0.27450982],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.13725491],
        [0.94509804],
        [0.88235295],
        [0.627451  ],
        [0.42352942],
        [0.00392157],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.31764707],
        [0.9411765 ],
        [0.99215686],
        [0.99215686],
        [0.46666667],
        [0.09803922],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.1764706 ],
        [0.7294118 ],
        [0.99215686],
        [0.99215686],
        [0.5882353 ],
        [0.10588235],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.0627451 ],
        [0.3647059 ],
        [0.9882353 ],
        [0.99215686],
        [0.73333335],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.9764706 ],
        [0.99215686],
        [0.9764706 ],
        [0.2509804 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.18039216],
        [0.50980395],
        [0.7176471 ],
        [0.99215686],
        [0.99215686],
        [0.8117647 ],
        [0.00784314],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.15294118],
        [0.5803922 ],
        [0.8980392 ],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.98039216],
        [0.7137255 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.09411765],
        [0.44705883],
        [0.8666667 ],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.7882353 ],
        [0.30588236],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.09019608],
        [0.25882354],
        [0.8352941 ],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.7764706 ],
        [0.31764707],
        [0.00784314],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.07058824],
        [0.67058825],
        [0.85882354],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.7647059 ],
        [0.3137255 ],
        [0.03529412],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.21568628],
        [0.6745098 ],
        [0.8862745 ],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.95686275],
        [0.52156866],
        [0.04313726],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.53333336],
        [0.99215686],
        [0.99215686],
        [0.99215686],
        [0.83137256],
        [0.5294118 ],
        [0.5176471 ],
        [0.0627451 ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]]], dtype=float32)
array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.])

Архитектура LeNet-5

Изменения относительно оригинальной архитектуры:

  • Увеличение фильтров в первом сверточном слое с 6 до 32 и втором сверточном слое с 16 до 64.
  • Снижение количества субдискретизации активаций до одной вместо двух, только после второго сверточного слоя.
  • Применение функции активации ReLU.

Проектирование архитектуры LeNet-5

In [4]:
from keras.api.models import Sequential
from keras.api.layers import InputLayer, Conv2D, MaxPooling2D, Dropout, Flatten, Dense

lenet_model = Sequential()

# Входной слой
lenet_model.add(InputLayer(shape=(28, 28, 1)))

# Первый скрытый слой
lenet_model.add(Conv2D(32, kernel_size=(3, 3), activation="relu"))

# Второй скрытый слой
lenet_model.add(Conv2D(64, kernel_size=(3, 3), activation="relu"))
lenet_model.add(MaxPooling2D(pool_size=(2, 2)))
lenet_model.add(Dropout(0.25))

# Третий скрытый слой
lenet_model.add(Flatten())
lenet_model.add(Dense(128, activation="relu"))
lenet_model.add(Dropout(0.5))

# Выходной слой
lenet_model.add(Dense(n_classes, activation="softmax"))

lenet_model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 26, 26, 32)     │           320 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 24, 24, 64)     │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 12, 12, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 12, 12, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten (Flatten)               │ (None, 9216)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 128)            │     1,179,776 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_1 (Dropout)             │ (None, 128)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 10)             │         1,290 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 1,199,882 (4.58 MB)
 Trainable params: 1,199,882 (4.58 MB)
 Non-trainable params: 0 (0.00 B)

Обучение глубокой модели

In [5]:
lenet_model.compile(
    loss="categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
)

lenet_model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=10,
    validation_data=(X_valid, y_valid),
)
Epoch 1/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 11s 23ms/step - accuracy: 0.8517 - loss: 0.4710 - val_accuracy: 0.9815 - val_loss: 0.0585
Epoch 2/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 10s 22ms/step - accuracy: 0.9740 - loss: 0.0872 - val_accuracy: 0.9866 - val_loss: 0.0377
Epoch 3/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 10s 22ms/step - accuracy: 0.9806 - loss: 0.0641 - val_accuracy: 0.9889 - val_loss: 0.0325
Epoch 4/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 11s 23ms/step - accuracy: 0.9827 - loss: 0.0553 - val_accuracy: 0.9910 - val_loss: 0.0285
Epoch 5/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 11s 22ms/step - accuracy: 0.9851 - loss: 0.0452 - val_accuracy: 0.9909 - val_loss: 0.0291
Epoch 6/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 11s 22ms/step - accuracy: 0.9888 - loss: 0.0359 - val_accuracy: 0.9900 - val_loss: 0.0317
Epoch 7/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 10s 22ms/step - accuracy: 0.9893 - loss: 0.0361 - val_accuracy: 0.9917 - val_loss: 0.0282
Epoch 8/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 11s 22ms/step - accuracy: 0.9906 - loss: 0.0289 - val_accuracy: 0.9915 - val_loss: 0.0285
Epoch 9/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 11s 23ms/step - accuracy: 0.9912 - loss: 0.0280 - val_accuracy: 0.9918 - val_loss: 0.0264
Epoch 10/10
469/469 ━━━━━━━━━━━━━━━━━━━━ 11s 23ms/step - accuracy: 0.9916 - loss: 0.0246 - val_accuracy: 0.9913 - val_loss: 0.0276
Out[5]:
<keras.src.callbacks.history.History at 0x22d135acda0>

Оценка качества модели

Точность модели на тестовой выборке -- 99.13 %

In [6]:
lenet_model.evaluate(X_valid, y_valid)
313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - accuracy: 0.9888 - loss: 0.0329
Out[6]:
[0.027569113299250603, 0.9912999868392944]