ckiias/lec3-2-alexnet.ipynb

576 KiB
Raw Blame History

Инициализация Keras

Для ускорения обучения на GPU следует настраивать backend под конкретную ОС и модель GPU.

Для ускорения pytorch на Windows и свежей карте от NVidia следует установить вместо обычного pytorch:

torch = { version = "^2.7.0+cu128", source = "pytorch-cuda128" }
torchaudio = { version = "^2.7.0+cu128", source = "pytorch-cuda128" }
torchvision = { version = "^0.22.0+cu128", source = "pytorch-cuda128" }

Обязательно следует включить репозиторий

[[tool.poetry.source]]
name = "pytorch-cuda128"
url = "https://download.pytorch.org/whl/cu128"
priority = "explicit"

Для macOS можно использовать jax 0.5.0 (обязательно такая версия) + jax-metal 0.1.1

In [1]:
import os

os.environ["KERAS_BACKEND"] = "torch"
import keras

print(keras.__version__)
3.9.2

Загрузка набора данных для задачи классификации

В данном примере используется фрагмент набора данных Cats and Dogs Classification Dataset

В наборе данных два класса (всего 24 998 изображений): кошки (12 499 изображения) и собаки (12 499 изображения)

Ссылка: https://www.kaggle.com/datasets/bhavikjikadara/dog-and-cat-classification-dataset

In [2]:
import kagglehub
import os

path = kagglehub.dataset_download("bhavikjikadara/dog-and-cat-classification-dataset")
path = os.path.join(path, "PetImages")

Формирование выборок

Для формирования выборок используется устаревший (deprecated) класс ImageDataGenerator

Вместо него рекомендуется использовать image_dataset_from_directory (https://keras.io/api/data_loading/image/)

Для использования image_dataset_from_directory требуется tensorflow

ImageDataGenerator формирует две выборки: обучающую и валидационную (80 на 20).

В каждой выборке изображения масштабируются до размера 224 на 224 пиксела с RGB пространством.

Изображения подгружаются с диска в процессе обучения и валидации модели.

In [3]:
from keras.src.legacy.preprocessing.image import ImageDataGenerator

batch_size = 64

data_loader = ImageDataGenerator(validation_split=0.2)

train = data_loader.flow_from_directory(
    directory=path,
    target_size=(224, 224),
    color_mode="rgb",
    class_mode="binary",
    batch_size=batch_size,
    shuffle=True,
    seed=9,
    subset="training",
)

valid = data_loader.flow_from_directory(
    directory=path,
    target_size=(224, 224),
    color_mode="rgb",
    class_mode="binary",
    batch_size=batch_size,
    shuffle=True,
    seed=9,
    subset="validation",
)

train.class_indices
Found 20000 images belonging to 2 classes.
Found 4998 images belonging to 2 classes.
Out[3]:
{'Cat': 0, 'Dog': 1}

Архитектура AlexNet

Модель AlexNet описана в лекции про глубокое обучение

Проектирование архитектуры AlexNet

In [4]:
from keras.api.models import Sequential
from keras.api.layers import InputLayer, Conv2D, MaxPooling2D, Dropout, Flatten, Dense, BatchNormalization

alexnet_model = Sequential()

# Входной слой
alexnet_model.add(InputLayer(shape=(224, 224, 3)))

# Первый скрытый слой
alexnet_model.add(Conv2D(96, kernel_size=(11, 11), strides=(4, 4), activation="relu"))
alexnet_model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
alexnet_model.add(BatchNormalization())

# Второй скрытый слой
alexnet_model.add(Conv2D(256, kernel_size=(5, 5), activation="relu"))
alexnet_model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
alexnet_model.add(BatchNormalization())

# Третий скрытый слой
alexnet_model.add(Conv2D(256, kernel_size=(3, 3), activation="relu"))

# Четвертый скрытый слой
alexnet_model.add(Conv2D(384, kernel_size=(3, 3), activation="relu"))

# Пятый скрытый слой
alexnet_model.add(Conv2D(384, kernel_size=(3, 3), activation="relu"))
alexnet_model.add(MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
alexnet_model.add(BatchNormalization())

# Шестой скрытый слой
alexnet_model.add(Flatten())
alexnet_model.add(Dense(4096, activation="tanh"))
alexnet_model.add(Dropout(0.5))

# Седьмой скрытый слой
alexnet_model.add(Dense(4096, activation="tanh"))
alexnet_model.add(Dropout(0.5))

# Выходной слой
alexnet_model.add(Dense(1, activation="sigmoid"))

alexnet_model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 54, 54, 96)     │        34,944 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 26, 26, 96)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization             │ (None, 26, 26, 96)     │           384 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 22, 22, 256)    │       614,656 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 10, 10, 256)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_1           │ (None, 10, 10, 256)    │         1,024 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_2 (Conv2D)               │ (None, 8, 8, 256)      │       590,080 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_3 (Conv2D)               │ (None, 6, 6, 384)      │       885,120 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_4 (Conv2D)               │ (None, 4, 4, 384)      │     1,327,488 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_2 (MaxPooling2D)  │ (None, 1, 1, 384)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_2           │ (None, 1, 1, 384)      │         1,536 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten (Flatten)               │ (None, 384)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 4096)           │     1,576,960 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 4096)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 4096)           │    16,781,312 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_1 (Dropout)             │ (None, 4096)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (None, 1)              │         4,097 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 21,817,601 (83.23 MB)
 Trainable params: 21,816,129 (83.22 MB)
 Non-trainable params: 1,472 (5.75 KB)

Обучение глубокой модели

In [5]:
alexnet_model.compile(
    loss="binary_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
)

alexnet_model.fit(
    x=train,
    validation_data=valid,
    epochs=100
)
d:\Projects\Python\mai\.venv\Lib\site-packages\keras\src\trainers\data_adapters\py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
  self._warn_if_super_not_called()
d:\Projects\Python\mai\.venv\Lib\site-packages\PIL\TiffImagePlugin.py:900: UserWarning: Truncated File Read
  warnings.warn(str(msg))
Epoch 1/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 93s 295ms/step - accuracy: 0.5094 - loss: 1.5595 - val_accuracy: 0.5290 - val_loss: 0.7144
Epoch 2/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 92s 294ms/step - accuracy: 0.5305 - loss: 0.7776 - val_accuracy: 0.5314 - val_loss: 0.7015
Epoch 3/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 92s 294ms/step - accuracy: 0.5392 - loss: 0.7418 - val_accuracy: 0.5136 - val_loss: 0.7653
Epoch 4/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 90s 288ms/step - accuracy: 0.5461 - loss: 0.7339 - val_accuracy: 0.5676 - val_loss: 0.6940
Epoch 5/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 97s 310ms/step - accuracy: 0.5631 - loss: 0.7349 - val_accuracy: 0.4854 - val_loss: 0.7876
Epoch 6/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 97s 309ms/step - accuracy: 0.5519 - loss: 0.7588 - val_accuracy: 0.5784 - val_loss: 0.7633
Epoch 7/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 92s 293ms/step - accuracy: 0.5918 - loss: 0.6969 - val_accuracy: 0.5990 - val_loss: 0.6865
Epoch 8/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.6017 - loss: 0.6950 - val_accuracy: 0.5470 - val_loss: 0.7832
Epoch 9/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 280ms/step - accuracy: 0.5869 - loss: 0.7124 - val_accuracy: 0.5500 - val_loss: 0.7952
Epoch 10/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 86s 276ms/step - accuracy: 0.5894 - loss: 0.7112 - val_accuracy: 0.6182 - val_loss: 0.7114
Epoch 11/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 93s 296ms/step - accuracy: 0.5923 - loss: 0.7114 - val_accuracy: 0.5674 - val_loss: 0.7310
Epoch 12/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 97s 310ms/step - accuracy: 0.6261 - loss: 0.6881 - val_accuracy: 0.5842 - val_loss: 0.7458
Epoch 13/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 98s 312ms/step - accuracy: 0.6293 - loss: 0.6767 - val_accuracy: 0.5020 - val_loss: 0.9032
Epoch 14/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 95s 305ms/step - accuracy: 0.6181 - loss: 0.6952 - val_accuracy: 0.6417 - val_loss: 0.6388
Epoch 15/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 93s 299ms/step - accuracy: 0.6178 - loss: 0.6890 - val_accuracy: 0.5462 - val_loss: 0.7037
Epoch 16/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 99s 315ms/step - accuracy: 0.5895 - loss: 0.7066 - val_accuracy: 0.6188 - val_loss: 0.7259
Epoch 17/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 102s 325ms/step - accuracy: 0.6549 - loss: 0.6548 - val_accuracy: 0.5402 - val_loss: 0.8502
Epoch 18/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 99s 315ms/step - accuracy: 0.6935 - loss: 0.6069 - val_accuracy: 0.5174 - val_loss: 1.1147
Epoch 19/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 89s 286ms/step - accuracy: 0.7211 - loss: 0.5727 - val_accuracy: 0.7047 - val_loss: 0.6050
Epoch 20/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 86s 276ms/step - accuracy: 0.7356 - loss: 0.5669 - val_accuracy: 0.6951 - val_loss: 0.5860
Epoch 21/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 88s 281ms/step - accuracy: 0.7518 - loss: 0.5439 - val_accuracy: 0.7709 - val_loss: 0.4893
Epoch 22/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.7632 - loss: 0.5239 - val_accuracy: 0.7467 - val_loss: 0.5864
Epoch 23/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.7727 - loss: 0.5033 - val_accuracy: 0.7751 - val_loss: 0.4713
Epoch 24/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.7841 - loss: 0.4682 - val_accuracy: 0.7643 - val_loss: 0.5510
Epoch 25/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 86s 274ms/step - accuracy: 0.7872 - loss: 0.4699 - val_accuracy: 0.5776 - val_loss: 1.0140
Epoch 26/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 85s 273ms/step - accuracy: 0.7962 - loss: 0.4578 - val_accuracy: 0.6791 - val_loss: 0.6313
Epoch 27/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 86s 275ms/step - accuracy: 0.8093 - loss: 0.4240 - val_accuracy: 0.6463 - val_loss: 0.8024
Epoch 28/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 91s 291ms/step - accuracy: 0.8099 - loss: 0.4352 - val_accuracy: 0.7421 - val_loss: 0.5641
Epoch 29/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 92s 293ms/step - accuracy: 0.8185 - loss: 0.4183 - val_accuracy: 0.7937 - val_loss: 0.4554
Epoch 30/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 93s 296ms/step - accuracy: 0.8300 - loss: 0.3931 - val_accuracy: 0.7837 - val_loss: 0.4655
Epoch 31/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 107s 342ms/step - accuracy: 0.8468 - loss: 0.3578 - val_accuracy: 0.7977 - val_loss: 0.5012
Epoch 32/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 116s 372ms/step - accuracy: 0.8535 - loss: 0.3602 - val_accuracy: 0.7783 - val_loss: 0.5194
Epoch 33/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 116s 371ms/step - accuracy: 0.8608 - loss: 0.3326 - val_accuracy: 0.7873 - val_loss: 0.4888
Epoch 34/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 111s 356ms/step - accuracy: 0.8580 - loss: 0.3339 - val_accuracy: 0.7375 - val_loss: 0.6566
Epoch 35/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 111s 354ms/step - accuracy: 0.8615 - loss: 0.3329 - val_accuracy: 0.8181 - val_loss: 0.4174
Epoch 36/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 110s 350ms/step - accuracy: 0.8610 - loss: 0.3358 - val_accuracy: 0.6757 - val_loss: 0.9422
Epoch 37/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 101s 324ms/step - accuracy: 0.8593 - loss: 0.3543 - val_accuracy: 0.8081 - val_loss: 0.5241
Epoch 38/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 101s 323ms/step - accuracy: 0.8805 - loss: 0.3017 - val_accuracy: 0.8401 - val_loss: 0.3856
Epoch 39/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 101s 324ms/step - accuracy: 0.8928 - loss: 0.2749 - val_accuracy: 0.7851 - val_loss: 0.4438
Epoch 40/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 105s 337ms/step - accuracy: 0.8995 - loss: 0.2546 - val_accuracy: 0.8591 - val_loss: 0.3600
Epoch 41/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 106s 338ms/step - accuracy: 0.9109 - loss: 0.2269 - val_accuracy: 0.7961 - val_loss: 0.5176
Epoch 42/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 106s 340ms/step - accuracy: 0.9057 - loss: 0.2371 - val_accuracy: 0.8575 - val_loss: 0.3894
Epoch 43/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 106s 340ms/step - accuracy: 0.9111 - loss: 0.2292 - val_accuracy: 0.8493 - val_loss: 0.4270
Epoch 44/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 106s 339ms/step - accuracy: 0.9188 - loss: 0.2122 - val_accuracy: 0.8497 - val_loss: 0.4038
Epoch 45/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 106s 338ms/step - accuracy: 0.9249 - loss: 0.1933 - val_accuracy: 0.7949 - val_loss: 0.5533
Epoch 46/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 97s 310ms/step - accuracy: 0.9347 - loss: 0.1671 - val_accuracy: 0.7715 - val_loss: 0.8307
Epoch 47/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9231 - loss: 0.2009 - val_accuracy: 0.7877 - val_loss: 0.7301
Epoch 48/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9295 - loss: 0.1965 - val_accuracy: 0.8457 - val_loss: 0.5038
Epoch 49/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 90s 289ms/step - accuracy: 0.9285 - loss: 0.1886 - val_accuracy: 0.8737 - val_loss: 0.4602
Epoch 50/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 90s 288ms/step - accuracy: 0.9429 - loss: 0.1447 - val_accuracy: 0.8281 - val_loss: 0.4814
Epoch 51/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 276ms/step - accuracy: 0.9410 - loss: 0.1527 - val_accuracy: 0.8800 - val_loss: 0.3787
Epoch 52/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9284 - loss: 0.1853 - val_accuracy: 0.7073 - val_loss: 0.8980
Epoch 53/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.8610 - loss: 0.3486 - val_accuracy: 0.8417 - val_loss: 0.4740
Epoch 54/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.9170 - loss: 0.2164 - val_accuracy: 0.8693 - val_loss: 0.4258
Epoch 55/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9350 - loss: 0.1721 - val_accuracy: 0.8671 - val_loss: 0.3911
Epoch 56/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 276ms/step - accuracy: 0.9512 - loss: 0.1330 - val_accuracy: 0.8513 - val_loss: 0.4823
Epoch 57/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9433 - loss: 0.1466 - val_accuracy: 0.8745 - val_loss: 0.4241
Epoch 58/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9478 - loss: 0.1368 - val_accuracy: 0.8768 - val_loss: 0.3645
Epoch 59/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9453 - loss: 0.1556 - val_accuracy: 0.8427 - val_loss: 0.4757
Epoch 60/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9205 - loss: 0.2193 - val_accuracy: 0.8575 - val_loss: 0.4185
Epoch 61/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9466 - loss: 0.1508 - val_accuracy: 0.8760 - val_loss: 0.4537
Epoch 62/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.9552 - loss: 0.1254 - val_accuracy: 0.8589 - val_loss: 0.5982
Epoch 63/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9567 - loss: 0.1279 - val_accuracy: 0.8443 - val_loss: 0.6917
Epoch 64/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 89s 284ms/step - accuracy: 0.9634 - loss: 0.1068 - val_accuracy: 0.8834 - val_loss: 0.4579
Epoch 65/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 89s 285ms/step - accuracy: 0.9675 - loss: 0.0933 - val_accuracy: 0.8691 - val_loss: 0.6827
Epoch 66/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 95s 304ms/step - accuracy: 0.9714 - loss: 0.0819 - val_accuracy: 0.8788 - val_loss: 0.5256
Epoch 67/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 93s 297ms/step - accuracy: 0.9672 - loss: 0.0972 - val_accuracy: 0.8565 - val_loss: 0.5663
Epoch 68/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 92s 294ms/step - accuracy: 0.9640 - loss: 0.1079 - val_accuracy: 0.8810 - val_loss: 0.4636
Epoch 69/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 104s 332ms/step - accuracy: 0.9725 - loss: 0.0784 - val_accuracy: 0.8577 - val_loss: 0.4973
Epoch 70/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 97s 311ms/step - accuracy: 0.9629 - loss: 0.1124 - val_accuracy: 0.8669 - val_loss: 0.6146
Epoch 71/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 106s 340ms/step - accuracy: 0.9687 - loss: 0.0921 - val_accuracy: 0.8715 - val_loss: 0.4832
Epoch 72/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 100s 318ms/step - accuracy: 0.9716 - loss: 0.0872 - val_accuracy: 0.8445 - val_loss: 0.6765
Epoch 73/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.9651 - loss: 0.1045 - val_accuracy: 0.8621 - val_loss: 0.6374
Epoch 74/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9641 - loss: 0.1208 - val_accuracy: 0.8531 - val_loss: 0.6291
Epoch 75/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9669 - loss: 0.1037 - val_accuracy: 0.8675 - val_loss: 0.5712
Epoch 76/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9676 - loss: 0.1039 - val_accuracy: 0.8647 - val_loss: 0.5575
Epoch 77/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.9654 - loss: 0.1088 - val_accuracy: 0.8798 - val_loss: 0.5096
Epoch 78/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 88s 280ms/step - accuracy: 0.9704 - loss: 0.0928 - val_accuracy: 0.8143 - val_loss: 0.9064
Epoch 79/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.9386 - loss: 0.1918 - val_accuracy: 0.8621 - val_loss: 0.6179
Epoch 80/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.9692 - loss: 0.0961 - val_accuracy: 0.8659 - val_loss: 0.5571
Epoch 81/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9760 - loss: 0.0721 - val_accuracy: 0.8719 - val_loss: 0.8253
Epoch 82/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9751 - loss: 0.0857 - val_accuracy: 0.8559 - val_loss: 0.6966
Epoch 83/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9750 - loss: 0.0845 - val_accuracy: 0.8387 - val_loss: 0.8816
Epoch 84/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9675 - loss: 0.1127 - val_accuracy: 0.8729 - val_loss: 0.5734
Epoch 85/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9703 - loss: 0.1002 - val_accuracy: 0.8485 - val_loss: 0.6070
Epoch 86/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.8775 - loss: 0.3334 - val_accuracy: 0.8033 - val_loss: 0.6031
Epoch 87/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.9396 - loss: 0.1664 - val_accuracy: 0.8483 - val_loss: 0.5745
Epoch 88/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 88s 280ms/step - accuracy: 0.9685 - loss: 0.0909 - val_accuracy: 0.8701 - val_loss: 0.5936
Epoch 89/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 280ms/step - accuracy: 0.9808 - loss: 0.0539 - val_accuracy: 0.8900 - val_loss: 0.5556
Epoch 90/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9860 - loss: 0.0455 - val_accuracy: 0.8792 - val_loss: 0.6251
Epoch 91/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9842 - loss: 0.0473 - val_accuracy: 0.8635 - val_loss: 0.7786
Epoch 92/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9871 - loss: 0.0405 - val_accuracy: 0.8699 - val_loss: 0.6566
Epoch 93/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9833 - loss: 0.0583 - val_accuracy: 0.8778 - val_loss: 0.7701
Epoch 94/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 279ms/step - accuracy: 0.9854 - loss: 0.0500 - val_accuracy: 0.8477 - val_loss: 0.5307
Epoch 95/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9766 - loss: 0.0759 - val_accuracy: 0.8739 - val_loss: 0.7431
Epoch 96/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 280ms/step - accuracy: 0.9822 - loss: 0.0606 - val_accuracy: 0.8443 - val_loss: 0.9705
Epoch 97/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9854 - loss: 0.0469 - val_accuracy: 0.8707 - val_loss: 0.7642
Epoch 98/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 277ms/step - accuracy: 0.9818 - loss: 0.0676 - val_accuracy: 0.8790 - val_loss: 0.7260
Epoch 99/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9854 - loss: 0.0496 - val_accuracy: 0.8555 - val_loss: 0.8581
Epoch 100/100
313/313 ━━━━━━━━━━━━━━━━━━━━ 87s 278ms/step - accuracy: 0.9840 - loss: 0.0528 - val_accuracy: 0.8824 - val_loss: 0.6971
Out[5]:
<keras.src.callbacks.history.History at 0x27eee7fcad0>

Оценка качества модели

Качество модели - 88.2 %.#### Оценка качества модели

In [6]:
alexnet_model.evaluate(valid)
79/79 ━━━━━━━━━━━━━━━━━━━━ 10s 124ms/step - accuracy: 0.8849 - loss: 0.6870
Out[6]:
[0.697141706943512, 0.8823529481887817]

Пример использования обученной модели

Для примера используются случайные изображения из сети Интернет

In [16]:
import mahotas as mh
from matplotlib import pyplot as plt

cat = mh.imread("data/-cat.jpg")
plt.imshow(cat)
plt.show()

dog = mh.imread("data/-dog.jpg")
plt.imshow(dog)
plt.show()
In [17]:
resized_cat = mh.resize.resize_rgb_to(cat, (224, 224))

resized_dog = mh.resize.resize_rgb_to(dog, (224, 224))
resized_dog.shape
Out[17]:
(224, 224, 3)
In [19]:
results = [
        1
        if alexnet_model.predict(item.reshape(1, 224, 224, 3).astype("float32"))
        > 0.5
        else 0
        for item in [resized_cat, resized_dog]
]

for result in results:
    display(result, list(valid.class_indices.keys())[list(valid.class_indices.values()).index(result)])
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 121ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 134ms/step
0
'Cat'
1
'Dog'