Skip to content

Commit 5fca324

Browse files
authored
Speech api: Fill audio buffer in a separate thread. (#527)
This is to avoid timing issues where the request thread doesn't poll the generator fast enough to consume all the incoming audio data from the input device. In that case, the audio device buffer overflows, leading to lost data and exceptions and other nastiness. Address #515
1 parent e8a10bf commit 5fca324

File tree

3 files changed

+114
-55
lines changed

3 files changed

+114
-55
lines changed

speech/grpc/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ gcloud==0.18.1
22
grpcio==1.0.0
33
PyAudio==0.2.9
44
grpc-google-cloud-speech-v1beta1==1.0.1
5+
six==1.10.0

speech/grpc/transcribe_streaming.py

Lines changed: 90 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,26 @@
1818

1919
import contextlib
2020
import re
21+
import signal
2122
import threading
2223

23-
from gcloud.credentials import get_credentials
24+
from gcloud import credentials
2425
from google.cloud.speech.v1beta1 import cloud_speech_pb2 as cloud_speech
2526
from google.rpc import code_pb2
2627
from grpc.beta import implementations
28+
from grpc.framework.interfaces.face import face
2729
import pyaudio
30+
from six.moves import queue
2831

2932
# Audio recording parameters
3033
RATE = 16000
31-
CHANNELS = 1
3234
CHUNK = int(RATE / 10) # 100ms
3335

34-
# Keep the request alive for this many seconds
35-
DEADLINE_SECS = 8 * 60 * 60
36+
# The Speech API has a streaming limit of 60 seconds of audio*, so keep the
37+
# connection alive for that long, plus some more to give the API time to figure
38+
# out the transcription.
39+
# * https://g.co/cloud/speech/limits#content
40+
DEADLINE_SECS = 60 * 3 + 5
3641
SPEECH_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'
3742

3843

@@ -42,7 +47,7 @@ def make_channel(host, port):
4247
ssl_channel = implementations.ssl_channel_credentials(None, None, None)
4348

4449
# Grab application default credentials from the environment
45-
creds = get_credentials().create_scoped([SPEECH_SCOPE])
50+
creds = credentials.get_credentials().create_scoped([SPEECH_SCOPE])
4651
# Add a plugin to inject the creds into the header
4752
auth_header = (
4853
'Authorization',
@@ -58,33 +63,81 @@ def make_channel(host, port):
5863
return implementations.secure_channel(host, port, composite_channel)
5964

6065

66+
def _audio_data_generator(buff):
67+
"""A generator that yields all available data in the given buffer.
68+
69+
Args:
70+
buff - a Queue object, where each element is a chunk of data.
71+
Yields:
72+
A chunk of data that is the aggregate of all chunks of data in `buff`.
73+
The function will block until at least one data chunk is available.
74+
"""
75+
while True:
76+
# Use a blocking get() to ensure there's at least one chunk of data
77+
chunk = buff.get()
78+
if not chunk:
79+
# A falsey value indicates the stream is closed.
80+
break
81+
data = [chunk]
82+
83+
# Now consume whatever other data's still buffered.
84+
while True:
85+
try:
86+
data.append(buff.get(block=False))
87+
except queue.Empty:
88+
break
89+
yield b''.join(data)
90+
91+
92+
def _fill_buffer(audio_stream, buff, chunk):
93+
"""Continuously collect data from the audio stream, into the buffer."""
94+
try:
95+
while True:
96+
buff.put(audio_stream.read(chunk))
97+
except IOError:
98+
# This happens when the stream is closed. Signal that we're done.
99+
buff.put(None)
100+
101+
61102
# [START audio_stream]
62103
@contextlib.contextmanager
63-
def record_audio(channels, rate, chunk):
104+
def record_audio(rate, chunk):
64105
"""Opens a recording stream in a context manager."""
65106
audio_interface = pyaudio.PyAudio()
66107
audio_stream = audio_interface.open(
67-
format=pyaudio.paInt16, channels=channels, rate=rate,
108+
format=pyaudio.paInt16,
109+
# The API currently only supports 1-channel (mono) audio
110+
# https://goo.gl/z757pE
111+
channels=1, rate=rate,
68112
input=True, frames_per_buffer=chunk,
69113
)
70114

71-
yield audio_stream
115+
# Create a thread-safe buffer of audio data
116+
buff = queue.Queue()
117+
118+
# Spin up a separate thread to buffer audio data from the microphone
119+
# This is necessary so that the input device's buffer doesn't overflow
120+
# while the calling thread makes network requests, etc.
121+
fill_buffer_thread = threading.Thread(
122+
target=_fill_buffer, args=(audio_stream, buff, chunk))
123+
fill_buffer_thread.start()
124+
125+
yield _audio_data_generator(buff)
72126

73127
audio_stream.stop_stream()
74128
audio_stream.close()
129+
fill_buffer_thread.join()
75130
audio_interface.terminate()
76131
# [END audio_stream]
77132

78133

79-
def request_stream(stop_audio, channels=CHANNELS, rate=RATE, chunk=CHUNK):
134+
def request_stream(data_stream, rate):
80135
"""Yields `StreamingRecognizeRequest`s constructed from a recording audio
81136
stream.
82137
83138
Args:
84-
stop_audio: A threading.Event object stops the recording when set.
85-
channels: How many audio channels to record.
139+
data_stream: A generator that yields raw audio data to send.
86140
rate: The sampling rate in hertz.
87-
chunk: Buffer audio into chunks of this size before sending to the api.
88141
"""
89142
# The initial request must contain metadata about the stream, so the
90143
# server knows how to interpret it.
@@ -105,14 +158,9 @@ def request_stream(stop_audio, channels=CHANNELS, rate=RATE, chunk=CHUNK):
105158
yield cloud_speech.StreamingRecognizeRequest(
106159
streaming_config=streaming_config)
107160

108-
with record_audio(channels, rate, chunk) as audio_stream:
109-
while not stop_audio.is_set():
110-
data = audio_stream.read(chunk)
111-
if not data:
112-
raise StopIteration()
113-
114-
# Subsequent requests can all just have the content
115-
yield cloud_speech.StreamingRecognizeRequest(audio_content=data)
161+
for data in data_stream:
162+
# Subsequent requests can all just have the content
163+
yield cloud_speech.StreamingRecognizeRequest(audio_content=data)
116164

117165

118166
def listen_print_loop(recognize_stream):
@@ -126,25 +174,36 @@ def listen_print_loop(recognize_stream):
126174

127175
# Exit recognition if any of the transcribed phrases could be
128176
# one of our keywords.
129-
if any(re.search(r'\b(exit|quit)\b', alt.transcript)
177+
if any(re.search(r'\b(exit|quit)\b', alt.transcript, re.I)
130178
for result in resp.results
131179
for alt in result.alternatives):
132180
print('Exiting..')
133-
return
181+
break
134182

135183

136184
def main():
137-
stop_audio = threading.Event()
138185
with cloud_speech.beta_create_Speech_stub(
139186
make_channel('speech.googleapis.com', 443)) as service:
140-
try:
141-
listen_print_loop(
142-
service.StreamingRecognize(
143-
request_stream(stop_audio), DEADLINE_SECS))
144-
finally:
145-
# Stop the request stream once we're done with the loop - otherwise
146-
# it'll keep going in the thread that the grpc lib makes for it..
147-
stop_audio.set()
187+
# For streaming audio from the microphone, there are three threads.
188+
# First, a thread that collects audio data as it comes in
189+
with record_audio(RATE, CHUNK) as buffered_audio_data:
190+
# Second, a thread that sends requests with that data
191+
requests = request_stream(buffered_audio_data, RATE)
192+
# Third, a thread that listens for transcription responses
193+
recognize_stream = service.StreamingRecognize(
194+
requests, DEADLINE_SECS)
195+
196+
# Exit things cleanly on interrupt
197+
signal.signal(signal.SIGINT, lambda *_: recognize_stream.cancel())
198+
199+
# Now, put the transcription responses to use.
200+
try:
201+
listen_print_loop(recognize_stream)
202+
203+
recognize_stream.cancel()
204+
except face.CancellationError:
205+
# This happens because of the interrupt handler
206+
pass
148207

149208

150209
if __name__ == '__main__':

speech/grpc/transcribe_streaming_test.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,54 +11,53 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
import contextlib
15-
import io
1614
import re
1715
import time
1816

1917
import transcribe_streaming
2018

2119

22-
class MockAudioStream(object):
23-
def __init__(self, audio_filename, trailing_silence_secs=10):
20+
class MockPyAudio(object):
21+
def __init__(self, audio_filename):
2422
self.audio_filename = audio_filename
25-
self.silence = io.BytesIO('\0\0' * transcribe_streaming.RATE *
26-
trailing_silence_secs)
2723

28-
def __enter__(self):
29-
self.audio_file = open(self.audio_filename)
24+
def __call__(self, *args):
25+
return self
26+
27+
def open(self, *args, **kwargs):
28+
self.audio_file = open(self.audio_filename, 'rb')
3029
return self
3130

32-
def __exit__(self, *args):
31+
def close(self):
3332
self.audio_file.close()
3433

35-
def __call__(self, *args):
36-
return self
34+
def stop_stream(self):
35+
pass
36+
37+
def terminate(self):
38+
pass
3739

3840
def read(self, num_frames):
41+
if self.audio_file.closed:
42+
raise IOError()
3943
# Approximate realtime by sleeping for the appropriate time for the
4044
# requested number of frames
4145
time.sleep(num_frames / float(transcribe_streaming.RATE))
4246
# audio is 16-bit samples, whereas python byte is 8-bit
4347
num_bytes = 2 * num_frames
44-
chunk = self.audio_file.read(num_bytes) or self.silence.read(num_bytes)
48+
try:
49+
chunk = self.audio_file.read(num_bytes)
50+
except ValueError:
51+
raise IOError()
52+
if not chunk:
53+
raise IOError()
4554
return chunk
4655

4756

48-
def mock_audio_stream(filename):
49-
@contextlib.contextmanager
50-
def mock_audio_stream(channels, rate, chunk):
51-
with open(filename, 'rb') as audio_file:
52-
yield audio_file
53-
54-
return mock_audio_stream
55-
56-
5757
def test_main(resource, monkeypatch, capsys):
5858
monkeypatch.setattr(
59-
transcribe_streaming, 'record_audio',
60-
mock_audio_stream(resource('quit.raw')))
61-
monkeypatch.setattr(transcribe_streaming, 'DEADLINE_SECS', 30)
59+
transcribe_streaming.pyaudio, 'PyAudio',
60+
MockPyAudio(resource('quit.raw')))
6261

6362
transcribe_streaming.main()
6463
out, err = capsys.readouterr()

0 commit comments

Comments
 (0)