Миграция на YOLOv8

This commit is contained in:
Vladislav Moiseev 2023-06-27 00:28:41 +04:00
parent b6a8209eb4
commit 19dd74c4f7
4 changed files with 46 additions and 40 deletions

2
.gitignore vendored
View File

@ -252,4 +252,4 @@ cython_debug/
#.idea/
# End of https://www.toptal.com/developers/gitignore/api/python,pycharm+all
yolov5s.pt
yolov8s.pt

17
main.py
View File

@ -2,9 +2,7 @@ import os
import sys
import cv2 as cv
import numpy as np
import requests
import torch
import imageWorking
import neuralNetwork
@ -25,18 +23,19 @@ def analyze_file(uid: str, image_path: str) -> None:
raise Exception(f'Онтология с uid {uid} не существует')
if not os.path.isfile(image_path):
raise Exception(f'Изображение {image_path} не существует')
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
model.names = neuralNetwork.rename_entity(model.names)
model = neuralNetwork.load_model()
# Распознавание изображения.
results = model(imageWorking.get_image_as_array(image_path))
results = model.predict(source=imageWorking.get_image_as_array(image_path))
# Создание аксиом онтологии на основе результатов распознавания.
object_properties = list()
data_properties = list()
for i, res in enumerate(results.pred):
results_ndarray = np.array(res)
request = ontologyWorking.get_request_data(model.names, results_ndarray)
for res in results:
classes = res.boxes.cls.int()
conf = res.boxes.conf
boxes = res.boxes.xywh
request = ontologyWorking.get_request_data(model.names, classes, conf, boxes)
object_properties += request[0]
data_properties += request[1]
@ -76,7 +75,7 @@ def analyze_file(uid: str, image_path: str) -> None:
print('Неизвестное состояние')
# Вывод изображения.
cv.imshow('result', results.render()[0][:, :, ::-1])
cv.imshow('result', results[0].plot())
cv.waitKey(0)
cv.destroyAllWindows()

View File

@ -1,10 +1,9 @@
def rename_entity(list_names: dict) -> dict:
from ultralytics import YOLO
def load_model(name: str = 'yolov8s.pt') -> YOLO:
'''
Нормализация названий объектов.
@param list_names: Список названий объектов.
Загрузка предварительно натренированной модели.
@param name: Название модели.
'''
temp_list = list()
for entity in list_names.values():
entity: str
temp_list.append(entity.title().replace(' ', ''))
return temp_list
model = YOLO(name);
return model

View File

@ -15,41 +15,49 @@ def is_ontology_exists(uid: str, url: str) -> bool:
return False
def get_entity_square(results_ndarray_i: np.ndarray) -> float:
def rename_entity(list_names: dict) -> dict:
'''
Нормализация названий объектов.
@param list_names: Список названий объектов.
'''
temp_list = list()
for entity in list_names.values():
entity: str
temp_list.append(entity.title().replace(' ', ''))
return temp_list
def get_entity_square(width: float, height: float) -> float:
'''
Получение площади занимаемой области.
@param results_ndarray_i: Описание местоположения объекта.
@param width: Ширина области в px.
@param height: Высота области в px.
'''
square = float((results_ndarray_i[2] - results_ndarray_i[0]) *
(results_ndarray_i[3] - results_ndarray_i[1]))
return abs(square)
return abs(width * height)
def get_request_data(entities: dict, results_ndarray: np.ndarray) -> tuple[list, list]:
def get_request_data(entities: dict, objects: np.ndarray, confs: np.ndarray, boxes: np.ndarray) -> tuple[list, list]:
'''
Формирование данных для сервиса онтологий.
@param entities: Список имён объектов.
@param results_ndarray: Результат распознавания объектов.
'''
classroom = 'classroom'
entities = rename_entity(entities)
object_properties = list()
data_properties = list()
for i, entity in enumerate(entities): # запись в лист имен объектов и присутствие
if (results_ndarray[:, -1] == i).sum() > 0: # если объект найден
object_properties.append({'domain': entity,
'property': 'locatedIn',
'range': classroom})
else:
object_properties.append({'domain': entity,
'property': 'notLocatedIn',
'range': classroom})
for i in range(results_ndarray.shape[0]):
data_properties.append({'domain': entities[int(results_ndarray[i, 5])],
'property': 'hasArea',
'value': get_entity_square(results_ndarray[i])})
data_properties.append({'domain': entities[int(results_ndarray[i, 5])],
'property': 'hasConfidence',
'value': float(results_ndarray[i, 4])})
for entity_idx, entity in enumerate(entities):
if (entity_idx in objects):
object_properties.append({'domain': entity, 'property': 'locatedIn', 'range': classroom})
else:
object_properties.append({'domain': entity, 'property': 'notLocatedIn', 'range': classroom})
for object_idx, object in enumerate(objects):
conf = confs[object_idx]
box = boxes[object_idx]
entity = entities[object.item()]
data_properties.append({'domain': entity, 'property': 'hasArea', 'value': get_entity_square(float(box[2]), float(box[3]))})
data_properties.append({'domain': entity, 'property': 'hasConfidence', 'value': float(conf)})
return object_properties, data_properties