diff --git a/src/parse_tree/parse_tree.py b/src/parse_tree/parse_tree.py index b601f81..39beb61 100644 --- a/src/parse_tree/parse_tree.py +++ b/src/parse_tree/parse_tree.py @@ -1,44 +1,49 @@ -import warnings -from typing import List, Dict +from typing import List, Dict, Optional -from anytree import Node, RenderTree +from anytree import RenderTree -from src.parse_tree.parse_tree_item import ParseItem +from src.parse_tree.parse_tree_node import ParseTreeNode class ParseTree: def __init__(self, raw_tree: str): - self._tree: Node = self.__create_tree(self.__create_nodes_array(raw_tree)) + self._tree: ParseTreeNode = self.__create_tree(self.__create_nodes_array(raw_tree)) @staticmethod - def __create_nodes_array(raw_tree: str) -> List[ParseItem]: - nodes: List[ParseItem] = [] + def __parse_raw_tree_line(raw_tree_line: str) -> Optional[ParseTreeNode]: + parsed_str = raw_tree_line.split('\t') + if len(parsed_str) != 10: + return None + return ParseTreeNode(int(parsed_str[0]), parsed_str[1], parsed_str[3], int(parsed_str[6]), parsed_str[7]) + + def __create_nodes_array(self, raw_tree: str) -> List[ParseTreeNode]: + nodes: List[ParseTreeNode] = [] parsed_syntax_lines = raw_tree.split('\n') for line in parsed_syntax_lines: - try: - nodes.append(ParseItem(line)) - except AssertionError: - warnings.warn('Empty line') + tree_node = self.__parse_raw_tree_line(line) + if tree_node: + nodes.append(tree_node) return nodes - def __create_tree(self, nodes: List[ParseItem]) -> Node: + @staticmethod + def __create_tree(nodes: List[ParseTreeNode]) -> ParseTreeNode: tree_nodes_count: int = 0 - parents: Dict[int, Node] = {} - root = None + parents: Dict[int, ParseTreeNode] = {} + root: ParseTreeNode = Optional[ParseTreeNode] while tree_nodes_count < len(nodes): for node in nodes: - if parents.get(node.index()) is not None: + if parents.get(node.index) is not None: continue - parent: Node = parents.get(node.parent_index()) - if parent is not None or node.parent_index() == 0: - new_node = Node(node, parent) - parents[node.index()] = new_node + parent: ParseTreeNode = parents.get(node.parent_index) + if parent is not None or node.parent_index == 0: + node.parent = parent + parents[node.index] = node tree_nodes_count = tree_nodes_count + 1 - if node.parent_index() == 0: - root = new_node + if node.parent_index == 0: + root = node break return root def __repr__(self) -> str: - return '\n'.join([f'{pre}{node.name}' for pre, fill, node in RenderTree(self._tree)]) + return '\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(self._tree)]) diff --git a/src/parse_tree/parse_tree_item.py b/src/parse_tree/parse_tree_item.py deleted file mode 100644 index 79bd4d4..0000000 --- a/src/parse_tree/parse_tree_item.py +++ /dev/null @@ -1,22 +0,0 @@ -class ParseItem: - def __init__(self, syntax_result_line: str): - parsed_str = syntax_result_line.split('\t') - if len(parsed_str) != 10: - raise AssertionError(f'{syntax_result_line} is not CoNNL-U-2 line') - self._index: int = int(parsed_str[0]) - self._lemma: str = parsed_str[1] - self._upos: str = parsed_str[3] - self._parent_index: int = int(parsed_str[6]) - self._relation: str = parsed_str[7] - - def index(self) -> int: - return self._index - - def parent_index(self) -> int: - return self._parent_index - - def __repr__(self) -> str: - return f'{self._index} {self._lemma} {self._upos} {self._parent_index} {self._relation}' - - - diff --git a/src/parse_tree/parse_tree_node.py b/src/parse_tree/parse_tree_node.py new file mode 100644 index 0000000..d0aa55a --- /dev/null +++ b/src/parse_tree/parse_tree_node.py @@ -0,0 +1,20 @@ +from anytree import NodeMixin + + +class ParseTreeNode(NodeMixin): + + def __init__(self, index: int, lemma: str, upos: str, parent_index: int, relation: str, parent=None, children=None): + self.index = index + self.lemma = lemma + self.upos = upos + self.parent_index = parent_index + self.relation = relation + self.parent = parent + if children: + self.children = children + + def __repr__(self) -> str: + return f'{self.index} {self.lemma} {self.upos} {self.parent_index} {self.relation}' + + + diff --git a/src/speech.py b/src/speech.py index 885e495..f391b71 100644 --- a/src/speech.py +++ b/src/speech.py @@ -3,6 +3,7 @@ from scipy.io import wavfile class Speech: + @staticmethod def __check_wav(wav_file): sample_rate, sig = wavfile.read(wav_file)