Move parse_tree_node.py to NodeMixin class

This commit is contained in:
Aleksey Filippov 2022-01-22 11:58:15 +04:00
parent cab7a1c129
commit 9ad8bcc667
4 changed files with 48 additions and 44 deletions

View File

@ -1,44 +1,49 @@
import warnings from typing import List, Dict, Optional
from typing import List, Dict
from anytree import Node, RenderTree from anytree import RenderTree
from src.parse_tree.parse_tree_item import ParseItem from src.parse_tree.parse_tree_node import ParseTreeNode
class ParseTree: class ParseTree:
def __init__(self, raw_tree: str): def __init__(self, raw_tree: str):
self._tree: Node = self.__create_tree(self.__create_nodes_array(raw_tree)) self._tree: ParseTreeNode = self.__create_tree(self.__create_nodes_array(raw_tree))
@staticmethod @staticmethod
def __create_nodes_array(raw_tree: str) -> List[ParseItem]: def __parse_raw_tree_line(raw_tree_line: str) -> Optional[ParseTreeNode]:
nodes: List[ParseItem] = [] parsed_str = raw_tree_line.split('\t')
if len(parsed_str) != 10:
return None
return ParseTreeNode(int(parsed_str[0]), parsed_str[1], parsed_str[3], int(parsed_str[6]), parsed_str[7])
def __create_nodes_array(self, raw_tree: str) -> List[ParseTreeNode]:
nodes: List[ParseTreeNode] = []
parsed_syntax_lines = raw_tree.split('\n') parsed_syntax_lines = raw_tree.split('\n')
for line in parsed_syntax_lines: for line in parsed_syntax_lines:
try: tree_node = self.__parse_raw_tree_line(line)
nodes.append(ParseItem(line)) if tree_node:
except AssertionError: nodes.append(tree_node)
warnings.warn('Empty line')
return nodes return nodes
def __create_tree(self, nodes: List[ParseItem]) -> Node: @staticmethod
def __create_tree(nodes: List[ParseTreeNode]) -> ParseTreeNode:
tree_nodes_count: int = 0 tree_nodes_count: int = 0
parents: Dict[int, Node] = {} parents: Dict[int, ParseTreeNode] = {}
root = None root: ParseTreeNode = Optional[ParseTreeNode]
while tree_nodes_count < len(nodes): while tree_nodes_count < len(nodes):
for node in nodes: for node in nodes:
if parents.get(node.index()) is not None: if parents.get(node.index) is not None:
continue continue
parent: Node = parents.get(node.parent_index()) parent: ParseTreeNode = parents.get(node.parent_index)
if parent is not None or node.parent_index() == 0: if parent is not None or node.parent_index == 0:
new_node = Node(node, parent) node.parent = parent
parents[node.index()] = new_node parents[node.index] = node
tree_nodes_count = tree_nodes_count + 1 tree_nodes_count = tree_nodes_count + 1
if node.parent_index() == 0: if node.parent_index == 0:
root = new_node root = node
break break
return root return root
def __repr__(self) -> str: def __repr__(self) -> str:
return '\n'.join([f'{pre}{node.name}' for pre, fill, node in RenderTree(self._tree)]) return '\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(self._tree)])

View File

@ -1,22 +0,0 @@
class ParseItem:
def __init__(self, syntax_result_line: str):
parsed_str = syntax_result_line.split('\t')
if len(parsed_str) != 10:
raise AssertionError(f'{syntax_result_line} is not CoNNL-U-2 line')
self._index: int = int(parsed_str[0])
self._lemma: str = parsed_str[1]
self._upos: str = parsed_str[3]
self._parent_index: int = int(parsed_str[6])
self._relation: str = parsed_str[7]
def index(self) -> int:
return self._index
def parent_index(self) -> int:
return self._parent_index
def __repr__(self) -> str:
return f'{self._index} {self._lemma} {self._upos} {self._parent_index} {self._relation}'

View File

@ -0,0 +1,20 @@
from anytree import NodeMixin
class ParseTreeNode(NodeMixin):
def __init__(self, index: int, lemma: str, upos: str, parent_index: int, relation: str, parent=None, children=None):
self.index = index
self.lemma = lemma
self.upos = upos
self.parent_index = parent_index
self.relation = relation
self.parent = parent
if children:
self.children = children
def __repr__(self) -> str:
return f'{self.index} {self.lemma} {self.upos} {self.parent_index} {self.relation}'

View File

@ -3,6 +3,7 @@ from scipy.io import wavfile
class Speech: class Speech:
@staticmethod @staticmethod
def __check_wav(wav_file): def __check_wav(wav_file):
sample_rate, sig = wavfile.read(wav_file) sample_rate, sig = wavfile.read(wav_file)