diff --git a/.gitignore b/.gitignore index 9a80b0e..e4923a2 100644 --- a/.gitignore +++ b/.gitignore @@ -252,4 +252,4 @@ cython_debug/ #.idea/ # End of https://www.toptal.com/developers/gitignore/api/python,pycharm+all -yolov5s.pt \ No newline at end of file +yolov8s.pt \ No newline at end of file diff --git a/main.py b/main.py index b5973fa..643f279 100644 --- a/main.py +++ b/main.py @@ -2,9 +2,7 @@ import os import sys import cv2 as cv -import numpy as np import requests -import torch import imageWorking import neuralNetwork @@ -25,18 +23,19 @@ def analyze_file(uid: str, image_path: str) -> None: raise Exception(f'Онтология с uid {uid} не существует') if not os.path.isfile(image_path): raise Exception(f'Изображение {image_path} не существует') - model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) - model.names = neuralNetwork.rename_entity(model.names) + model = neuralNetwork.load_model() # Распознавание изображения. - results = model(imageWorking.get_image_as_array(image_path)) + results = model.predict(source=imageWorking.get_image_as_array(image_path)) # Создание аксиом онтологии на основе результатов распознавания. object_properties = list() data_properties = list() - for i, res in enumerate(results.pred): - results_ndarray = np.array(res) - request = ontologyWorking.get_request_data(model.names, results_ndarray) + for res in results: + classes = res.boxes.cls.int() + conf = res.boxes.conf + boxes = res.boxes.xywh + request = ontologyWorking.get_request_data(model.names, classes, conf, boxes) object_properties += request[0] data_properties += request[1] @@ -76,7 +75,7 @@ def analyze_file(uid: str, image_path: str) -> None: print('Неизвестное состояние') # Вывод изображения. - cv.imshow('result', results.render()[0][:, :, ::-1]) + cv.imshow('result', results[0].plot()) cv.waitKey(0) cv.destroyAllWindows() diff --git a/neuralNetwork.py b/neuralNetwork.py index 35c4f39..22f88c8 100644 --- a/neuralNetwork.py +++ b/neuralNetwork.py @@ -1,10 +1,9 @@ -def rename_entity(list_names: dict) -> dict: +from ultralytics import YOLO + +def load_model(name: str = 'yolov8s.pt') -> YOLO: ''' - Нормализация названий объектов. - @param list_names: Список названий объектов. + Загрузка предварительно натренированной модели. + @param name: Название модели. ''' - temp_list = list() - for entity in list_names.values(): - entity: str - temp_list.append(entity.title().replace(' ', '')) - return temp_list + model = YOLO(name); + return model diff --git a/ontologyWorking.py b/ontologyWorking.py index 8f97ed2..11e7f94 100644 --- a/ontologyWorking.py +++ b/ontologyWorking.py @@ -15,41 +15,49 @@ def is_ontology_exists(uid: str, url: str) -> bool: return False -def get_entity_square(results_ndarray_i: np.ndarray) -> float: +def rename_entity(list_names: dict) -> dict: + ''' + Нормализация названий объектов. + @param list_names: Список названий объектов. + ''' + temp_list = list() + for entity in list_names.values(): + entity: str + temp_list.append(entity.title().replace(' ', '')) + return temp_list + + +def get_entity_square(width: float, height: float) -> float: ''' Получение площади занимаемой области. - @param results_ndarray_i: Описание местоположения объекта. + @param width: Ширина области в px. + @param height: Высота области в px. ''' - square = float((results_ndarray_i[2] - results_ndarray_i[0]) * - (results_ndarray_i[3] - results_ndarray_i[1])) - return abs(square) + return abs(width * height) -def get_request_data(entities: dict, results_ndarray: np.ndarray) -> tuple[list, list]: +def get_request_data(entities: dict, objects: np.ndarray, confs: np.ndarray, boxes: np.ndarray) -> tuple[list, list]: ''' Формирование данных для сервиса онтологий. @param entities: Список имён объектов. @param results_ndarray: Результат распознавания объектов. ''' classroom = 'classroom' + entities = rename_entity(entities) object_properties = list() data_properties = list() - for i, entity in enumerate(entities): # запись в лист имен объектов и присутствие - if (results_ndarray[:, -1] == i).sum() > 0: # если объект найден - object_properties.append({'domain': entity, - 'property': 'locatedIn', - 'range': classroom}) + + for entity_idx, entity in enumerate(entities): + if (entity_idx in objects): + object_properties.append({'domain': entity, 'property': 'locatedIn', 'range': classroom}) else: - object_properties.append({'domain': entity, - 'property': 'notLocatedIn', - 'range': classroom}) - - for i in range(results_ndarray.shape[0]): - data_properties.append({'domain': entities[int(results_ndarray[i, 5])], - 'property': 'hasArea', - 'value': get_entity_square(results_ndarray[i])}) - data_properties.append({'domain': entities[int(results_ndarray[i, 5])], - 'property': 'hasConfidence', - 'value': float(results_ndarray[i, 4])}) + object_properties.append({'domain': entity, 'property': 'notLocatedIn', 'range': classroom}) + + for object_idx, object in enumerate(objects): + conf = confs[object_idx] + box = boxes[object_idx] + entity = entities[object.item()] + data_properties.append({'domain': entity, 'property': 'hasArea', 'value': get_entity_square(float(box[2]), float(box[3]))}) + data_properties.append({'domain': entity, 'property': 'hasConfidence', 'value': float(conf)}) return object_properties, data_properties