Move parse_tree_node.py to NodeMixin class
This commit is contained in:
parent
cab7a1c129
commit
9ad8bcc667
@ -1,44 +1,49 @@
|
|||||||
import warnings
|
from typing import List, Dict, Optional
|
||||||
from typing import List, Dict
|
|
||||||
|
|
||||||
from anytree import Node, RenderTree
|
from anytree import RenderTree
|
||||||
|
|
||||||
from src.parse_tree.parse_tree_item import ParseItem
|
from src.parse_tree.parse_tree_node import ParseTreeNode
|
||||||
|
|
||||||
|
|
||||||
class ParseTree:
|
class ParseTree:
|
||||||
|
|
||||||
def __init__(self, raw_tree: str):
|
def __init__(self, raw_tree: str):
|
||||||
self._tree: Node = self.__create_tree(self.__create_nodes_array(raw_tree))
|
self._tree: ParseTreeNode = self.__create_tree(self.__create_nodes_array(raw_tree))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __create_nodes_array(raw_tree: str) -> List[ParseItem]:
|
def __parse_raw_tree_line(raw_tree_line: str) -> Optional[ParseTreeNode]:
|
||||||
nodes: List[ParseItem] = []
|
parsed_str = raw_tree_line.split('\t')
|
||||||
|
if len(parsed_str) != 10:
|
||||||
|
return None
|
||||||
|
return ParseTreeNode(int(parsed_str[0]), parsed_str[1], parsed_str[3], int(parsed_str[6]), parsed_str[7])
|
||||||
|
|
||||||
|
def __create_nodes_array(self, raw_tree: str) -> List[ParseTreeNode]:
|
||||||
|
nodes: List[ParseTreeNode] = []
|
||||||
parsed_syntax_lines = raw_tree.split('\n')
|
parsed_syntax_lines = raw_tree.split('\n')
|
||||||
for line in parsed_syntax_lines:
|
for line in parsed_syntax_lines:
|
||||||
try:
|
tree_node = self.__parse_raw_tree_line(line)
|
||||||
nodes.append(ParseItem(line))
|
if tree_node:
|
||||||
except AssertionError:
|
nodes.append(tree_node)
|
||||||
warnings.warn('Empty line')
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
def __create_tree(self, nodes: List[ParseItem]) -> Node:
|
@staticmethod
|
||||||
|
def __create_tree(nodes: List[ParseTreeNode]) -> ParseTreeNode:
|
||||||
tree_nodes_count: int = 0
|
tree_nodes_count: int = 0
|
||||||
parents: Dict[int, Node] = {}
|
parents: Dict[int, ParseTreeNode] = {}
|
||||||
root = None
|
root: ParseTreeNode = Optional[ParseTreeNode]
|
||||||
while tree_nodes_count < len(nodes):
|
while tree_nodes_count < len(nodes):
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if parents.get(node.index()) is not None:
|
if parents.get(node.index) is not None:
|
||||||
continue
|
continue
|
||||||
parent: Node = parents.get(node.parent_index())
|
parent: ParseTreeNode = parents.get(node.parent_index)
|
||||||
if parent is not None or node.parent_index() == 0:
|
if parent is not None or node.parent_index == 0:
|
||||||
new_node = Node(node, parent)
|
node.parent = parent
|
||||||
parents[node.index()] = new_node
|
parents[node.index] = node
|
||||||
tree_nodes_count = tree_nodes_count + 1
|
tree_nodes_count = tree_nodes_count + 1
|
||||||
if node.parent_index() == 0:
|
if node.parent_index == 0:
|
||||||
root = new_node
|
root = node
|
||||||
break
|
break
|
||||||
return root
|
return root
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return '\n'.join([f'{pre}{node.name}' for pre, fill, node in RenderTree(self._tree)])
|
return '\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(self._tree)])
|
||||||
|
@ -1,22 +0,0 @@
|
|||||||
class ParseItem:
|
|
||||||
def __init__(self, syntax_result_line: str):
|
|
||||||
parsed_str = syntax_result_line.split('\t')
|
|
||||||
if len(parsed_str) != 10:
|
|
||||||
raise AssertionError(f'{syntax_result_line} is not CoNNL-U-2 line')
|
|
||||||
self._index: int = int(parsed_str[0])
|
|
||||||
self._lemma: str = parsed_str[1]
|
|
||||||
self._upos: str = parsed_str[3]
|
|
||||||
self._parent_index: int = int(parsed_str[6])
|
|
||||||
self._relation: str = parsed_str[7]
|
|
||||||
|
|
||||||
def index(self) -> int:
|
|
||||||
return self._index
|
|
||||||
|
|
||||||
def parent_index(self) -> int:
|
|
||||||
return self._parent_index
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f'{self._index} {self._lemma} {self._upos} {self._parent_index} {self._relation}'
|
|
||||||
|
|
||||||
|
|
||||||
|
|
20
src/parse_tree/parse_tree_node.py
Normal file
20
src/parse_tree/parse_tree_node.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from anytree import NodeMixin
|
||||||
|
|
||||||
|
|
||||||
|
class ParseTreeNode(NodeMixin):
|
||||||
|
|
||||||
|
def __init__(self, index: int, lemma: str, upos: str, parent_index: int, relation: str, parent=None, children=None):
|
||||||
|
self.index = index
|
||||||
|
self.lemma = lemma
|
||||||
|
self.upos = upos
|
||||||
|
self.parent_index = parent_index
|
||||||
|
self.relation = relation
|
||||||
|
self.parent = parent
|
||||||
|
if children:
|
||||||
|
self.children = children
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'{self.index} {self.lemma} {self.upos} {self.parent_index} {self.relation}'
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -3,6 +3,7 @@ from scipy.io import wavfile
|
|||||||
|
|
||||||
|
|
||||||
class Speech:
|
class Speech:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __check_wav(wav_file):
|
def __check_wav(wav_file):
|
||||||
sample_rate, sig = wavfile.read(wav_file)
|
sample_rate, sig = wavfile.read(wav_file)
|
||||||
|
Loading…
Reference in New Issue
Block a user