diff --git a/main.py b/main.py index 9fa1b26..2ea47b7 100644 --- a/main.py +++ b/main.py @@ -1,49 +1,23 @@ #!/usr/bin/env python3 -import json import sys -from urllib.request import Request, urlopen +import warnings -from scipy.io import wavfile +from speech import Speech - -def stt(wav_file: str, url: str) -> str: - print(f'Connecting to \'{url}\'...') - request = Request(f'{url}/stt', data=_load_wav(wav_file), headers={'Content-Type': 'audio/wav'}) - result = json.loads(urlopen(request).read().decode('utf-8')) - - if not ('code' in result and 'text' in result): - raise RuntimeError(f'Wrong reply from server: {result}') - return result.text if not result.code else f'Server error: [{result.text}]: {result.code}' - - -def _load_wav(wav_file): - _check_wav(wav_file) - with open(wav_file, 'rb') as file: - result = file.read() - return result - - -def _check_wav(wav_file): - sample_rate, sig = wavfile.read(wav_file) - channels = len(sig.shape) - bits = sig.dtype.base.name - if sample_rate != 16000: - raise Exception(f'Sample rate is not 16000: {sample_rate}') - if channels != 1: - raise Exception(f'Number of Channels is not 1 (Not mono): {channels}') - if bits != 'int16': - raise Exception(f'Bits per sample 16: {bits}') +if not sys.warnoptions: + warnings.simplefilter("ignore") def _main(): if len(sys.argv) < 2: - print('Usage: {} FILE [URL]'.format(sys.argv[0])) + print(f'Usage: {sys.argv[0]} FILE') exit(1) - file = sys.argv[1] - server = 'http://127.0.0.1:8086' if len(sys.argv) < 3 else sys.argv[2] - print(f'Result: {stt(file, server)}') + wav_file = sys.argv[1] + speech_server = 'http://vosk.athene.tech' + text = Speech().run(wav_file, speech_server) + print(f'Text: {text}') if __name__ == '__main__': - _main() \ No newline at end of file + _main() diff --git a/speech.py b/speech.py new file mode 100644 index 0000000..47967ab --- /dev/null +++ b/speech.py @@ -0,0 +1,40 @@ +import json +from urllib.request import Request, urlopen + +from scipy.io import wavfile + + +class Speech: + @staticmethod + def __check_wav(wav_file): + sample_rate, sig = wavfile.read(wav_file) + channels = len(sig.shape) + bits = sig.dtype.base.name + if sample_rate != 16000: + raise Exception(f'Sample rate is not 16000: {sample_rate}') + if channels != 1: + raise Exception(f'Number of Channels is not 1 (Not mono): {channels}') + if bits != 'int16': + raise Exception(f'Bits per sample 16: {bits}') + + @staticmethod + def __load_wav(wav_file): + Speech.__check_wav(wav_file) + with open(wav_file, 'rb') as file: + result = file.read() + return result + + @staticmethod + def __stt(wav_file: str, server: str) -> str: + print(f'Connecting to \'{server}\'...') + request = Request(url=f'{server}/stt', + data=Speech.__load_wav(wav_file), + headers={'Content-Type': 'audio/wav'}) + result = json.loads(urlopen(request).read().decode('utf-8')) + + if not ('code' in result and 'text' in result): + raise RuntimeError(f'Wrong reply from server: {result}') + return result['text'] if not result['code'] else f'Server error: {result}' + + def run(self, wav_file, server): + return self.__stt(wav_file, server)