ckiias/lec3-1-lenet.ipynb

1210 lines
50 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "7915f17e",
"metadata": {},
"source": [
"### Инициализация Keras"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "560de685",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.9.2\n"
]
}
],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"torch\"\n",
"import keras\n",
"\n",
"print(keras.__version__)"
]
},
{
"cell_type": "markdown",
"id": "27d07c7a",
"metadata": {},
"source": [
"#### Загрузка набора данных для задачи классификации\n",
"\n",
"База данных MNIST (сокращение от \"Modified National Institute of Standards and Technology\") — объёмная база данных образцов рукописного написания цифр. База данных является стандартом, предложенным Национальным институтом стандартов и технологий США с целью обучения и сопоставления методов распознавания изображений с помощью машинного обучения в первую очередь на основе нейронных сетей. Данные состоят из заранее подготовленных примеров изображений, на основе которых проводится обучение и тестирование систем.\n",
"\n",
"База данных MNIST содержит 60000 изображений для обучения и 10000 изображений для тестирования."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4a885965",
"metadata": {},
"outputs": [],
"source": [
"from keras.api.datasets import mnist\n",
"\n",
"(X_train, y_train), (X_valid, y_valid) = mnist.load_data()"
]
},
{
"cell_type": "markdown",
"id": "4761508b",
"metadata": {},
"source": [
"#### Предобработка данных\n",
"\n",
"Количество классов - 10 (от 0 до 9).\n",
"\n",
"Все изображения из X трансформируются в матрицы 28*28 признака и нормализуются.\n",
"\n",
"Для целевых признаков применяется унитарное кодирование в бинарные векторы длиной 10 (нормализация).\n",
"\n",
"Четвертое измерение в reshape определяет количество цветовых каналов.\n",
"\n",
"Используется только один канал, так как изображения не цветные.\n",
"\n",
"Для цветных изображений следует использовать три канала (RGB)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d5ca49ce",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.01176471],\n",
" [0.07058824],\n",
" [0.07058824],\n",
" [0.07058824],\n",
" [0.49411765],\n",
" [0.53333336],\n",
" [0.6862745 ],\n",
" [0.10196079],\n",
" [0.6509804 ],\n",
" [1. ],\n",
" [0.96862745],\n",
" [0.49803922],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.11764706],\n",
" [0.14117648],\n",
" [0.36862746],\n",
" [0.6039216 ],\n",
" [0.6666667 ],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.88235295],\n",
" [0.6745098 ],\n",
" [0.99215686],\n",
" [0.9490196 ],\n",
" [0.7647059 ],\n",
" [0.2509804 ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.19215687],\n",
" [0.93333334],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.9843137 ],\n",
" [0.3647059 ],\n",
" [0.32156864],\n",
" [0.32156864],\n",
" [0.21960784],\n",
" [0.15294118],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.07058824],\n",
" [0.85882354],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.7764706 ],\n",
" [0.7137255 ],\n",
" [0.96862745],\n",
" [0.94509804],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.3137255 ],\n",
" [0.6117647 ],\n",
" [0.41960785],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.8039216 ],\n",
" [0.04313726],\n",
" [0. ],\n",
" [0.16862746],\n",
" [0.6039216 ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.05490196],\n",
" [0.00392157],\n",
" [0.6039216 ],\n",
" [0.99215686],\n",
" [0.3529412 ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.54509807],\n",
" [0.99215686],\n",
" [0.74509805],\n",
" [0.00784314],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.04313726],\n",
" [0.74509805],\n",
" [0.99215686],\n",
" [0.27450982],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.13725491],\n",
" [0.94509804],\n",
" [0.88235295],\n",
" [0.627451 ],\n",
" [0.42352942],\n",
" [0.00392157],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.31764707],\n",
" [0.9411765 ],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.46666667],\n",
" [0.09803922],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.1764706 ],\n",
" [0.7294118 ],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.5882353 ],\n",
" [0.10588235],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.0627451 ],\n",
" [0.3647059 ],\n",
" [0.9882353 ],\n",
" [0.99215686],\n",
" [0.73333335],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.9764706 ],\n",
" [0.99215686],\n",
" [0.9764706 ],\n",
" [0.2509804 ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.18039216],\n",
" [0.50980395],\n",
" [0.7176471 ],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.8117647 ],\n",
" [0.00784314],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.15294118],\n",
" [0.5803922 ],\n",
" [0.8980392 ],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.98039216],\n",
" [0.7137255 ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.09411765],\n",
" [0.44705883],\n",
" [0.8666667 ],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.7882353 ],\n",
" [0.30588236],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.09019608],\n",
" [0.25882354],\n",
" [0.8352941 ],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.7764706 ],\n",
" [0.31764707],\n",
" [0.00784314],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.07058824],\n",
" [0.67058825],\n",
" [0.85882354],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.7647059 ],\n",
" [0.3137255 ],\n",
" [0.03529412],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.21568628],\n",
" [0.6745098 ],\n",
" [0.8862745 ],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.95686275],\n",
" [0.52156866],\n",
" [0.04313726],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0.53333336],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.99215686],\n",
" [0.83137256],\n",
" [0.5294118 ],\n",
" [0.5176471 ],\n",
" [0.0627451 ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]],\n",
"\n",
" [[0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ],\n",
" [0. ]]], dtype=float32)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.])"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"n_classes = 10\n",
"\n",
"X_train = X_train.reshape(60000, 28, 28, 1).astype(\"float32\") / 255\n",
"X_valid = X_valid.reshape(10000, 28, 28, 1).astype(\"float32\") / 255\n",
"y_train = keras.utils.to_categorical(y_train, n_classes)\n",
"y_valid = keras.utils.to_categorical(y_valid, n_classes)\n",
"\n",
"display(X_train[0])\n",
"display(y_train[0])"
]
},
{
"cell_type": "markdown",
"id": "bfb9434d",
"metadata": {},
"source": [
"### Архитектура LeNet-5\n",
"\n",
"Изменения относительно оригинальной архитектуры:\n",
"- Увеличение фильтров в первом и втором сверточных слоях до 32 и 64 с 6 и 16.\n",
"- Снижение количества субдискретизации активаций до одной вместо двух.\n",
"- Применение функции активации ReLU."
]
},
{
"cell_type": "markdown",
"id": "3250d20b",
"metadata": {},
"source": [
"#### Проектирование архитектуры LeNet-5"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "904b01b0",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ conv2d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">26</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">32</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">320</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ conv2d_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Conv2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">24</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">24</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">18,496</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ max_pooling2d (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">MaxPooling2D</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dropout (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">64</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ flatten (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Flatten</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">9216</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,179,776</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dropout_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dropout</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,290</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ conv2d (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m26\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m320\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ conv2d_1 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m24\u001b[0m, \u001b[38;5;34m24\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m18,496\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ max_pooling2d (\u001b[38;5;33mMaxPooling2D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m64\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ flatten (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9216\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,179,776\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m) │ \u001b[38;5;34m1,290\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,199,882</span> (4.58 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,199,882\u001b[0m (4.58 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,199,882</span> (4.58 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,199,882\u001b[0m (4.58 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from keras.api.models import Sequential\n",
"from keras.api.layers import InputLayer, Conv2D, MaxPooling2D, Dropout, Flatten, Dense\n",
"\n",
"lenet_model = Sequential()\n",
"\n",
"# Входной слой\n",
"lenet_model.add(InputLayer(shape=(28, 28, 1)))\n",
"\n",
"# Первый скрытый слой\n",
"lenet_model.add(Conv2D(32, kernel_size=(3, 3), activation=\"relu\"))\n",
"\n",
"# Второй скрытый слой\n",
"lenet_model.add(Conv2D(64, kernel_size=(3, 3), activation=\"relu\"))\n",
"lenet_model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"lenet_model.add(Dropout(0.25))\n",
"\n",
"# Третий скрытый слой\n",
"lenet_model.add(Flatten())\n",
"lenet_model.add(Dense(128, activation=\"relu\"))\n",
"lenet_model.add(Dropout(0.5))\n",
"\n",
"# Выходной слой\n",
"lenet_model.add(Dense(n_classes, activation=\"softmax\"))\n",
"\n",
"lenet_model.summary()"
]
},
{
"cell_type": "markdown",
"id": "49fceead",
"metadata": {},
"source": [
"#### Обучение глубокой модели"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fe650631",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 23ms/step - accuracy: 0.8517 - loss: 0.4710 - val_accuracy: 0.9815 - val_loss: 0.0585\n",
"Epoch 2/10\n",
"\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 22ms/step - accuracy: 0.9740 - loss: 0.0872 - val_accuracy: 0.9866 - val_loss: 0.0377\n",
"Epoch 3/10\n",
"\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 22ms/step - accuracy: 0.9806 - loss: 0.0641 - val_accuracy: 0.9889 - val_loss: 0.0325\n",
"Epoch 4/10\n",
"\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 23ms/step - accuracy: 0.9827 - loss: 0.0553 - val_accuracy: 0.9910 - val_loss: 0.0285\n",
"Epoch 5/10\n",
"\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 22ms/step - accuracy: 0.9851 - loss: 0.0452 - val_accuracy: 0.9909 - val_loss: 0.0291\n",
"Epoch 6/10\n",
"\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 22ms/step - accuracy: 0.9888 - loss: 0.0359 - val_accuracy: 0.9900 - val_loss: 0.0317\n",
"Epoch 7/10\n",
"\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 22ms/step - accuracy: 0.9893 - loss: 0.0361 - val_accuracy: 0.9917 - val_loss: 0.0282\n",
"Epoch 8/10\n",
"\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 22ms/step - accuracy: 0.9906 - loss: 0.0289 - val_accuracy: 0.9915 - val_loss: 0.0285\n",
"Epoch 9/10\n",
"\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 23ms/step - accuracy: 0.9912 - loss: 0.0280 - val_accuracy: 0.9918 - val_loss: 0.0264\n",
"Epoch 10/10\n",
"\u001b[1m469/469\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m11s\u001b[0m 23ms/step - accuracy: 0.9916 - loss: 0.0246 - val_accuracy: 0.9913 - val_loss: 0.0276\n"
]
},
{
"data": {
"text/plain": [
"<keras.src.callbacks.history.History at 0x22d135acda0>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lenet_model.compile(\n",
" loss=\"categorical_crossentropy\",\n",
" optimizer=\"adam\",\n",
" metrics=[\"accuracy\"],\n",
")\n",
"\n",
"lenet_model.fit(\n",
" X_train,\n",
" y_train,\n",
" batch_size=128,\n",
" epochs=10,\n",
" validation_data=(X_valid, y_valid),\n",
")"
]
},
{
"cell_type": "markdown",
"id": "a831795e",
"metadata": {},
"source": [
"#### Оценка качества модели\n",
"\n",
"Точность модели на тестовой выборке -- 99.13 %"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7f9d0bc0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m313/313\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 5ms/step - accuracy: 0.9888 - loss: 0.0329\n"
]
},
{
"data": {
"text/plain": [
"[0.027569113299250603, 0.9912999868392944]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lenet_model.evaluate(X_valid, y_valid)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv (3.12.10)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}