diff --git a/README.md b/README.md
index 76e8c0b..72154b7 100644
--- a/README.md
+++ b/README.md
@@ -699,79 +699,187 @@ for result in transcript.auto_highlights.results:
### **Streaming Examples**
-[Read more about our streaming service.](https://www.assemblyai.com/docs/streaming/universal-3-pro)
+Real-time speech-to-text via WebSocket against the `u3-rt-pro` model. The SDK ships two clients with identical option/event/handler surfaces — `StreamingClient` (threaded) and `AsyncStreamingClient` (asyncio). Pick whichever fits your codebase.
+
+**Handler contract**: every handler is called as `handler(client, event)`. Plain functions and `async def` functions both work; `AsyncStreamingClient` awaits async handlers inline on the read task, so don't block — use `asyncio.create_task(...)` if you need concurrent work.
+
+[Read more about the streaming service.](https://www.assemblyai.com/docs/streaming/universal-3-pro)
+
+
+ Stream a local file (sync)
+
+```python
+import assemblyai as aai
+from assemblyai.streaming.v3 import (
+ BeginEvent, StreamingClient, StreamingClientOptions, StreamingError,
+ StreamingEvents, StreamingParameters, TerminationEvent, TurnEvent,
+)
+
+def on_begin(client, event: BeginEvent):
+ print(f"Session started: {event.id}")
+
+def on_turn(client, event: TurnEvent):
+ print(f"{event.transcript} (end_of_turn={event.end_of_turn})")
+
+def on_terminated(client, event: TerminationEvent):
+ print(f"Done: {event.audio_duration_seconds}s of audio processed")
+
+def on_error(client, error: StreamingError):
+ print(f"Error: {error} (code={error.code})")
+
+client = StreamingClient(StreamingClientOptions(api_key=""))
+client.on(StreamingEvents.Begin, on_begin)
+client.on(StreamingEvents.Turn, on_turn)
+client.on(StreamingEvents.Termination, on_terminated)
+client.on(StreamingEvents.Error, on_error)
+
+client.connect(StreamingParameters(
+ sample_rate=16000, speech_model="u3-rt-pro", format_turns=True,
+))
+try:
+ client.stream(aai.extras.stream_file(filepath="audio.wav", sample_rate=16000))
+finally:
+ client.disconnect(terminate=True)
+```
+
+
- Stream your microphone in real-time
+ Stream your microphone (sync)
+
+`MicrophoneStream` requires PyAudio:
```bash
-pip install -U assemblyai
+pip install -U "assemblyai[extras]"
```
```python
-import logging
-from typing import Type
-
import assemblyai as aai
from assemblyai.streaming.v3 import (
- BeginEvent,
- StreamingClient,
- StreamingClientOptions,
- StreamingError,
- StreamingEvents,
- StreamingParameters,
- TurnEvent,
- TerminationEvent,
+ StreamingClient, StreamingClientOptions, StreamingEvents, StreamingParameters,
)
-api_key = ""
+def on_turn(client, event):
+ print(f"{event.transcript} (end_of_turn={event.end_of_turn})")
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger(__name__)
+client = StreamingClient(StreamingClientOptions(api_key=""))
+client.on(StreamingEvents.Turn, on_turn)
+client.connect(StreamingParameters(sample_rate=16000, speech_model="u3-rt-pro"))
-def on_begin(self: Type[StreamingClient], event: BeginEvent):
- print(f"Session started: {event.id}")
+try:
+ client.stream(aai.extras.MicrophoneStream(sample_rate=16000))
+finally:
+ client.disconnect(terminate=True)
+```
+
+
+
+
+ Stream a local file (async)
+
+`AsyncStreamingClient` mirrors `StreamingClient` with async methods. It's safe to use as an async context manager — `disconnect()` runs on block exit even if user code raises. Don't pass `extras.stream_file` directly (it uses blocking `time.sleep`); pace from an async generator instead.
+
+```python
+import asyncio
+from assemblyai.streaming.v3 import (
+ AsyncStreamingClient, StreamingClientOptions, StreamingEvents, StreamingParameters,
+)
+
+async def stream_file_async(path: str, sample_rate: int, chunk_duration: float = 0.3):
+ bytes_per_chunk = int(sample_rate * chunk_duration) * 2
+ with open(path, "rb") as f:
+ while chunk := f.read(bytes_per_chunk):
+ yield chunk
+ await asyncio.sleep(chunk_duration)
+
+async def on_turn(client, event):
+ print(f"{event.transcript} (end_of_turn={event.end_of_turn})")
+
+async def main():
+ async with AsyncStreamingClient(StreamingClientOptions(api_key="")) as client:
+ client.on(StreamingEvents.Turn, on_turn)
+ await client.connect(StreamingParameters(
+ sample_rate=16000, speech_model="u3-rt-pro", format_turns=True,
+ ))
+ await client.stream(stream_file_async("audio.wav", 16000))
+
+asyncio.run(main())
+```
+
+
-def on_turn(self: Type[StreamingClient], event: TurnEvent):
- print(f"{event.transcript} ({event.end_of_turn})")
-
-def on_terminated(self: Type[StreamingClient], event: TerminationEvent):
- print(
- f"Session terminated: {event.audio_duration_seconds} seconds of audio processed"
- )
-
-def on_error(self: Type[StreamingClient], error: StreamingError):
- print(f"Error occurred: {error}")
-
-def main():
- client = StreamingClient(
- StreamingClientOptions(
- api_key=api_key,
- api_host="streaming.assemblyai.com",
- )
- )
-
- client.on(StreamingEvents.Begin, on_begin)
- client.on(StreamingEvents.Turn, on_turn)
- client.on(StreamingEvents.Termination, on_terminated)
- client.on(StreamingEvents.Error, on_error)
-
- client.connect(
- StreamingParameters(
- sample_rate=16000,
- speech_model="u3-rt-pro",
- )
- )
-
- try:
- client.stream(
- aai.extras.MicrophoneStream(sample_rate=16000)
- )
- finally:
- client.disconnect(terminate=True)
-
-if __name__ == "__main__":
- main()
+
+ Handle errors
+
+Server-side errors arrive on the `Error` event rather than being raised. The handler receives a `StreamingError` (an `Exception` subclass) with `.code: int | None` — **not** the wire `ErrorEvent` class.
+
+`StreamingErrorCodes` is a `dict[int, str]` mapping wire codes to human-readable messages. Use `.get(...)` for lookup:
+
+```python
+from assemblyai.streaming.v3 import StreamingErrorCodes
+
+def on_error(client, error):
+ message = StreamingErrorCodes.get(error.code, str(error))
+ print(f"Streaming error {error.code}: {message}")
+```
+
+Common codes: `4001` Not Authorized, `4002` Insufficient Funds, `4029` Client sent audio too fast, `4031` Session idle for too long.
+
+
+
+
+ Change settings mid-session
+
+`set_params` updates an active session. Typical use: enable turn formatting (punctuation, casing) only on confirmed end-of-turn so partial transcripts stay raw:
+
+```python
+from assemblyai.streaming.v3 import StreamingSessionParameters
+
+def on_turn(client, event):
+ if event.end_of_turn and not event.turn_is_formatted:
+ client.set_params(StreamingSessionParameters(format_turns=True))
+```
+
+For voice agents, `force_endpoint()` flushes the current turn — useful when an external signal (UI button, barge-in detection) determines the user has stopped speaking before VAD does:
+
+```python
+client.force_endpoint() # ends the current turn immediately
+```
+
+
+
+
+ Temporary tokens for browser / edge clients
+
+Don't ship your API key to browsers. Mint a short-lived token server-side and pass it to the client.
+
+**Sync server (Flask / WSGI / scripts):**
+```python
+client = StreamingClient(StreamingClientOptions(api_key=""))
+token = client.create_temporary_token(expires_in_seconds=60)
+# Send `token` to the browser, which connects with options(token=token).
+```
+
+**Async server (FastAPI / asyncio):** always wrap in `async with` even though you don't call `connect()` — `create_temporary_token` lazily opens an `httpx.AsyncClient` pool. The context manager closes it on exit; without it you leak a pool every request.
+
+```python
+from fastapi import FastAPI
+from assemblyai.streaming.v3 import AsyncStreamingClient, StreamingClientOptions
+
+app = FastAPI()
+MASTER_KEY = ""
+
+@app.get("/streaming-token")
+async def streaming_token():
+ async with AsyncStreamingClient(StreamingClientOptions(api_key=MASTER_KEY)) as client:
+ return {"token": await client.create_temporary_token(expires_in_seconds=60)}
+```
+
+**Browser / edge client:** pass the token via `StreamingClientOptions(token=...)`:
+
+```python
+client = StreamingClient(StreamingClientOptions(token=""))
+client.connect(StreamingParameters(sample_rate=16000, speech_model="u3-rt-pro"))
```
diff --git a/assemblyai/__init__.py b/assemblyai/__init__.py
index 77efb71..4662522 100644
--- a/assemblyai/__init__.py
+++ b/assemblyai/__init__.py
@@ -2,10 +2,9 @@
from .__version__ import __version__
from .client import Client
from .lemur import Lemur
-from .transcriber import RealtimeTranscriber, Transcriber, Transcript, TranscriptGroup
+from .transcriber import Transcriber, Transcript, TranscriptGroup
from .types import (
AssemblyAIError,
- AudioEncoding,
AutohighlightResponse,
AutohighlightResult,
Chapter,
@@ -47,13 +46,6 @@
PIIRedactionPolicy,
PIISubstitutionPolicy,
RawTranscriptionConfig,
- RealtimeError,
- RealtimeFinalTranscript,
- RealtimePartialTranscript,
- RealtimeSessionInformation,
- RealtimeSessionOpened,
- RealtimeTranscript,
- RealtimeWord,
RedactPiiAudioOptions,
Sentence,
Sentiment,
@@ -93,7 +85,6 @@
__all__ = [
# types
"AssemblyAIError",
- "AudioEncoding",
"AutohighlightResponse",
"AutohighlightResult",
"Chapter",
@@ -170,14 +161,6 @@
"Word",
"WordBoost",
"WordSearchMatch",
- "RealtimeTranscriber",
- "RealtimeError",
- "RealtimeFinalTranscript",
- "RealtimePartialTranscript",
- "RealtimeSessionInformation",
- "RealtimeSessionOpened",
- "RealtimeTranscript",
- "RealtimeWord",
# package globals
"settings",
# packages
diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py
index 8261441..79afd54 100644
--- a/assemblyai/__version__.py
+++ b/assemblyai/__version__.py
@@ -1 +1 @@
-__version__ = "0.64.2"
+__version__ = "0.64.3"
diff --git a/assemblyai/api.py b/assemblyai/api.py
index b2f666a..7c20645 100644
--- a/assemblyai/api.py
+++ b/assemblyai/api.py
@@ -9,8 +9,6 @@
ENDPOINT_UPLOAD = "/v2/upload"
ENDPOINT_LEMUR_BASE = "/lemur/v3"
ENDPOINT_LEMUR = f"{ENDPOINT_LEMUR_BASE}/generate"
-ENDPOINT_REALTIME_WEBSOCKET = "/v2/realtime/ws"
-ENDPOINT_REALTIME_TOKEN = "/v2/realtime/token"
def _get_error_message(response: httpx.Response) -> str:
@@ -415,24 +413,3 @@ def lemur_get_response_data(
return types.LemurQuestionResponse.parse_obj(json_data)
return types.LemurStringResponse.parse_obj(json_data)
-
-
-def create_temporary_token(
- client: httpx.Client,
- request: types.RealtimeCreateTemporaryTokenRequest,
- http_timeout: Optional[float],
-) -> str:
- response = client.post(
- f"{ENDPOINT_REALTIME_TOKEN}",
- json=request.dict(exclude_none=True),
- timeout=http_timeout,
- )
-
- if response.status_code != httpx.codes.OK:
- raise types.AssemblyAIError(
- f"Failed to create temporary token: {_get_error_message(response)}",
- response.status_code,
- )
-
- data = types.RealtimeCreateTemporaryTokenResponse.parse_obj(response.json())
- return data.token
diff --git a/assemblyai/streaming/v3/__init__.py b/assemblyai/streaming/v3/__init__.py
index e89ad55..c7d0806 100644
--- a/assemblyai/streaming/v3/__init__.py
+++ b/assemblyai/streaming/v3/__init__.py
@@ -1,3 +1,4 @@
+from .async_client import AsyncStreamingClient
from .client import StreamingClient
from .models import (
BeginEvent,
@@ -9,6 +10,7 @@
SpeechStartedEvent,
StreamingClientOptions,
StreamingError,
+ StreamingErrorCodes,
StreamingEvents,
StreamingParameters,
StreamingPiiPolicy,
@@ -21,6 +23,7 @@
)
__all__ = [
+ "AsyncStreamingClient",
"BeginEvent",
"Encoding",
"EventMessage",
@@ -31,6 +34,7 @@
"StreamingClient",
"StreamingClientOptions",
"StreamingError",
+ "StreamingErrorCodes",
"StreamingEvents",
"StreamingParameters",
"StreamingPiiPolicy",
diff --git a/assemblyai/streaming/v3/_base.py b/assemblyai/streaming/v3/_base.py
new file mode 100644
index 0000000..d46f90a
--- /dev/null
+++ b/assemblyai/streaming/v3/_base.py
@@ -0,0 +1,267 @@
+"""Sync/async-agnostic core for streaming v3 clients.
+
+Houses the pieces that are *exactly* the same between the threaded
+``StreamingClient`` and the asyncio-based ``AsyncStreamingClient``:
+
+- Wire-format helpers (``_dump_model``, ``_parse_model``, ``_build_uri``,
+ ``_build_headers``, parameter normalization, user-agent construction).
+- Inbound message parsing (``_parse_message`` + ``_parse_event_type``).
+- Connection-closed error mapping (``_build_connection_closed_error``).
+- The ``_BaseStreamingClient`` base class with shared init state and
+ the ``on(...)`` handler-registration entrypoint.
+
+Subclasses must implement the I/O loops (``_read_*`` / ``_write_*``) plus
+``connect``, ``disconnect``, ``stream``, ``set_params``, ``force_endpoint``,
+and ``create_temporary_token``. Sync subclasses use plain methods; async
+subclasses use ``async def``. The sync/async return-type divergence is
+why those methods aren't ``@abstractmethod`` on this base.
+"""
+
+import json
+import logging
+import sys
+from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
+from urllib.parse import urlencode
+
+import websockets
+from pydantic import BaseModel
+
+from assemblyai import __version__
+
+from .models import (
+ BeginEvent,
+ ErrorEvent,
+ EventMessage,
+ LLMGatewayResponseEvent,
+ SpeechStartedEvent,
+ StreamingClientOptions,
+ StreamingError,
+ StreamingErrorCodes,
+ StreamingEvents,
+ StreamingParameters,
+ TerminationEvent,
+ TurnEvent,
+ WarningEvent,
+)
+
+logger = logging.getLogger(__name__)
+
+
+_M = TypeVar("_M", bound=BaseModel)
+
+
+def _dump_model(model: BaseModel) -> Dict[str, Any]:
+ if hasattr(model, "model_dump"):
+ return model.model_dump(exclude_none=True)
+ return model.dict(exclude_none=True)
+
+
+def _dump_model_json(model: BaseModel) -> str:
+ if hasattr(model, "model_dump_json"):
+ return model.model_dump_json(exclude_none=True)
+ return model.json(exclude_none=True)
+
+
+def _parse_model(model_class: Type[_M], data: Dict[str, Any]) -> _M:
+ if hasattr(model_class, "model_validate"):
+ return model_class.model_validate(data)
+ return model_class.parse_obj(data)
+
+
+def _normalize_min_turn_silence(params_dict: dict) -> dict:
+ """Collapse `min_end_of_turn_silence_when_confident` into `min_turn_silence` so only
+ one wire key is ever sent. Emits deprecation warnings."""
+ old = params_dict.pop("min_end_of_turn_silence_when_confident", None)
+ if old is None:
+ return params_dict
+ if "min_turn_silence" in params_dict:
+ logger.warning(
+ "[Deprecation Warning] Both `min_end_of_turn_silence_when_confident` and "
+ "`min_turn_silence` are set. Using `min_turn_silence`; "
+ "`min_end_of_turn_silence_when_confident` is deprecated."
+ )
+ else:
+ logger.warning(
+ "[Deprecation Warning] `min_end_of_turn_silence_when_confident` is "
+ "deprecated and will be removed in a future release. Please use "
+ "`min_turn_silence` instead."
+ )
+ params_dict["min_turn_silence"] = old
+ return params_dict
+
+
+def _normalize_voice_focus(params_dict: dict) -> dict:
+ """Collapse `noise_suppression_model` / `noise_suppression_threshold` into
+ `voice_focus` / `voice_focus_threshold` so only the new wire keys are sent.
+ Emits deprecation warnings."""
+ for old_key, new_key in (
+ ("noise_suppression_model", "voice_focus"),
+ ("noise_suppression_threshold", "voice_focus_threshold"),
+ ):
+ old = params_dict.pop(old_key, None)
+ if old is None:
+ continue
+ if new_key in params_dict:
+ logger.warning(
+ f"[Deprecation Warning] Both `{old_key}` and `{new_key}` are set. "
+ f"Using `{new_key}`; `{old_key}` is deprecated."
+ )
+ else:
+ logger.warning(
+ f"[Deprecation Warning] `{old_key}` is deprecated and will be removed "
+ f"in a future release. Please use `{new_key}` instead."
+ )
+ params_dict[new_key] = old
+ return params_dict
+
+
+def _user_agent() -> str:
+ vi = sys.version_info
+ python_version = f"{vi.major}.{vi.minor}.{vi.micro}"
+ return (
+ f"AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})"
+ )
+
+
+def _emit_param_warnings(params: StreamingParameters) -> None:
+ if params.speech_model == "u3-pro":
+ logger.warning(
+ "[Deprecation Warning] The speech model `u3-pro` is deprecated and will be removed in a future release. "
+ "Please use `u3-rt-pro` instead."
+ )
+ if params.customer_support_audio_capture:
+ logger.warning(
+ "`customer_support_audio_capture=True` will record session audio. "
+ "Only enable this when explicitly coordinating with AssemblyAI support."
+ )
+
+
+def _build_uri(host: str, params: StreamingParameters) -> str:
+ params_dict = _normalize_voice_focus(
+ _normalize_min_turn_silence(_dump_model(params))
+ )
+ # JSON-encode list and dict parameters for proper API compatibility (e.g.,
+ # keyterms_prompt, llm_gateway)
+ for key, value in params_dict.items():
+ if isinstance(value, list):
+ params_dict[key] = json.dumps(value)
+ elif isinstance(value, dict):
+ params_dict[key] = json.dumps(value)
+
+ params_encoded = urlencode(params_dict)
+
+ if host.startswith(("ws://", "wss://")):
+ return f"{host}/v3/ws?{params_encoded}"
+ return f"wss://{host}/v3/ws?{params_encoded}"
+
+
+def _build_headers(options: StreamingClientOptions) -> Dict[str, Optional[str]]:
+ # Matches the pre-refactor sync behavior: ``Authorization`` is left as the
+ # raw value (may be ``None`` when neither ``token`` nor ``api_key`` is set,
+ # which surfaces the misconfiguration through the websockets/httpx layer).
+ return {
+ "Authorization": options.token or options.api_key,
+ "User-Agent": _user_agent(),
+ "AssemblyAI-Version": "2025-05-12",
+ }
+
+
+class _BaseStreamingClient:
+ """Sync/async-agnostic core for streaming clients.
+
+ Subclasses must implement: ``connect``, ``disconnect``, ``stream``,
+ ``set_params``, ``force_endpoint``, ``create_temporary_token``, plus
+ the I/O loops (``_read_*`` / ``_write_*``). Sync subclasses use plain
+ methods; async subclasses use ``async def`` — the return-type
+ divergence is why these aren't ``@abstractmethod`` on this base.
+ """
+
+ def __init__(self, options: StreamingClientOptions):
+ self._options = options
+ self._handlers: Dict[StreamingEvents, List[Callable]] = {
+ event: [] for event in StreamingEvents.__members__.values()
+ }
+ # Dedup flags for one-time error dispatch. ``_report_connection_closed``
+ # and ``_report_server_error`` perform their flag check + set
+ # synchronously (no ``await`` / yield between them) before any
+ # dispatch, so even when both I/O tasks/threads race to report the
+ # same close only the first caller executes the dispatch body.
+ # - Threading: the read thread is the sole dispatcher; the write
+ # thread stages closes via ``_pending_close_error`` for the read
+ # thread to drain.
+ # - Asyncio: either task may call the report function; the sync
+ # check-and-set inside the function gives the dedup atomicity.
+ self._connection_closed_reported = False
+ self._server_error_reported = False
+ self._websocket: Optional[Any] = None
+
+ def on(self, event: StreamingEvents, handler: Callable) -> None:
+ """Register a handler for a streaming event.
+
+ ``event`` is a value from ``StreamingEvents`` (``Begin``, ``Turn``,
+ ``Termination``, ``SpeechStarted``, ``Error``, ``Warning``,
+ ``LLMGatewayResponse``). ``handler`` is invoked as
+ ``handler(client, event)``. For ``AsyncStreamingClient``, async
+ handlers are awaited inline on the read task. Exceptions raised by
+ handlers are logged and swallowed — they do not terminate the
+ session.
+ """
+ if event in StreamingEvents.__members__.values() and callable(handler):
+ self._handlers[event].append(handler)
+
+ @staticmethod
+ def _parse_event_type(message_type: Optional[Any]) -> Optional[StreamingEvents]:
+ if not isinstance(message_type, str):
+ return None
+ try:
+ return StreamingEvents[message_type]
+ except KeyError:
+ return None
+
+ @classmethod
+ def _parse_message(cls, data: Dict[str, Any]) -> Optional[EventMessage]:
+ if "type" in data:
+ event_type = cls._parse_event_type(data.get("type"))
+
+ if event_type == StreamingEvents.Begin:
+ return _parse_model(BeginEvent, data)
+ elif event_type == StreamingEvents.Termination:
+ return _parse_model(TerminationEvent, data)
+ elif event_type == StreamingEvents.Turn:
+ return _parse_model(TurnEvent, data)
+ elif event_type == StreamingEvents.SpeechStarted:
+ return _parse_model(SpeechStartedEvent, data)
+ elif event_type == StreamingEvents.LLMGatewayResponse:
+ return _parse_model(LLMGatewayResponseEvent, data)
+ elif event_type == StreamingEvents.Error:
+ return _parse_model(ErrorEvent, data)
+ elif event_type == StreamingEvents.Warning:
+ return _parse_model(WarningEvent, data)
+ else:
+ return None
+ elif "error" in data:
+ return _parse_model(ErrorEvent, data)
+ return None
+
+ @staticmethod
+ def _build_connection_closed_error(
+ error: Union[
+ StreamingError,
+ ErrorEvent,
+ websockets.exceptions.ConnectionClosed,
+ OSError,
+ ],
+ ) -> Optional[StreamingError]:
+ if isinstance(error, StreamingError):
+ return error
+ if isinstance(error, ErrorEvent):
+ return StreamingError(message=error.error, code=error.error_code)
+ if isinstance(error, websockets.exceptions.ConnectionClosed):
+ if error.code == 1000:
+ return None
+ if error.code is not None and error.code in StreamingErrorCodes:
+ message = StreamingErrorCodes[error.code]
+ else:
+ message = error.reason or f"Connection closed (code={error.code})"
+ return StreamingError(message=message, code=error.code)
+ return StreamingError(message=f"Connection failed: {error}")
diff --git a/assemblyai/streaming/v3/async_client.py b/assemblyai/streaming/v3/async_client.py
new file mode 100644
index 0000000..6804982
--- /dev/null
+++ b/assemblyai/streaming/v3/async_client.py
@@ -0,0 +1,512 @@
+import asyncio
+import collections.abc
+import inspect
+import json
+import logging
+from typing import Any, AsyncIterable, Callable, Dict, Iterable, Optional, Union
+
+import httpx
+import websockets
+from pydantic import BaseModel
+
+# Prefer the new asyncio client API (websockets >= 13). Fall back to the legacy
+# top-level connect for older versions the SDK still supports per ``setup.py``
+# (``websockets>=11.0``). The two APIs differ only in the header-kwarg name
+# (``additional_headers`` vs ``extra_headers``); the ``websocket_connect_async``
+# wrapper below papers that over so tests and callers see one entry point.
+try:
+ from websockets.asyncio.client import connect as _ws_connect
+
+ _WS_HEADER_KW = "additional_headers"
+except ImportError: # pragma: no cover - exercised on websockets <13 only
+ from websockets.client import connect as _ws_connect # type: ignore[no-redef]
+
+ _WS_HEADER_KW = "extra_headers"
+
+from ._base import (
+ _BaseStreamingClient,
+ _build_headers,
+ _build_uri,
+ _dump_model,
+ _dump_model_json,
+ _emit_param_warnings,
+ _normalize_min_turn_silence,
+ _user_agent,
+)
+from .models import (
+ ErrorEvent,
+ EventMessage,
+ ForceEndpoint,
+ OperationMessage,
+ StreamingClientOptions,
+ StreamingError,
+ StreamingEvents,
+ StreamingParameters,
+ StreamingSessionParameters,
+ TerminateSession,
+ TerminationEvent,
+ UpdateConfiguration,
+ WarningEvent,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def websocket_connect_async(
+ uri: str, additional_headers: Dict[str, Optional[str]]
+) -> Any:
+ """Open a websocket connection using whichever ``websockets`` API is
+ available. Returns the underlying ``Connect`` awaitable so callers may
+ ``await`` it directly (or wrap in ``asyncio.wait_for``). Module-level
+ indirection so tests can patch a single attribute.
+
+ ``additional_headers`` matches the ``Dict[str, Optional[str]]`` shape
+ returned by ``_build_headers``; an ``Authorization`` value of ``None``
+ (no credentials configured) is forwarded to the underlying websockets
+ library so the misconfiguration surfaces at the handshake layer.
+ """
+ return _ws_connect(uri, **{_WS_HEADER_KW: additional_headers})
+
+
+class AsyncStreamingClient(_BaseStreamingClient):
+ """Asyncio-native counterpart to ``StreamingClient``.
+
+ The public API mirrors the thread-based client one-to-one — same options,
+ parameters, events, and event-handler registration. Methods that touch the
+ network are coroutines. Event handlers may be plain callables or
+ coroutine functions; coroutine handlers are awaited inline by the single
+ internal read task. Handlers should therefore avoid indefinite blocking,
+ just as with the sync client.
+
+ Behavioral notes vs. the sync ``StreamingClient``:
+
+ - ``stream`` / ``set_params`` / ``force_endpoint`` raise ``RuntimeError``
+ when called before ``connect()`` — silent drop would diverge from the
+ sync client (which buffers pre-connect data) in a way that's easy to
+ miss. After the connection has closed, the same calls are silent
+ no-ops so cleanup paths don't need defensive try/except.
+ - ``disconnect(terminate=True)`` waits at most 2.0s for the write task to
+ drain the ``TerminateSession`` frame before forcing teardown. The sync
+ client joins indefinitely.
+ - Supports ``async with``: ``disconnect()`` is invoked on block exit so
+ the websocket / HTTP client are always released even when user code
+ raises.
+ """
+
+ def __init__(self, options: StreamingClientOptions):
+ super().__init__(options)
+
+ self._client = _AsyncHTTPClient(
+ api_host=options.api_host, api_key=options.api_key
+ )
+
+ # Created lazily in ``connect()`` so they bind to the loop that runs
+ # ``connect()``, not whatever loop was current at ``__init__`` time
+ # (matters on Python 3.8/3.9 and avoids "no running event loop"
+ # DeprecationWarnings on 3.10+ when constructed outside a loop).
+ self._write_queue: Optional["asyncio.Queue[OperationMessage]"] = None
+ self._stop_event: Optional[asyncio.Event] = None
+ self._read_task: Optional[asyncio.Task] = None
+ self._write_task: Optional[asyncio.Task] = None
+
+ async def connect(self, params: StreamingParameters) -> None:
+ # Single-use: a client whose connection went down (success or
+ # handshake failure) sets ``_connection_closed_reported``; reusing
+ # it would yield a silently dead read/write loop because
+ # ``_stop_event`` is already set.
+ already_used = (
+ self._websocket is not None
+ or self._connection_closed_reported
+ or (self._read_task is not None and not self._read_task.done())
+ )
+ if already_used:
+ raise RuntimeError(
+ "AsyncStreamingClient has already been connected; "
+ "create a new instance for a new connection."
+ )
+
+ self._write_queue = asyncio.Queue()
+ self._stop_event = asyncio.Event()
+
+ _emit_param_warnings(params)
+
+ uri = _build_uri(self._options.api_host, params)
+ headers = _build_headers(self._options)
+
+ try:
+ self._websocket = await asyncio.wait_for(
+ websocket_connect_async(uri, additional_headers=headers),
+ timeout=15,
+ )
+ except websockets.exceptions.InvalidStatus as exc:
+ status_code = getattr(getattr(exc, "response", None), "status_code", None)
+ await self._report_connection_closed(
+ StreamingError(
+ message=f"WebSocket handshake rejected (HTTP {status_code})",
+ code=status_code,
+ )
+ )
+ # Single-use design: a failed handshake terminates the client. Close
+ # the HTTP client now so users who treat ``on_error`` as the
+ # terminal signal don't leak the httpx pool.
+ await self._client.aclose()
+ return
+ except (
+ websockets.exceptions.InvalidHandshake,
+ websockets.exceptions.ConnectionClosed,
+ OSError,
+ asyncio.TimeoutError,
+ TimeoutError,
+ ) as exc:
+ await self._report_connection_closed(exc)
+ await self._client.aclose()
+ return
+
+ self._read_task = asyncio.create_task(
+ self._read_loop(), name="AsyncStreamingClient._read_loop"
+ )
+ self._write_task = asyncio.create_task(
+ self._write_loop(), name="AsyncStreamingClient._write_loop"
+ )
+
+ logger.debug("Connected to WebSocket server")
+
+ async def disconnect(self, terminate: bool = False) -> None:
+ if self._stop_event is None:
+ # Never connected — still close the HTTP client so the pool
+ # doesn't leak.
+ await self._client.aclose()
+ return
+
+ # Enqueue Terminate even when stop is already set: ``_write_loop``
+ # bypasses the stop gate for TerminateSession so the frame still
+ # reaches the server when the write task is alive.
+ if terminate and self._write_queue is not None:
+ await self._write_queue.put(TerminateSession())
+ # Let the write task drain TerminateSession and exit naturally
+ # before we set stop / cancel below. ``asyncio.wait`` does not
+ # cancel the awaited task on timeout, unlike ``wait_for``.
+ if self._write_task is not None and not self._write_task.done():
+ await asyncio.wait({self._write_task}, timeout=2.0)
+
+ self._stop_event.set()
+
+ current = asyncio.current_task()
+ for task in (self._read_task, self._write_task):
+ if task is None or task is current or task.done():
+ continue
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+ except Exception:
+ logger.exception("Streaming task raised during disconnect")
+
+ await self._close_websocket()
+ await self._client.aclose()
+
+ async def _close_websocket(self) -> None:
+ if not self._websocket:
+ return
+ try:
+ await self._websocket.close()
+ except (OSError, websockets.exceptions.WebSocketException) as exc:
+ logger.debug("Error closing websocket: %s", exc)
+
+ async def stream(
+ self,
+ data: Union[bytes, AsyncIterable[bytes], Iterable[bytes]],
+ ) -> None:
+ # Loud on misuse (pre-connect), quiet on natural close (post-stop).
+ # The first guards against silent data loss; the second keeps cleanup
+ # paths simple.
+ write_queue, stop_event = self._ensure_connected("stream")
+ if stop_event.is_set():
+ return
+
+ if isinstance(data, bytes):
+ await write_queue.put(data)
+ return
+
+ if isinstance(data, collections.abc.AsyncIterable):
+ async for chunk in data:
+ if stop_event.is_set():
+ return
+ await write_queue.put(chunk)
+ return
+
+ for chunk in data:
+ if stop_event.is_set():
+ return
+ await write_queue.put(chunk)
+
+ async def set_params(self, params: StreamingSessionParameters) -> None:
+ write_queue, stop_event = self._ensure_connected("set_params")
+ if stop_event.is_set():
+ return
+ message_dict = _normalize_min_turn_silence(_dump_model(params))
+ message = UpdateConfiguration(**message_dict)
+ await write_queue.put(message)
+
+ async def force_endpoint(self) -> None:
+ write_queue, stop_event = self._ensure_connected("force_endpoint")
+ if stop_event.is_set():
+ return
+ await write_queue.put(ForceEndpoint())
+
+ def _ensure_connected(
+ self, method: str
+ ) -> "tuple[asyncio.Queue[OperationMessage], asyncio.Event]":
+ # Returns the post-connect primitives so callers narrow ``Optional``
+ # locally instead of repeating ``is None`` checks at every use site
+ # (mypy can't propagate narrowing through a separate method call).
+ if self._write_queue is None or self._stop_event is None:
+ raise RuntimeError(
+ f"AsyncStreamingClient is not connected; call connect() before {method}()"
+ )
+ return self._write_queue, self._stop_event
+
+ async def _write_loop(self) -> None:
+ # ``_write_loop`` is only ``create_task``ed inside ``connect()`` after
+ # the primitives are initialized. ``if`` (not ``assert``) so it
+ # survives ``python -O`` if the invariant is ever violated.
+ if self._write_queue is None or self._stop_event is None:
+ raise RuntimeError("AsyncStreamingClient internal state not initialized")
+ while True:
+ if not self._websocket:
+ raise ValueError("Not connected to the WebSocket server")
+
+ try:
+ data = await asyncio.wait_for(self._write_queue.get(), timeout=1)
+ except asyncio.TimeoutError:
+ if self._stop_event.is_set():
+ return
+ continue
+
+ # TerminateSession bypasses the stop gate so disconnect(terminate=True)
+ # can always send it, even when stop is set between put() and the
+ # write loop's next iteration.
+ is_terminate = isinstance(data, TerminateSession)
+ if not is_terminate and self._stop_event.is_set():
+ return
+
+ try:
+ if isinstance(data, bytes):
+ await self._websocket.send(data)
+ elif isinstance(data, BaseModel):
+ await self._websocket.send(_dump_model_json(data))
+ else:
+ raise ValueError(f"Attempted to send invalid message: {type(data)}")
+ except websockets.exceptions.ConnectionClosed as exc:
+ # Dispatch the close directly from the write task. The read
+ # task may short-circuit on ``_stop_event`` at the top of its
+ # loop (e.g. while a buffered message was processed between
+ # ``recv()`` calls) and never observe the close in ``recv()``,
+ # so the write task can't rely on it to dispatch.
+ # ``_report_connection_closed`` is idempotent — its flag check
+ # + set is synchronous (no ``await`` between them), so if the
+ # read task also raises ``ConnectionClosed`` it'll be a no-op.
+ await self._report_connection_closed(exc)
+ return
+
+ if is_terminate:
+ return
+
+ async def _read_loop(self) -> None:
+ # ``_read_loop`` is only ``create_task``ed inside ``connect()`` after
+ # ``_stop_event`` is initialized. ``if`` (not ``assert``) so it
+ # survives ``python -O`` if the invariant is ever violated.
+ if self._stop_event is None:
+ raise RuntimeError("AsyncStreamingClient internal state not initialized")
+ while True:
+ if not self._websocket:
+ raise ValueError("Not connected to the WebSocket server")
+
+ if self._stop_event.is_set():
+ return
+
+ try:
+ message_data = await self._websocket.recv()
+ except websockets.exceptions.ConnectionClosed as exc:
+ await self._report_connection_closed(exc)
+ return
+
+ try:
+ message_json = json.loads(message_data)
+ except json.JSONDecodeError as exc:
+ logger.warning(f"Failed to decode message: {exc}")
+ continue
+
+ message = self._parse_message(message_json)
+
+ if isinstance(message, ErrorEvent):
+ await self._report_server_error(message)
+ elif isinstance(message, WarningEvent):
+ await self._handle_warning(message)
+ elif message:
+ await self._handle_message(message)
+ else:
+ logger.warning(f"Unsupported event type: {message_json.get('type')}")
+
+ async def _handle_message(self, message: EventMessage) -> None:
+ # ``_handle_message`` is only reached from ``_read_loop``, which only
+ # runs after ``connect()`` has initialized ``_stop_event``.
+ if self._stop_event is None:
+ raise RuntimeError("AsyncStreamingClient internal state not initialized")
+ if isinstance(message, TerminationEvent):
+ self._stop_event.set()
+
+ event_type = StreamingEvents[message.type]
+
+ for handler in self._handlers[event_type]:
+ await self._invoke_handler(handler, message, event_type)
+
+ async def _handle_warning(self, warning: WarningEvent) -> None:
+ logger.warning(
+ "Streaming warning (code=%s): %s", warning.warning_code, warning.warning
+ )
+ for handler in self._handlers[StreamingEvents.Warning]:
+ await self._invoke_handler(handler, warning, StreamingEvents.Warning)
+
+ async def _report_server_error(self, error: ErrorEvent) -> None:
+ # Only reachable from ``_read_loop`` (after primitives are initialized).
+ if self._stop_event is None:
+ raise RuntimeError("AsyncStreamingClient internal state not initialized")
+ self._server_error_reported = True
+ streaming_error = StreamingError(message=error.error, code=error.error_code)
+ logger.error("Streaming error: %s (code=%s)", error.error, error.error_code)
+ await self._dispatch_error(streaming_error)
+ # Tear down locally so a server that sends Error without a trailing
+ # close frame doesn't leave the read loop blocked in ``recv()``
+ # forever. ``_close_websocket`` is idempotent; if the trailing close
+ # does arrive, ``_report_connection_closed`` will dedup via
+ # ``_server_error_reported``.
+ await self._close_websocket()
+ self._stop_event.set()
+
+ async def _report_connection_closed(
+ self,
+ error: Union[
+ StreamingError,
+ ErrorEvent,
+ websockets.exceptions.ConnectionClosed,
+ OSError,
+ ],
+ ) -> None:
+ # Callers (``connect()`` failure path, ``_read_loop``, ``_write_loop``)
+ # all run after ``_stop_event`` is initialized.
+ if self._stop_event is None:
+ raise RuntimeError("AsyncStreamingClient internal state not initialized")
+ if self._connection_closed_reported:
+ return
+ self._connection_closed_reported = True
+ self._stop_event.set()
+
+ streaming_error = self._build_connection_closed_error(error)
+
+ if streaming_error is None:
+ await self._close_websocket()
+ return
+
+ if isinstance(error, websockets.exceptions.ConnectionClosed):
+ reason = error.reason or "no reason given"
+ logger.error("Connection closed: %s (code=%s)", reason, error.code)
+ else:
+ logger.error(
+ "Connection failed: %s (code=%s)",
+ streaming_error,
+ streaming_error.code,
+ )
+
+ # If a server Error frame already fired on_error, the close is the
+ # effect, not a new cause — log it (above) but skip the duplicate
+ # user-visible error.
+ if not self._server_error_reported:
+ await self._dispatch_error(streaming_error)
+
+ await self._close_websocket()
+
+ async def _dispatch_error(self, error: StreamingError) -> None:
+ for handler in self._handlers[StreamingEvents.Error]:
+ await self._invoke_handler(handler, error, StreamingEvents.Error)
+
+ async def _invoke_handler(
+ self,
+ handler: Callable,
+ payload: Any,
+ event_type: StreamingEvents,
+ ) -> None:
+ try:
+ result = handler(self, payload)
+ if inspect.isawaitable(result):
+ await result
+ except Exception:
+ logger.exception("on_%s handler raised", event_type.name.lower())
+
+ async def create_temporary_token(
+ self,
+ expires_in_seconds: int,
+ max_session_duration_seconds: Optional[int] = None,
+ ) -> str:
+ return await self._client.create_temporary_token(
+ expires_in_seconds=expires_in_seconds,
+ max_session_duration_seconds=max_session_duration_seconds,
+ )
+
+ async def __aenter__(self) -> "AsyncStreamingClient":
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb) -> None:
+ await self.disconnect(terminate=exc_type is None)
+
+
+class _AsyncHTTPClient:
+ def __init__(self, api_host: str, api_key: Optional[str] = None):
+ # Lazy: don't instantiate httpx.AsyncClient here. Bare construction of
+ # an AsyncStreamingClient that's never connected (or used only for
+ # connect() — which doesn't go through the HTTP client) must not
+ # leak an httpx pool.
+ self._api_host = api_host
+ self._api_key = api_key
+ self._http_client: Optional[httpx.AsyncClient] = None
+ self._closed = False
+
+ def _get_or_create_client(self) -> httpx.AsyncClient:
+ if self._http_client is None:
+ headers = {"User-Agent": f"{httpx._client.USER_AGENT} {_user_agent()}"}
+ if self._api_key:
+ headers["Authorization"] = self._api_key
+ self._http_client = httpx.AsyncClient(
+ base_url="https://" + self._api_host,
+ headers=headers,
+ )
+ return self._http_client
+
+ async def create_temporary_token(
+ self,
+ expires_in_seconds: int,
+ max_session_duration_seconds: Optional[int] = None,
+ ) -> str:
+ # ``expires_in_seconds`` is required per the type; always forward it
+ # so passing ``0`` reaches the server (where it can be validated)
+ # instead of being silently dropped by a falsy check.
+ params: Dict[str, Any] = {"expires_in_seconds": expires_in_seconds}
+
+ if max_session_duration_seconds is not None:
+ params["max_session_duration_seconds"] = max_session_duration_seconds
+
+ response = await self._get_or_create_client().get("/v3/token", params=params)
+ response.raise_for_status()
+ return response.json()["token"]
+
+ async def aclose(self) -> None:
+ if self._closed:
+ return
+ self._closed = True
+ if self._http_client is None:
+ return
+ try:
+ await self._http_client.aclose()
+ except Exception as exc:
+ logger.debug("Error closing async HTTP client: %s", exc)
diff --git a/assemblyai/streaming/v3/client.py b/assemblyai/streaming/v3/client.py
index cce42d8..65ace7a 100644
--- a/assemblyai/streaming/v3/client.py
+++ b/assemblyai/streaming/v3/client.py
@@ -1,35 +1,36 @@
import json
import logging
import queue
-import sys
import threading
-from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
-from urllib.parse import urlencode
+from typing import Any, Dict, Generator, Iterable, Optional, Union
import httpx
import websockets
from pydantic import BaseModel
from websockets.sync.client import connect as websocket_connect
-from assemblyai import __version__
-
+from ._base import (
+ _BaseStreamingClient,
+ _build_headers,
+ _build_uri,
+ _dump_model,
+ _dump_model_json,
+ _emit_param_warnings,
+ _normalize_min_turn_silence,
+ _user_agent,
+)
from .models import (
- BeginEvent,
ErrorEvent,
EventMessage,
ForceEndpoint,
- LLMGatewayResponseEvent,
OperationMessage,
- SpeechStartedEvent,
StreamingClientOptions,
StreamingError,
- StreamingErrorCodes,
StreamingEvents,
StreamingParameters,
StreamingSessionParameters,
TerminateSession,
TerminationEvent,
- TurnEvent,
UpdateConfiguration,
WarningEvent,
)
@@ -37,144 +38,34 @@
logger = logging.getLogger(__name__)
-def _dump_model(model: BaseModel):
- if hasattr(model, "model_dump"):
- return model.model_dump(exclude_none=True)
- return model.dict(exclude_none=True)
-
-
-def _parse_model(model_class, data):
- if hasattr(model_class, "model_validate"):
- return model_class.model_validate(data)
- return model_class.parse_obj(data)
-
-
-def _normalize_min_turn_silence(params_dict: dict) -> dict:
- """Collapse `min_end_of_turn_silence_when_confident` into `min_turn_silence` so only
- one wire key is ever sent. Emits deprecation warnings."""
- old = params_dict.pop("min_end_of_turn_silence_when_confident", None)
- if old is None:
- return params_dict
- if "min_turn_silence" in params_dict:
- logger.warning(
- "[Deprecation Warning] Both `min_end_of_turn_silence_when_confident` and "
- "`min_turn_silence` are set. Using `min_turn_silence`; "
- "`min_end_of_turn_silence_when_confident` is deprecated."
- )
- else:
- logger.warning(
- "[Deprecation Warning] `min_end_of_turn_silence_when_confident` is "
- "deprecated and will be removed in a future release. Please use "
- "`min_turn_silence` instead."
- )
- params_dict["min_turn_silence"] = old
- return params_dict
-
-
-def _normalize_voice_focus(params_dict: dict) -> dict:
- """Collapse `noise_suppression_model` / `noise_suppression_threshold` into
- `voice_focus` / `voice_focus_threshold` so only the new wire keys are sent.
- Emits deprecation warnings."""
- for old_key, new_key in (
- ("noise_suppression_model", "voice_focus"),
- ("noise_suppression_threshold", "voice_focus_threshold"),
- ):
- old = params_dict.pop(old_key, None)
- if old is None:
- continue
- if new_key in params_dict:
- logger.warning(
- f"[Deprecation Warning] Both `{old_key}` and `{new_key}` are set. "
- f"Using `{new_key}`; `{old_key}` is deprecated."
- )
- else:
- logger.warning(
- f"[Deprecation Warning] `{old_key}` is deprecated and will be removed "
- f"in a future release. Please use `{new_key}` instead."
- )
- params_dict[new_key] = old
- return params_dict
-
-
-def _dump_model_json(model: BaseModel):
- if hasattr(model, "model_dump_json"):
- return model.model_dump_json(exclude_none=True)
- return model.json(exclude_none=True)
-
-
-def _user_agent() -> str:
- vi = sys.version_info
- python_version = f"{vi.major}.{vi.minor}.{vi.micro}"
- return (
- f"AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})"
- )
-
-
-class StreamingClient:
+class StreamingClient(_BaseStreamingClient):
def __init__(self, options: StreamingClientOptions):
- self._options = options
+ super().__init__(options)
self._client = _HTTPClient(api_host=options.api_host, api_key=options.api_key)
- self._handlers: Dict[StreamingEvents, List[Callable]] = {}
-
- for event in StreamingEvents.__members__.values():
- self._handlers[event] = []
-
self._write_queue: queue.Queue[OperationMessage] = queue.Queue()
self._write_thread = threading.Thread(target=self._write_message)
self._read_thread = threading.Thread(target=self._read_message)
self._stop_event = threading.Event()
- # Both flags are read and set only on the read thread (or on the main
- # thread before workers start, for handshake errors). Plain bools are
- # sufficient — no cross-thread synchronization is needed.
- self._connection_closed_reported = False
- self._server_error_reported = False
# Deliberate single-slot shared-memory handoff: the write thread parks
# a ConnectionClosed here and the read thread drains it. Synchronization
# is provided by `_stop_event.set()` (write side) + `recv(timeout=1)`
# (read side), which together give a happens-before within ~1s.
self._pending_close_error: Optional[Exception] = None
- self._websocket = None
def connect(self, params: StreamingParameters) -> None:
- if params.speech_model == "u3-pro":
- logger.warning(
- "[Deprecation Warning] The speech model `u3-pro` is deprecated and will be removed in a future release. "
- "Please use `u3-rt-pro` instead."
- )
+ """Open the WebSocket session and start the read/write threads.
- if params.customer_support_audio_capture:
- logger.warning(
- "`customer_support_audio_capture=True` will record session audio. "
- "Only enable this when explicitly coordinating with AssemblyAI support."
- )
+ Blocks until the handshake completes. If the server rejects the
+ handshake (auth error, etc.) ``Error`` is dispatched to any
+ ``on(StreamingEvents.Error, ...)`` handler rather than raised, so
+ registration order matters: call ``on()`` before ``connect()``.
+ """
+ _emit_param_warnings(params)
- params_dict = _normalize_voice_focus(
- _normalize_min_turn_silence(_dump_model(params))
- )
-
- # JSON-encode list and dict parameters for proper API compatibility (e.g., keyterms_prompt, llm_gateway)
- for key, value in params_dict.items():
- if isinstance(value, list):
- params_dict[key] = json.dumps(value)
- elif isinstance(value, dict):
- params_dict[key] = json.dumps(value)
-
- params_encoded = urlencode(params_dict)
-
- host = self._options.api_host
- if host.startswith(("ws://", "wss://")):
- uri = f"{host}/v3/ws?{params_encoded}"
- else:
- uri = f"wss://{host}/v3/ws?{params_encoded}"
- headers = {
- "Authorization": self._options.token
- if self._options.token
- else self._options.api_key,
- "User-Agent": _user_agent(),
- "AssemblyAI-Version": "2025-05-12",
- }
+ uri = _build_uri(self._options.api_host, params)
+ headers = _build_headers(self._options)
try:
self._websocket = websocket_connect(
@@ -206,6 +97,14 @@ def connect(self, params: StreamingParameters) -> None:
logger.debug("Connected to WebSocket server")
def disconnect(self, terminate: bool = False) -> None:
+ """Stop the read/write threads and close the WebSocket.
+
+ Pass ``terminate=True`` for a graceful close — the client sends a
+ ``TerminateSession`` frame and waits for the server's
+ ``TerminationEvent`` (which reports total audio duration). Without
+ ``terminate=True`` the WebSocket is closed without notifying the
+ server.
+ """
# Enqueue Terminate even when stop is already set: `_write_message`
# bypasses the stop gate for TerminateSession so the frame still
# reaches the server when the write thread is alive.
@@ -236,6 +135,13 @@ def _close_websocket(self) -> None:
def stream(
self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]]
) -> None:
+ """Send audio bytes to the server.
+
+ Accepts a raw ``bytes`` buffer or any (sync) iterable of ``bytes``.
+ Returns once all chunks are enqueued — the write thread does the
+ actual sending. After ``disconnect()`` (or a connection drop) this
+ becomes a silent no-op.
+ """
if self._stop_event.is_set():
return
@@ -257,10 +163,6 @@ def force_endpoint(self):
message = ForceEndpoint()
self._write_queue.put(message)
- def on(self, event: StreamingEvents, handler: Callable) -> None:
- if event in StreamingEvents.__members__.values() and callable(handler):
- self._handlers[event].append(handler)
-
def _write_message(self) -> None:
while True:
if not self._websocket:
@@ -335,7 +237,7 @@ def _read_message(self) -> None:
elif message:
self._handle_message(message)
else:
- logger.warning(f"Unsupported event type: {message_json['type']}")
+ logger.warning(f"Unsupported event type: {message_json.get('type')}")
def _handle_message(self, message: EventMessage) -> None:
if isinstance(message, TerminationEvent):
@@ -349,43 +251,6 @@ def _handle_message(self, message: EventMessage) -> None:
except Exception:
logger.exception("on_%s handler raised", event_type.name.lower())
- def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]:
- if "type" in data:
- message_type = data.get("type")
-
- event_type = self._parse_event_type(message_type)
-
- if event_type == StreamingEvents.Begin:
- return _parse_model(BeginEvent, data)
- elif event_type == StreamingEvents.Termination:
- return _parse_model(TerminationEvent, data)
- elif event_type == StreamingEvents.Turn:
- return _parse_model(TurnEvent, data)
- elif event_type == StreamingEvents.SpeechStarted:
- return _parse_model(SpeechStartedEvent, data)
- elif event_type == StreamingEvents.LLMGatewayResponse:
- return _parse_model(LLMGatewayResponseEvent, data)
- elif event_type == StreamingEvents.Error:
- return _parse_model(ErrorEvent, data)
- elif event_type == StreamingEvents.Warning:
- return _parse_model(WarningEvent, data)
- else:
- return None
- elif "error" in data:
- return _parse_model(ErrorEvent, data)
-
- return None
-
- @staticmethod
- def _parse_event_type(message_type: Optional[Any]) -> Optional[StreamingEvents]:
- if not isinstance(message_type, str):
- return None
-
- try:
- return StreamingEvents[message_type]
- except KeyError:
- return None
-
def _handle_warning(self, warning: WarningEvent):
logger.warning(
"Streaming warning (code=%s): %s", warning.warning_code, warning.warning
@@ -460,29 +325,6 @@ def _dispatch_error(self, error: StreamingError) -> None:
except Exception:
logger.exception("on_error handler raised")
- @staticmethod
- def _build_connection_closed_error(
- error: Union[
- StreamingError,
- ErrorEvent,
- websockets.exceptions.ConnectionClosed,
- OSError,
- ],
- ) -> Optional[StreamingError]:
- if isinstance(error, StreamingError):
- return error
- if isinstance(error, ErrorEvent):
- return StreamingError(message=error.error, code=error.error_code)
- if isinstance(error, websockets.exceptions.ConnectionClosed):
- if error.code == 1000:
- return None
- if error.code is not None and error.code in StreamingErrorCodes:
- message = StreamingErrorCodes[error.code]
- else:
- message = error.reason or f"Connection closed (code={error.code})"
- return StreamingError(message=message, code=error.code)
- return StreamingError(message=f"Connection failed: {error}")
-
def create_temporary_token(
self,
expires_in_seconds: int,
@@ -496,11 +338,7 @@ def create_temporary_token(
class _HTTPClient:
def __init__(self, api_host: str, api_key: Optional[str] = None):
- vi = sys.version_info
- python_version = f"{vi.major}.{vi.minor}.{vi.micro}"
- user_agent = f"{httpx._client.USER_AGENT} AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})"
-
- headers = {"User-Agent": user_agent}
+ headers = {"User-Agent": f"{httpx._client.USER_AGENT} {_user_agent()}"}
if api_key:
headers["Authorization"] = api_key
@@ -515,12 +353,12 @@ def create_temporary_token(
expires_in_seconds: int,
max_session_duration_seconds: Optional[int] = None,
) -> str:
- params: Dict[str, Any] = {}
-
- if expires_in_seconds:
- params["expires_in_seconds"] = expires_in_seconds
+ # ``expires_in_seconds`` is required per the type; always forward it
+ # so passing ``0`` reaches the server (where it can be validated)
+ # instead of being silently dropped by a falsy check.
+ params: Dict[str, Any] = {"expires_in_seconds": expires_in_seconds}
- if max_session_duration_seconds:
+ if max_session_duration_seconds is not None:
params["max_session_duration_seconds"] = max_session_duration_seconds
response = self._http_client.get(
diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py
index 5b35a1c..c7cd5f4 100644
--- a/assemblyai/streaming/v3/models.py
+++ b/assemblyai/streaming/v3/models.py
@@ -106,6 +106,7 @@ class StreamingSessionParameters(BaseModel):
filter_profanity: Optional[bool] = None
prompt: Optional[str] = None
interruption_delay: Optional[int] = None
+ turn_left_pad_ms: Optional[int] = None
class Encoding(str, Enum):
diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py
index ef94bac..7bf40ca 100644
--- a/assemblyai/transcriber.py
+++ b/assemblyai/transcriber.py
@@ -2,18 +2,11 @@
import concurrent.futures
import functools
-import json
import os
-import queue
-import threading
import time
from typing import (
- Any,
BinaryIO,
- Callable,
Dict,
- Generator,
- Iterable,
Iterator,
List,
Optional,
@@ -21,13 +14,10 @@
Tuple,
Union,
)
-from urllib.parse import urlencode, urlparse
+from urllib.parse import urlparse
import httpx
-import websockets
-import websockets.exceptions
from typing_extensions import Self
-from websockets.sync.client import connect as websocket_connect
from . import api, lemur, types
from . import client as _client
@@ -1281,444 +1271,3 @@ def list_transcripts_async(
Returns: A page with a list of transcripts along with page details.
"""
return self._executor.submit(self._impl.list_transcripts, params=params)
-
-
-class _RealtimeTranscriberImpl:
- def __init__(
- self,
- *,
- on_data: Callable[[types.RealtimeTranscript], None],
- on_error: Callable[[types.RealtimeError], None],
- on_open: Optional[Callable[[types.RealtimeSessionOpened], None]],
- on_close: Optional[Callable[[], None]],
- sample_rate: int,
- word_boost: List[str],
- encoding: Optional[types.AudioEncoding] = None,
- token: Optional[str] = None,
- client: _client.Client,
- end_utterance_silence_threshold: Optional[int],
- disable_partial_transcripts: Optional[bool],
- on_extra_session_information: Optional[
- Callable[[types.RealtimeSessionInformation], None]
- ] = None,
- ) -> None:
- self._client = client
- self._websocket: Optional[websockets.sync.client.ClientConnection] = None
-
- self._on_open = on_open
- self._on_data = on_data
- self._on_error = on_error
- self._on_close = on_close
- self._sample_rate = sample_rate
- self._word_boost = word_boost
- self._encoding = encoding
- self._token = token
- self._end_utterance_silence_threshold = end_utterance_silence_threshold
- self._disable_partial_transcripts = disable_partial_transcripts
- self._on_extra_session_information = on_extra_session_information
-
- self._write_queue: queue.Queue[Union[bytes, Dict]] = queue.Queue()
- self._write_thread = threading.Thread(target=self._write)
- self._read_thread = threading.Thread(target=self._read)
- self._stop_event = threading.Event()
-
- def connect(
- self,
- timeout: Optional[float],
- ) -> None:
- """
- Connects to the real-time service.
-
- Args:
- `timeout`: The maximum time to wait for the connection to be established.
- """
-
- params: Dict[str, Any] = {
- "sample_rate": self._sample_rate,
- }
- if self._word_boost:
- params["word_boost"] = json.dumps(self._word_boost)
- if self._encoding:
- params["encoding"] = self._encoding.value
- if self._token:
- params["token"] = self._token
- if self._disable_partial_transcripts:
- params["disable_partial_transcripts"] = self._disable_partial_transcripts
- if self._on_extra_session_information:
- params["enable_extra_session_information"] = True
-
- websocket_base_url = self._client.settings.base_url.replace("https", "wss")
-
- additional_headers = None
- if self._token is None:
- additional_headers = {"Authorization": f"{self._client.settings.api_key}"}
-
- try:
- self._websocket = websocket_connect(
- f"{websocket_base_url}{api.ENDPOINT_REALTIME_WEBSOCKET}?{urlencode(params)}",
- additional_headers=additional_headers,
- open_timeout=timeout,
- )
- except Exception as exc:
- return self._on_error(
- types.RealtimeError(
- f"Could not connect to the real-time service: {exc}"
- )
- )
-
- self._read_thread.start()
- self._write_thread.start()
-
- if self._end_utterance_silence_threshold is not None:
- self.configure_end_utterance_silence_threshold(
- self._end_utterance_silence_threshold
- )
-
- def stream(self, data: bytes) -> None:
- """
- Streams audio data to the real-time service by putting it into a queue.
- """
-
- self._write_queue.put(data)
-
- def configure_end_utterance_silence_threshold(
- self, threshold_milliseconds: int
- ) -> None:
- """
- Configures the end of utterance silence threshold.
- Can be called multiple times during a session at any point after the session starts.
-
- Args:
- `threshold_milliseconds`: The threshold in milliseconds.
- """
-
- self._write_queue.put(
- _RealtimeEndUtteranceSilenceThreshold(threshold_milliseconds).as_dict()
- )
-
- def force_end_utterance(self) -> None:
- """
- Forces the end of the current utterance.
- """
-
- self._write_queue.put(_RealtimeForceEndUtterance().as_dict())
-
- def close(self, terminate: bool = False) -> None:
- """
- Closes the connection to the real-time service gracefully.
- """
- if terminate and not self._stop_event.is_set():
- self._write_queue.put({"terminate_session": True})
-
- try:
- self._read_thread.join()
- self._write_thread.join()
- if self._websocket:
- self._websocket.close()
- except Exception:
- pass
-
- if self._on_close:
- self._on_close()
-
- def _read(self) -> None:
- """
- Reads messages from the real-time service.
-
- Must run in a separate thread to avoid blocking the main thread.
- """
-
- while not self._stop_event.is_set():
- if not self._websocket:
- raise ValueError("Websocket is None")
-
- try:
- recv_message = self._websocket.recv(timeout=1)
- except TimeoutError:
- continue
- except websockets.exceptions.ConnectionClosed as exc:
- return self._handle_error(exc)
-
- try:
- message = json.loads(recv_message)
- except json.JSONDecodeError as exc:
- self._on_error(
- types.RealtimeError(
- f"Could not decode message: {exc}",
- )
- )
- continue
-
- self._handle_message(message)
-
- def _write(self) -> None:
- """
- Writes messages to the real-time service.
-
- Must run in a separate thread to avoid blocking the main thread.
- """
-
- while not self._stop_event.is_set():
- try:
- data = self._write_queue.get(timeout=1)
- except queue.Empty:
- continue
-
- try:
- if not self._websocket:
- raise ValueError("websocket is None")
- elif isinstance(data, dict):
- self._websocket.send(json.dumps(data))
- elif isinstance(data, bytes):
- self._websocket.send(data)
- else:
- raise ValueError("unsupported message type")
- except websockets.exceptions.ConnectionClosed as exc:
- return self._handle_error(exc)
-
- def _handle_message(
- self,
- message: Dict[str, Any],
- ) -> None:
- """
- Handles a message received from the real-time service by calling the appropriate
- callback.
-
- Args:
- `message`: The message to handle.
- """
- if "message_type" in message:
- if message["message_type"] == types.RealtimeMessageTypes.partial_transcript:
- self._on_data(types.RealtimePartialTranscript(**message))
- elif message["message_type"] == types.RealtimeMessageTypes.final_transcript:
- self._on_data(types.RealtimeFinalTranscript(**message))
- elif (
- message["message_type"] == types.RealtimeMessageTypes.session_begins
- and self._on_open
- ):
- self._on_open(types.RealtimeSessionOpened(**message))
- elif (
- message["message_type"] == types.RealtimeMessageTypes.session_terminated
- ):
- self._stop_event.set()
- elif (
- message["message_type"]
- == types.RealtimeMessageTypes.session_information
- ):
- if self._on_extra_session_information is not None:
- self._on_extra_session_information(
- types.RealtimeSessionInformation(**message)
- )
- elif "error" in message:
- self._on_error(types.RealtimeError(message["error"]))
-
- def _handle_error(self, error: websockets.exceptions.ConnectionClosed) -> None:
- """
- Handles a WebSocket error by calling the appropriate callback.
-
- See a list of errors here:
-
- - https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
- - https://www.assemblyai.com/docs/Guides/real-time_streaming_transcription#closing-and-status-codes
- """
- if (
- error.code >= 4000
- and error.code <= 4999
- and error.code in types.RealtimeErrorMapping
- ):
- error_message = types.RealtimeErrorMapping[error.code]
- else:
- error_message = error.reason
-
- if error.code != 1000:
- self._on_error(types.RealtimeError(error_message, error.code))
-
- self.close()
-
- @classmethod
- def create_temporary_token(
- cls,
- expires_in: int,
- timeout: Optional[float] = None,
- ) -> str:
- """
- Request a temporary authentication token.
-
- Args:
- expires_in: The amount of time until the token expires in seconds.
- timeout: The timeout in seconds to wait for a response.
- A `timeout` of `None` means no timeout.
-
- Returns: The temporary authentication token.
- """
-
- return api.create_temporary_token(
- client=_client.Client.get_default().http_client,
- request=types.RealtimeCreateTemporaryTokenRequest(
- expires_in=expires_in,
- ),
- http_timeout=timeout,
- )
-
-
-class _RealtimeForceEndUtterance:
- def as_dict(self) -> Dict[str, bool]:
- return {
- "force_end_utterance": True,
- }
-
-
-class _RealtimeEndUtteranceSilenceThreshold:
- def __init__(self, threshold_milliseconds: int) -> None:
- self._value = threshold_milliseconds
-
- @property
- def value(self) -> int:
- return self._value
-
- def as_dict(self) -> Dict[str, int]:
- return {"end_utterance_silence_threshold": self._value}
-
-
-class RealtimeTranscriber:
- def __init__(
- self,
- *,
- on_data: Callable[[types.RealtimeTranscript], None],
- on_error: Callable[[types.RealtimeError], None],
- on_open: Optional[Callable[[types.RealtimeSessionOpened], None]] = None,
- on_close: Optional[Callable[[], None]] = None,
- sample_rate: int,
- word_boost: List[str] = [],
- encoding: Optional[types.AudioEncoding] = None,
- token: Optional[str] = None,
- client: Optional[_client.Client] = None,
- end_utterance_silence_threshold: Optional[int] = None,
- disable_partial_transcripts: Optional[bool] = None,
- on_extra_session_information: Optional[
- Callable[[types.RealtimeSessionInformation], None]
- ] = None,
- ) -> None:
- """
- Creates a new real-time transcriber.
-
- Args:
- `on_data`: The callback to call when a new transcript is received.
- `on_error`: The callback to call when an error occurs.
- `on_open`: (Optional) The callback to call when the connection to the real-time service opens.
- `on_close`: (Optional) The callback to call when the connection to the real-time service closes.
- `sample_rate`: The sample rate of the audio data.
- `word_boost`: (Optional) A list of words to boost transcription probability for.
- `encoding`: (Optional) The encoding of the audio data.
- `token`: (Optional) A temporary authentication token.
- `client`: (Optional) The client to use for the real-time service.
- `end_utterance_silence_threshold`: (Optional) The end utterance silence threshold in milliseconds.
- `disable_partial_transcripts`: (Optional) If set to `True`, only final transcripts will be received.
- `on_extra_session_information`: (Optional) The callback to call when a `SessionInformation` message is received.
- If this callback is set, the parameter `enable_extra_session_information` is sent to the API, and the client
- receives a `SessionInformation` message right before receiving the session termination message.
- """
-
- self._client = client or _client.Client.get_default(
- api_key_required=token is None
- )
-
- self._impl = _RealtimeTranscriberImpl(
- on_open=on_open,
- on_data=on_data,
- on_error=on_error,
- on_close=on_close,
- sample_rate=sample_rate,
- word_boost=word_boost,
- encoding=encoding,
- token=token,
- client=self._client,
- end_utterance_silence_threshold=end_utterance_silence_threshold,
- disable_partial_transcripts=disable_partial_transcripts,
- on_extra_session_information=on_extra_session_information,
- )
-
- def connect(
- self,
- timeout: Optional[float] = 10.0,
- ) -> None:
- """
- Connects to the real-time service.
-
- Args:
- `timeout`: The timeout in seconds to wait for the connection to be established.
- A `timeout` of `None` means no timeout.
- """
-
- self._impl.connect(timeout=timeout)
-
- def stream(
- self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]]
- ) -> None:
- """
- Streams raw audio data to the real-time service.
-
- Args:
- `data`: Raw audio data in `bytes` or a generator/iterable of `bytes`.
-
- Note: Make sure that `data` matches the `sample_rate` that was given in the constructor.
- """
- if isinstance(data, bytes):
- self._impl.stream(data)
- return
-
- for chunk in data:
- self._impl.stream(chunk)
-
- def configure_end_utterance_silence_threshold(
- self, threshold_milliseconds: int
- ) -> None:
- """
- Configures the silence duration threshold used to detect the end of an utterance.
- In practice, it's used to tune how the transcriptions are split into final transcripts.
- Can be called multiple times during a session at any point after the session starts.
-
- Args:
- `threshold_milliseconds`: The threshold in milliseconds.
- """
- self._impl.configure_end_utterance_silence_threshold(threshold_milliseconds)
-
- def force_end_utterance(self) -> None:
- """
- Forces the end of the current utterance.
- After calling this method, the server will end the current utterance and return a final transcript.
- """
- self._impl.force_end_utterance()
-
- def close(self) -> None:
- """
- Closes the connection to the real-time service.
- """
-
- self._impl.close(terminate=True)
-
- @classmethod
- def create_temporary_token(
- cls,
- expires_in: int,
- timeout: Optional[float] = None,
- ) -> str:
- """
- Request a temporary authentication token.
-
- Example:
- To create a token, you can simply do:
- ```
- token = aai.RealtimeTranscriber.create_temporary_token(expires_in=360000)
- ```
-
- Args:
- expires_in: The amount of time until the token expires in seconds.
- timeout: The timeout in seconds to wait for a response.
- A `timeout` of `None` means no timeout.
-
- Returns: The temporary authentication token.
- """
- return _RealtimeTranscriberImpl.create_temporary_token(
- expires_in=expires_in, timeout=timeout
- )
diff --git a/assemblyai/types.py b/assemblyai/types.py
index f97d176..6f29e36 100644
--- a/assemblyai/types.py
+++ b/assemblyai/types.py
@@ -21,7 +21,7 @@
try:
# pydantic v2 import
- from pydantic import UUID4, BaseModel, ConfigDict, Field, field_validator
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
pydantic_v2 = True
@@ -34,7 +34,7 @@
) from None
# pydantic v1 import (fallback for Python < 3.14)
- from pydantic.v1 import UUID4, BaseModel, BaseSettings, ConfigDict, Field, validator
+ from pydantic.v1 import BaseModel, BaseSettings, ConfigDict, Field, validator
pydantic_v2 = False
@@ -2951,163 +2951,3 @@ class LemurPurgeResponse(BaseModel):
deleted: bool
"The result of the LeMUR purge request"
-
-
-class RealtimeMessageTypes(str, Enum):
- """
- The type of message received from the real-time API
- """
-
- partial_transcript = "PartialTranscript"
- final_transcript = "FinalTranscript"
- session_begins = "SessionBegins"
- session_terminated = "SessionTerminated"
- session_information = "SessionInformation"
-
-
-class AudioEncoding(str, Enum):
- """
- The encoding of the audio data
- """
-
- pcm_s16le = "pcm_s16le"
- pcm_mulaw = "pcm_mulaw"
-
-
-class RealtimeCreateTemporaryTokenRequest(BaseModel):
- expires_in: int
- "The amount of time until the token expires in seconds"
-
-
-class RealtimeCreateTemporaryTokenResponse(BaseModel):
- token: str
- "The temporary authentication token for real-time transcription"
-
-
-class RealtimeSessionOpened(BaseModel):
- """
- Once a real-time session is opened, the client will receive this message
- """
-
- message_type: RealtimeMessageTypes = RealtimeMessageTypes.session_begins
-
- session_id: UUID4
- "Unique identifier for the established session."
-
- expires_at: datetime
- "Timestamp when this session will expire."
-
-
-class RealtimeWord(BaseModel):
- """
- A word in a real-time transcript
- """
-
- start: int
- "Start time of word relative to session start, in milliseconds"
-
- end: int
- "End time of word relative to session start, in milliseconds"
-
- confidence: float
- "The confidence score of the word, between 0 and 1"
-
- text: str
- "The word itself"
-
-
-class RealtimeTranscript(BaseModel):
- """
- Base class for real-time transcript messages.
- """
-
- message_type: RealtimeMessageTypes
- "Describes the type of message"
-
- audio_start: int
- "Start time of audio sample relative to session start, in milliseconds"
-
- audio_end: int
- "End time of audio sample relative to session start, in milliseconds"
-
- confidence: float
- "The confidence score of the entire transcription, between 0 and 1"
-
- text: str
- "The transcript for your audio"
-
- words: List[RealtimeWord]
- """
- An array of objects, with the information for each word in the transcription text.
- Will include the `start`/`end` time (in milliseconds) of the word, the `confidence` score of the word,
- and the `text` (i.e. the word itself)
- """
-
- created: datetime
- "Timestamp when this message was created"
-
-
-class RealtimePartialTranscript(RealtimeTranscript):
- """
- As you send audio data to the service, the service will immediately start responding with partial transcripts.
- """
-
- message_type: RealtimeMessageTypes = RealtimeMessageTypes.partial_transcript
-
-
-class RealtimeFinalTranscript(RealtimeTranscript):
- """
- After you've received your partial results, our model will continue to analyze incoming audio and,
- when it detects the end of an "utterance" (usually a pause in speech), it will finalize the results
- sent to you so far with higher accuracy, as well as add punctuation and casing to the transcription text.
- """
-
- message_type: RealtimeMessageTypes = RealtimeMessageTypes.final_transcript
-
- punctuated: bool
- "Whether the transcript has been punctuated and cased"
-
- text_formatted: bool
- "Whether the transcript has been formatted (e.g. Dollar -> $)"
-
-
-class RealtimeSessionInformation(BaseModel):
- """
- If `on_extra_session_information` is set, the client receives this message
- right before receiving the session termination message.
- """
-
- message_type: RealtimeMessageTypes = RealtimeMessageTypes.session_information
-
- audio_duration_seconds: float
- "The duration of the audio in seconds"
-
-
-class RealtimeError(AssemblyAIError):
- """
- Real-time error message
- """
-
-
-RealtimeErrorMapping = {
- 4000: "Sample rate must be a positive integer",
- 4001: "Not Authorized",
- 4002: "Insufficient Funds",
- 4003: """This feature is paid-only and requires you to add a credit card.
- Please visit https://app.assemblyai.com/ to add a credit card to your account""",
- 4004: "Session Not Found",
- 4008: "Session Expired",
- 4010: "Session Previously Closed",
- 4029: "Client sent audio too fast",
- 4030: "Session is handled by another websocket",
- 4031: "Session idle for too long",
- 4032: "Audio duration is too short",
- 4033: "Audio duration is too long",
- 4034: "Audio too small to transcode",
- 4100: "Endpoint received invalid JSON",
- 4101: "Endpoint received a message with an invalid schema",
- 4102: "This account has exceeded the number of allowed streams",
- 4103: "The session has been reconnected. This websocket is no longer valid.",
- 4104: "Could not parse word boost parameter",
- 1013: "Temporary server condition forced blocking client's request",
-}
diff --git a/tests/unit/test_realtime_transcriber.py b/tests/unit/test_realtime_transcriber.py
deleted file mode 100644
index c8d978a..0000000
--- a/tests/unit/test_realtime_transcriber.py
+++ /dev/null
@@ -1,564 +0,0 @@
-import datetime
-import json
-import uuid
-from unittest.mock import MagicMock
-from urllib.parse import urlencode
-
-import httpx
-import pytest
-import websockets.exceptions
-from faker import Faker
-from pytest_httpx import HTTPXMock
-from pytest_mock import MockFixture
-
-import assemblyai as aai
-from assemblyai.api import ENDPOINT_REALTIME_TOKEN
-
-aai.settings.api_key = "test"
-
-
-def _disable_rw_threads(mocker: MockFixture):
- """
- Disable the read/write threads for the websocket
- """
-
- mocker.patch("threading.Thread.start", return_value=None)
-
-
-@pytest.mark.parametrize(
- "encoding,token,expected_header",
- [
- (None, None, {"Authorization": "test"}),
- (aai.AudioEncoding.pcm_s16le, None, {"Authorization": "test"}),
- (aai.AudioEncoding.pcm_mulaw, None, {"Authorization": "test"}),
- (None, "12345678", None),
- (aai.AudioEncoding.pcm_s16le, "12345678", None),
- ],
-)
-def test_realtime_connect_has_parameters(
- encoding, token, expected_header, mocker: MockFixture
-):
- """
- Test that the connect method has the correct parameters set
- """
- aai.settings.base_url = "https://api.assemblyai.com"
-
- actual_url = None
- actual_additional_headers = None
- actual_open_timeout = None
-
- def mocked_websocket_connect(
- url: str, additional_headers: dict, open_timeout: float
- ):
- nonlocal actual_url, actual_additional_headers, actual_open_timeout
- actual_url = url
- actual_additional_headers = additional_headers
- actual_open_timeout = open_timeout
-
- mocker.patch(
- "assemblyai.transcriber.websocket_connect",
- new=mocked_websocket_connect,
- )
- _disable_rw_threads(mocker)
-
- transcriber = aai.RealtimeTranscriber(
- on_data=lambda: None,
- on_error=lambda error: print(error),
- sample_rate=44_100,
- word_boost=["AssemblyAI"],
- encoding=encoding,
- token=token,
- )
-
- transcriber.connect(timeout=15.0)
-
- params = dict(sample_rate=44100, word_boost=json.dumps(["AssemblyAI"]))
- if encoding:
- params["encoding"] = encoding.value
- if token:
- params["token"] = token
-
- assert actual_url == f"wss://api.assemblyai.com/v2/realtime/ws?{urlencode(params)}"
- assert actual_additional_headers == expected_header
- assert actual_open_timeout == 15.0
-
-
-def test_realtime_connect_succeeds(mocker: MockFixture):
- """
- Tests that the `RealtimeTranscriber` successfully connects to the `real-time` service.
- """
- on_error_called = False
-
- def on_error(error: aai.RealtimeError):
- nonlocal on_error_called
- on_error_called = True
-
- transcriber = aai.RealtimeTranscriber(
- on_data=lambda _: None,
- on_error=on_error,
- sample_rate=44_100,
- )
-
- mocker.patch(
- "assemblyai.transcriber.websocket_connect",
- return_value=MagicMock(),
- )
-
- # mock the read/write threads
- _disable_rw_threads(mocker)
-
- # should pass
- transcriber.connect()
-
- # no errors should be called
- assert not on_error_called
-
-
-def test_realtime_token_connect_succeeds(mocker: MockFixture):
- """
- Tests that the `RealtimeTranscriber` successfully connects
- to the `real-time` service when a token is used.
- """
- on_error_called = False
-
- # reset the API key
- mocker.patch("assemblyai.settings.api_key", new=None)
-
- def on_error(error: aai.RealtimeError):
- nonlocal on_error_called
- on_error_called = True
-
- transcriber = aai.RealtimeTranscriber(
- on_data=lambda _: None, on_error=on_error, sample_rate=44_100, token="12345"
- )
-
- mocker.patch(
- "assemblyai.transcriber.websocket_connect",
- return_value=MagicMock(),
- )
-
- # mock the read/write threads
- _disable_rw_threads(mocker)
-
- # should pass
- transcriber.connect()
-
- # no errors should be called
- assert not on_error_called
-
-
-def test_realtime_connect_fails(mocker: MockFixture):
- """
- Tests that the `RealtimeTranscriber` fails to connect to the `real-time` service.
- """
-
- on_error_called = False
-
- def on_error(error: aai.RealtimeError):
- nonlocal on_error_called
- on_error_called = True
-
- assert isinstance(error, aai.RealtimeError)
- assert "connection failed" in str(error)
-
- transcriber = aai.RealtimeTranscriber(
- on_data=lambda _: None,
- on_error=on_error,
- sample_rate=44_100,
- )
- mocker.patch(
- "assemblyai.transcriber.websocket_connect",
- side_effect=Exception("connection failed"),
- )
-
- transcriber.connect()
-
- assert on_error_called
-
-
-def test_realtime__read_succeeds(mocker: MockFixture, faker: Faker):
- """
- Tests the `_read` method of the `_RealtimeTranscriberImpl` class.
- """
-
- expected_transcripts = [
- aai.RealtimeFinalTranscript(
- created=faker.date_time(),
- text=faker.sentence(),
- audio_start=0,
- audio_end=1,
- confidence=1.0,
- words=[],
- punctuated=True,
- text_formatted=True,
- )
- ]
-
- received_transcripts = []
-
- def on_data(data: aai.RealtimeTranscript):
- nonlocal received_transcripts
- received_transcripts.append(data)
-
- transcriber = aai.RealtimeTranscriber(
- on_data=on_data,
- on_error=lambda _: None,
- sample_rate=44_100,
- )
-
- transcriber._impl._websocket = MagicMock()
- websocket_recv = [
- json.dumps(msg.dict(), default=str) for msg in expected_transcripts
- ]
- transcriber._impl._websocket.recv.side_effect = websocket_recv
-
- with pytest.raises(StopIteration):
- transcriber._impl._read()
-
- assert received_transcripts == expected_transcripts
-
-
-def test_realtime__read_fails(mocker: MockFixture):
- """
- Tests the `_read` method of the `_RealtimeTranscriberImpl` class.
- """
-
- on_error_called = False
-
- def on_error(error: aai.RealtimeError):
- nonlocal on_error_called
- on_error_called = True
-
- transcriber = aai.RealtimeTranscriber(
- on_data=lambda _: None,
- on_error=on_error,
- sample_rate=44_100,
- )
-
- transcriber._impl._websocket = MagicMock()
- error = websockets.exceptions.ConnectionClosedOK(rcvd=None, sent=None)
- transcriber._impl._websocket.recv.side_effect = error
-
- transcriber._impl._read()
-
- assert on_error_called
-
-
-def test_realtime__write_succeeds(mocker: MockFixture):
- """
- Tests the `_write` method of the `_RealtimeTranscriberImpl` class.
- """
- audio_chunks = [
- bytes([1, 2, 3, 4, 5]),
- bytes([6, 7, 8, 9, 10]),
- ]
-
- actual_sent = []
-
- def mocked_send(data: str):
- nonlocal actual_sent
- actual_sent.append(data)
-
- transcriber = aai.RealtimeTranscriber(
- on_data=lambda _: None,
- on_error=lambda _: None,
- sample_rate=44_100,
- )
-
- transcriber._impl._websocket = MagicMock()
- transcriber._impl._websocket.send = mocked_send
- transcriber._impl._stop_event.is_set = MagicMock(side_effect=[False, False, True])
-
- transcriber.stream(audio_chunks[0])
- transcriber.stream(audio_chunks[1])
-
- transcriber._impl._write()
-
- # assert that the correct data was sent (= the exact input bytes)
- assert len(actual_sent) == 2
- assert actual_sent[0] == audio_chunks[0]
- assert actual_sent[1] == audio_chunks[1]
-
-
-def test_realtime__handle_message_session_begins(mocker: MockFixture):
- """
- Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
- with the `SessionBegins` message.
- """
-
- test_message = {
- "message_type": "SessionBegins",
- "session_id": str(uuid.uuid4()),
- "expires_at": datetime.datetime.now().isoformat(),
- }
-
- on_open_called = False
-
- def on_open(session_opened: aai.RealtimeSessionOpened):
- nonlocal on_open_called
- on_open_called = True
- assert isinstance(session_opened, aai.RealtimeSessionOpened)
- assert session_opened.session_id == uuid.UUID(test_message["session_id"])
- assert session_opened.expires_at.isoformat() == test_message["expires_at"]
-
- transcriber = aai.RealtimeTranscriber(
- on_open=on_open,
- on_data=lambda _: None,
- on_error=lambda _: None,
- sample_rate=44_100,
- )
-
- transcriber._impl._handle_message(test_message)
-
- assert on_open_called
-
-
-def test_realtime__handle_message_partial_transcript(mocker: MockFixture):
- """
- Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
- with the `PartialTranscript` message.
- """
-
- test_message = {
- "message_type": "PartialTranscript",
- "text": "hello world",
- "audio_start": 0,
- "audio_end": 1500,
- "confidence": 0.99,
- "created": datetime.datetime.now().isoformat(),
- "words": [
- {
- "text": "hello",
- "start": 0,
- "end": 500,
- "confidence": 0.99,
- },
- {
- "text": "world",
- "start": 500,
- "end": 1500,
- "confidence": 0.99,
- },
- ],
- }
-
- on_data_called = False
-
- def on_data(data: aai.RealtimePartialTranscript):
- nonlocal on_data_called
- on_data_called = True
- assert isinstance(data, aai.RealtimePartialTranscript)
- assert data.text == test_message["text"]
- assert data.audio_start == test_message["audio_start"]
- assert data.audio_end == test_message["audio_end"]
- assert data.confidence == test_message["confidence"]
- assert data.created.isoformat() == test_message["created"]
- assert data.words == [
- aai.RealtimeWord(
- text=test_message["words"][0]["text"],
- start=test_message["words"][0]["start"],
- end=test_message["words"][0]["end"],
- confidence=test_message["words"][0]["confidence"],
- ),
- aai.RealtimeWord(
- text=test_message["words"][1]["text"],
- start=test_message["words"][1]["start"],
- end=test_message["words"][1]["end"],
- confidence=test_message["words"][1]["confidence"],
- ),
- ]
-
- transcriber = aai.RealtimeTranscriber(
- on_data=on_data,
- on_error=lambda _: None,
- sample_rate=44_100,
- )
-
- transcriber._impl._handle_message(test_message)
-
- assert on_data_called
-
-
-def test_realtime__handle_message_final_transcript(mocker: MockFixture):
- """
- Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
- with the `FinalTranscript` message.
- """
-
- test_message = {
- "message_type": "FinalTranscript",
- "text": "Hello, world!",
- "audio_start": 0,
- "audio_end": 1500,
- "confidence": 0.99,
- "created": datetime.datetime.now().isoformat(),
- "punctuated": True,
- "text_formatted": True,
- "words": [
- {
- "text": "Hello,",
- "start": 0,
- "end": 500,
- "confidence": 0.99,
- },
- {
- "text": "world!",
- "start": 500,
- "end": 1500,
- "confidence": 0.99,
- },
- ],
- }
-
- on_data_called = False
-
- def on_data(data: aai.RealtimeFinalTranscript):
- nonlocal on_data_called
- on_data_called = True
- assert isinstance(data, aai.RealtimeFinalTranscript)
- assert data.text == test_message["text"]
- assert data.audio_start == test_message["audio_start"]
- assert data.audio_end == test_message["audio_end"]
- assert data.confidence == test_message["confidence"]
- assert data.created.isoformat() == test_message["created"]
- assert data.punctuated == test_message["punctuated"]
- assert data.text_formatted == test_message["text_formatted"]
- assert data.words == [
- aai.RealtimeWord(
- text=test_message["words"][0]["text"],
- start=test_message["words"][0]["start"],
- end=test_message["words"][0]["end"],
- confidence=test_message["words"][0]["confidence"],
- ),
- aai.RealtimeWord(
- text=test_message["words"][1]["text"],
- start=test_message["words"][1]["start"],
- end=test_message["words"][1]["end"],
- confidence=test_message["words"][1]["confidence"],
- ),
- ]
-
- transcriber = aai.RealtimeTranscriber(
- on_data=on_data,
- on_error=lambda _: None,
- sample_rate=44_100,
- )
-
- transcriber._impl._handle_message(test_message)
-
- assert on_data_called
-
-
-def test_realtime__handle_message_error_message(mocker: MockFixture):
- """
- Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
- with the error message.
- """
-
- test_message = {
- "error": "test error",
- }
-
- on_error_called = False
-
- def on_error(error: aai.RealtimeError):
- nonlocal on_error_called
- on_error_called = True
- assert isinstance(error, aai.RealtimeError)
- assert str(error) == test_message["error"]
-
- transcriber = aai.RealtimeTranscriber(
- on_data=lambda _: None,
- on_error=on_error,
- sample_rate=44_100,
- )
-
- transcriber._impl._handle_message(test_message)
-
- assert on_error_called
-
-
-def test_realtime__handle_message_session_information_message(mocker: MockFixture):
- """
- Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
- with the session information message.
- """
-
- test_message = {
- "message_type": "SessionInformation",
- "audio_duration_seconds": 3000.0,
- }
-
- on_extra_session_information_called = False
-
- def on_extra_session_information(data: aai.RealtimeSessionInformation):
- nonlocal on_extra_session_information_called
- on_extra_session_information_called = True
- assert isinstance(data, aai.RealtimeSessionInformation)
- assert data.audio_duration_seconds == test_message["audio_duration_seconds"]
-
- transcriber = aai.RealtimeTranscriber(
- on_data=lambda _: None,
- on_error=lambda _: None,
- sample_rate=44_100,
- on_extra_session_information=on_extra_session_information,
- )
-
- transcriber._impl._handle_message(test_message)
-
- assert on_extra_session_information_called
-
-
-def test_realtime__handle_message_unknown_message(mocker: MockFixture):
- """
- Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class
- with an unknown message.
- """
-
- test_message = {
- "message_type": "Unknown",
- }
-
- on_data_called = False
-
- def on_data(data: aai.RealtimeTranscript):
- nonlocal on_data_called
- on_data_called = True
-
- on_error_called = False
-
- def on_error(error: aai.RealtimeError):
- nonlocal on_error_called
- on_error_called = True
-
- transcriber = aai.RealtimeTranscriber(
- on_data=on_data,
- on_error=on_error,
- sample_rate=44_100,
- )
-
- transcriber._impl._handle_message(test_message)
-
- assert not on_data_called
- assert not on_error_called
-
-
-def test_create_temporary_token(httpx_mock: HTTPXMock):
- """
- Tests whether the creation of a temporary token is successful.
- """
-
- # mock the specific endpoint
- httpx_mock.add_response(
- url=f"{aai.settings.base_url}{ENDPOINT_REALTIME_TOKEN}",
- status_code=httpx.codes.OK,
- method="POST",
- json={"token": "123456"},
- )
-
- token = aai.RealtimeTranscriber.create_temporary_token(expires_in=3000)
-
- assert token == "123456"
-
-
-# TODO: create tests for the `RealtimeTranscriber.close` method
diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py
index eb40850..a637eba 100644
--- a/tests/unit/test_streaming.py
+++ b/tests/unit/test_streaming.py
@@ -681,6 +681,37 @@ def mocked_websocket_connect(
assert "interruption_delay=500" in actual_url
+def test_client_connect_with_turn_left_pad_ms(mocker: MockFixture):
+ # Given: client + turn_left_pad_ms=1024 (U3-Pro left-pad window override)
+ actual_url = None
+
+ def mocked_websocket_connect(
+ url: str, additional_headers: dict, open_timeout: float
+ ):
+ nonlocal actual_url
+ actual_url = url
+
+ mocker.patch(
+ "assemblyai.streaming.v3.client.websocket_connect",
+ new=mocked_websocket_connect,
+ )
+ _disable_rw_threads(mocker)
+ client = StreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ params = StreamingParameters(
+ sample_rate=16000,
+ speech_model=SpeechModel.u3_rt_pro,
+ turn_left_pad_ms=1024,
+ )
+
+ # When: connect
+ client.connect(params)
+
+ # Then: parameter reaches the URL
+ assert "turn_left_pad_ms=1024" in actual_url
+
+
def test_customer_support_audio_capture_warns_when_enabled(
mocker: MockFixture, caplog: pytest.LogCaptureFixture
):
diff --git a/tests/unit/test_streaming_async.py b/tests/unit/test_streaming_async.py
new file mode 100644
index 0000000..bf00701
--- /dev/null
+++ b/tests/unit/test_streaming_async.py
@@ -0,0 +1,1153 @@
+import asyncio
+import json
+import logging
+from urllib.parse import urlencode
+
+import pytest
+from pytest_mock import MockFixture
+from websockets.exceptions import ConnectionClosed, InvalidStatus
+from websockets.frames import Close
+
+from assemblyai.streaming.v3 import (
+ AsyncStreamingClient,
+ SpeechModel,
+ StreamingClientOptions,
+ StreamingEvents,
+ StreamingParameters,
+)
+from assemblyai.streaming.v3.models import TerminateSession
+
+pytestmark = pytest.mark.asyncio
+
+
+def _default_params() -> StreamingParameters:
+ return StreamingParameters(
+ sample_rate=16000,
+ speech_model=SpeechModel.universal_streaming_english,
+ )
+
+
+class _FakeAsyncWebSocket:
+ """Programmable async websocket stand-in for driving AsyncStreamingClient
+ in tests. Inbound messages are queued via ``push_message`` /
+ ``push_close``; outbound sends accumulate in ``sent``.
+ """
+
+ def __init__(self, send_raises=None):
+ self._inbound: "asyncio.Queue[object]" = asyncio.Queue()
+ self._send_raises = send_raises
+ self.sent: list = []
+ self.send_call_count = 0
+ self.close_call_count = 0
+ self._closed = False
+
+ def push_message(self, data) -> None:
+ self._inbound.put_nowait(data)
+
+ def push_close(self, exc: BaseException) -> None:
+ self._inbound.put_nowait(exc)
+
+ async def recv(self):
+ item = await self._inbound.get()
+ if isinstance(item, BaseException):
+ raise item
+ return item
+
+ async def send(self, data) -> None:
+ self.send_call_count += 1
+ if self._send_raises is not None:
+ raise self._send_raises
+ self.sent.append(data)
+
+ async def close(self) -> None:
+ self.close_call_count += 1
+ self._closed = True
+
+
+def _patch_connect(mocker: MockFixture, fake_ws):
+ """Patch ``websocket_connect_async`` to return the given fake websocket."""
+
+ async def fake_connect(uri, additional_headers=None, **_kwargs):
+ fake_connect.uri = uri
+ fake_connect.additional_headers = additional_headers
+ return fake_ws
+
+ fake_connect.uri = None
+ fake_connect.additional_headers = None
+ mocker.patch(
+ "assemblyai.streaming.v3.async_client.websocket_connect_async",
+ new=fake_connect,
+ )
+ return fake_connect
+
+
+async def _wait_for_tasks(client: AsyncStreamingClient, timeout: float = 2.0) -> None:
+ """Wait until both read/write tasks have exited and stop is set. Raises
+ ``AssertionError`` on timeout so stalls fail tests deterministically
+ instead of silently passing."""
+ loop = asyncio.get_running_loop()
+ deadline = loop.time() + timeout
+ while loop.time() < deadline:
+ read_done = client._read_task is None or client._read_task.done()
+ write_done = client._write_task is None or client._write_task.done()
+ if read_done and write_done and client._stop_event.is_set():
+ return
+ await asyncio.sleep(0.01)
+ raise AssertionError(
+ f"AsyncStreamingClient read/write tasks did not finish within {timeout}s"
+ )
+
+
+async def test_client_connect_builds_uri_and_headers(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ fake_connect = _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+
+ params = _default_params()
+ await client.connect(params)
+
+ expected_qs = urlencode(
+ {
+ "sample_rate": params.sample_rate,
+ "speech_model": str(params.speech_model),
+ }
+ )
+ assert fake_connect.uri == f"wss://api.example.com/v3/ws?{expected_qs}"
+ assert fake_connect.additional_headers["Authorization"] == "test"
+ assert fake_connect.additional_headers["AssemblyAI-Version"] == "2025-05-12"
+ assert "AssemblyAI/1.0" in fake_connect.additional_headers["User-Agent"]
+
+ await client.disconnect()
+
+
+async def test_client_connect_with_token(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ fake_connect = _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(token="tok-value", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ assert fake_connect.additional_headers["Authorization"] == "tok-value"
+
+ await client.disconnect()
+
+
+async def test_stream_bytes_writes_to_socket(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ await client.stream(b"\x00" * 320)
+
+ # Give the write task a moment to drain the queue.
+ for _ in range(50):
+ if fake_ws.sent:
+ break
+ await asyncio.sleep(0.01)
+
+ assert fake_ws.sent == [b"\x00" * 320]
+
+ await client.disconnect()
+
+
+async def test_stream_sync_iterable(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ chunks = [b"a", b"bb", b"ccc"]
+ await client.stream(iter(chunks))
+
+ for _ in range(50):
+ if len(fake_ws.sent) == 3:
+ break
+ await asyncio.sleep(0.01)
+
+ assert fake_ws.sent == chunks
+
+ await client.disconnect()
+
+
+async def test_stream_async_iterable(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ async def gen():
+ for chunk in (b"x", b"yy", b"zzz"):
+ yield chunk
+
+ await client.stream(gen())
+
+ for _ in range(50):
+ if len(fake_ws.sent) == 3:
+ break
+ await asyncio.sleep(0.01)
+
+ assert fake_ws.sent == [b"x", b"yy", b"zzz"]
+
+ await client.disconnect()
+
+
+async def test_disconnect_terminate_sends_terminate_then_closes(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ await client.disconnect(terminate=True)
+
+ sent_terminate = [
+ s for s in fake_ws.sent if isinstance(s, str) and "Terminate" in s
+ ]
+ assert len(sent_terminate) == 1
+ assert fake_ws.close_call_count >= 1
+
+
+async def test_begin_event_dispatched_to_handler(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ received = []
+
+ def on_begin(_client, event):
+ received.append(event)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Begin, on_begin)
+ await client.connect(_default_params())
+
+ fake_ws.push_message(
+ json.dumps(
+ {
+ "type": "Begin",
+ "id": "abc",
+ "expires_at": "2030-01-01T00:00:00",
+ }
+ )
+ )
+
+ for _ in range(50):
+ if received:
+ break
+ await asyncio.sleep(0.01)
+
+ assert len(received) == 1
+ assert received[0].id == "abc"
+
+ await client.disconnect()
+
+
+async def test_async_handler_is_awaited(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ seen = []
+
+ async def on_begin(_client, event):
+ await asyncio.sleep(0)
+ seen.append(event.id)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Begin, on_begin)
+ await client.connect(_default_params())
+
+ fake_ws.push_message(
+ json.dumps(
+ {"type": "Begin", "id": "async-id", "expires_at": "2030-01-01T00:00:00"}
+ )
+ )
+
+ for _ in range(50):
+ if seen:
+ break
+ await asyncio.sleep(0.01)
+
+ assert seen == ["async-id"]
+
+ await client.disconnect()
+
+
+async def test_sync_and_async_handlers_can_mix(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ sync_seen = []
+ async_seen = []
+
+ def sync_handler(_client, event):
+ sync_seen.append(event.id)
+
+ async def async_handler(_client, event):
+ async_seen.append(event.id)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Begin, sync_handler)
+ client.on(StreamingEvents.Begin, async_handler)
+ await client.connect(_default_params())
+
+ fake_ws.push_message(
+ json.dumps({"type": "Begin", "id": "mix", "expires_at": "2030-01-01T00:00:00"})
+ )
+
+ for _ in range(50):
+ if sync_seen and async_seen:
+ break
+ await asyncio.sleep(0.01)
+
+ assert sync_seen == ["mix"]
+ assert async_seen == ["mix"]
+
+ await client.disconnect()
+
+
+async def test_error_event_then_close_fires_only_once(
+ mocker: MockFixture, caplog: pytest.LogCaptureFixture
+):
+ caplog.set_level(logging.ERROR)
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ received = []
+
+ def on_error(_client, err):
+ received.append(err)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Error, on_error)
+ await client.connect(_default_params())
+
+ fake_ws.push_message(
+ json.dumps({"type": "Error", "error": "Invalid API key", "error_code": 4001})
+ )
+ fake_ws.push_close(ConnectionClosed(rcvd=Close(4001, "Not Authorized"), sent=None))
+
+ await _wait_for_tasks(client)
+
+ assert len(received) == 1
+ assert str(received[0]) == "Invalid API key"
+ assert received[0].code == 4001
+
+ error_logs = [
+ rec
+ for rec in caplog.records
+ if "Streaming error" in rec.message and "4001" in rec.message
+ ]
+ close_logs = [
+ rec
+ for rec in caplog.records
+ if "Connection closed" in rec.message and "4001" in rec.message
+ ]
+ assert len(error_logs) == 1
+ # ``_report_server_error`` closes the websocket locally and sets stop, so
+ # the read loop exits before the pushed trailing close is recv'd. No close
+ # log is emitted in this path — the Error event already captured the cause.
+ assert close_logs == []
+
+ await client.disconnect()
+
+
+async def test_server_error_without_trailing_close_tears_down(mocker: MockFixture):
+ """Regression: a server ``Error`` frame with no trailing close must still
+ drive the read loop to exit. Without local teardown in
+ ``_report_server_error``, ``await ws.recv()`` would block indefinitely."""
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ received = []
+
+ def on_error(_client, err):
+ received.append(err)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Error, on_error)
+ await client.connect(_default_params())
+
+ # Push an Error frame and nothing else — no trailing close.
+ fake_ws.push_message(
+ json.dumps({"type": "Error", "error": "boom", "error_code": 4002})
+ )
+
+ # If teardown is missing this raises AssertionError after timeout.
+ await _wait_for_tasks(client)
+
+ assert len(received) == 1
+ assert received[0].code == 4002
+ assert fake_ws.close_call_count >= 1
+
+ await client.disconnect()
+
+
+async def test_clean_close_emits_no_error_or_log(
+ mocker: MockFixture, caplog: pytest.LogCaptureFixture
+):
+ caplog.set_level(logging.ERROR)
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ received = []
+
+ def on_error(_client, err):
+ received.append(err)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Error, on_error)
+ await client.connect(_default_params())
+
+ fake_ws.push_close(ConnectionClosed(rcvd=Close(1000, "session ended"), sent=None))
+
+ await _wait_for_tasks(client)
+
+ assert received == []
+ error_logs = [rec for rec in caplog.records if rec.levelno >= logging.ERROR]
+ assert error_logs == []
+
+ await client.disconnect()
+
+
+async def test_turn_handler_exception_does_not_kill_read_task(mocker: MockFixture):
+ """A raising Turn handler must not propagate out of the read task; the
+ next inbound message should still be delivered."""
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ seen = []
+
+ def bad_handler(_client, _turn):
+ raise RuntimeError("boom")
+
+ def good_handler(_client, turn):
+ seen.append(turn.end_of_turn)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Turn, bad_handler)
+ client.on(StreamingEvents.Turn, good_handler)
+ await client.connect(_default_params())
+
+ turn_payload = {
+ "type": "Turn",
+ "turn_order": 1,
+ "turn_is_formatted": False,
+ "end_of_turn": False,
+ "transcript": "hello",
+ "end_of_turn_confidence": 0.5,
+ "words": [],
+ }
+ fake_ws.push_message(json.dumps(turn_payload))
+ fake_ws.push_message(json.dumps({**turn_payload, "turn_order": 2}))
+
+ for _ in range(100):
+ if len(seen) == 2:
+ break
+ await asyncio.sleep(0.01)
+
+ assert seen == [False, False]
+
+ await client.disconnect()
+
+
+async def test_warning_handler_exception_does_not_kill_read_task(mocker: MockFixture):
+ """A raising Warning handler must not propagate out of the read task."""
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ received = []
+
+ def bad_handler(_client, _warning):
+ raise RuntimeError("boom")
+
+ def good_handler(_client, warning):
+ received.append(warning.warning_code)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Warning, bad_handler)
+ client.on(StreamingEvents.Warning, good_handler)
+ await client.connect(_default_params())
+
+ fake_ws.push_message(
+ json.dumps({"type": "Warning", "warning": "first", "warning_code": 1})
+ )
+ fake_ws.push_message(
+ json.dumps({"type": "Warning", "warning": "second", "warning_code": 2})
+ )
+
+ for _ in range(100):
+ if len(received) == 2:
+ break
+ await asyncio.sleep(0.01)
+
+ assert received == [1, 2]
+
+ await client.disconnect()
+
+
+async def test_stream_before_connect_raises_runtime_error():
+ """``stream()`` called before ``connect()`` must raise RuntimeError rather
+ than silently dropping data. Silent drop would diverge from the sync client
+ (which buffers pre-connect data) in a way that's easy to miss — explicit
+ failure surfaces the misuse."""
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+
+ async def gen():
+ yield b"x"
+
+ for data in (b"\x00" * 10, iter([b"a", b"b"]), gen()):
+ with pytest.raises(RuntimeError, match="not connected"):
+ await client.stream(data)
+
+
+async def test_set_params_before_connect_raises_runtime_error():
+ from assemblyai.streaming.v3 import (
+ StreamingSessionParameters,
+ )
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ with pytest.raises(RuntimeError, match="not connected"):
+ await client.set_params(StreamingSessionParameters(min_turn_silence=200))
+
+
+async def test_force_endpoint_before_connect_raises_runtime_error():
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ with pytest.raises(RuntimeError, match="not connected"):
+ await client.force_endpoint()
+
+
+async def test_stream_after_close_is_noop(mocker: MockFixture):
+ """Post-close ``stream()`` must stay a silent no-op so user cleanup paths
+ (e.g. a finally block draining a queue) don't have to wrap each call in
+ try/except. Pre-connect raise + post-close no-op gives both: misuse is
+ loud, cleanup is quiet."""
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ # Simulate a clean close — read task exits, _stop_event is set.
+ fake_ws.push_close(ConnectionClosed(rcvd=Close(1000, "bye"), sent=None))
+ await _wait_for_tasks(client)
+
+ # No raise: post-close stream is safe for cleanup.
+ await client.stream(b"\x00" * 10)
+ await client.disconnect()
+
+
+async def test_handler_exception_does_not_block_shutdown(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ def bad_handler(_client, _err):
+ raise RuntimeError("boom")
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Error, bad_handler)
+ await client.connect(_default_params())
+
+ fake_ws.push_close(ConnectionClosed(rcvd=Close(1011, "server error"), sent=None))
+
+ await _wait_for_tasks(client)
+ # If the handler exception had escaped, _wait_for_tasks would time out.
+ assert client._read_task.done()
+
+ await client.disconnect()
+
+
+async def test_invalid_status_during_connect_dispatches_error(mocker: MockFixture):
+ received = []
+
+ def on_error(_client, err):
+ received.append(err)
+
+ response = type("R", (), {"status_code": 401})()
+ err = InvalidStatus(response=response)
+
+ async def failing_connect(*_args, **_kwargs):
+ raise err
+
+ mocker.patch(
+ "assemblyai.streaming.v3.async_client.websocket_connect_async",
+ new=failing_connect,
+ )
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Error, on_error)
+
+ await client.connect(_default_params())
+
+ assert len(received) == 1
+ assert received[0].code == 401
+ assert "HTTP 401" in str(received[0])
+
+
+async def test_terminate_session_bypasses_stop_gate(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ # Pre-set stop, then queue a TerminateSession directly. The write loop must
+ # still send it before exiting.
+ client._stop_event.set()
+ await client._write_queue.put(TerminateSession())
+
+ for _ in range(100):
+ if fake_ws.send_call_count >= 1:
+ break
+ await asyncio.sleep(0.01)
+
+ assert fake_ws.send_call_count >= 1
+ assert any(isinstance(s, str) and "Terminate" in s for s in fake_ws.sent)
+
+ await client.disconnect()
+
+
+async def test_create_temporary_token(mocker: MockFixture):
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+
+ captured = {}
+
+ async def fake_get(self, url, params=None):
+ captured["url"] = url
+ captured["params"] = params
+
+ class R:
+ def raise_for_status(self_inner):
+ pass
+
+ def json(self_inner):
+ return {"token": "tmp-tok"}
+
+ return R()
+
+ mocker.patch("httpx.AsyncClient.get", new=fake_get)
+
+ token = await client.create_temporary_token(
+ expires_in_seconds=60, max_session_duration_seconds=600
+ )
+ assert token == "tmp-tok"
+ assert captured["url"] == "/v3/token"
+ assert captured["params"] == {
+ "expires_in_seconds": 60,
+ "max_session_duration_seconds": 600,
+ }
+
+ await client._client.aclose()
+
+
+async def test_create_temporary_token_forwards_zero_expires(mocker: MockFixture):
+ """Regression: ``expires_in_seconds=0`` must reach the server (so it can
+ reject it with a clear error) rather than being silently dropped by a
+ falsy check."""
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+
+ captured = {}
+
+ async def fake_get(self, url, params=None):
+ captured["params"] = params
+
+ class R:
+ def raise_for_status(self_inner):
+ pass
+
+ def json(self_inner):
+ return {"token": "tmp-tok"}
+
+ return R()
+
+ mocker.patch("httpx.AsyncClient.get", new=fake_get)
+
+ await client.create_temporary_token(expires_in_seconds=0)
+
+ assert captured["params"] == {"expires_in_seconds": 0}
+
+ await client._client.aclose()
+
+
+async def test_connect_twice_raises(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ with pytest.raises(RuntimeError, match="already been connected"):
+ await client.connect(_default_params())
+
+ await client.disconnect()
+
+
+async def test_connect_after_handshake_failure_raises(mocker: MockFixture):
+ """Regression: a failed connect leaves ``_connection_closed_reported`` set
+ and ``_stop_event`` set. A second ``connect()`` attempt on the same client
+ must surface a clear error, not silently produce a dead read/write loop."""
+ response = type("R", (), {"status_code": 401})()
+ err = InvalidStatus(response=response)
+
+ async def failing_connect(*_args, **_kwargs):
+ raise err
+
+ mocker.patch(
+ "assemblyai.streaming.v3.async_client.websocket_connect_async",
+ new=failing_connect,
+ )
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+
+ await client.connect(_default_params())
+
+ with pytest.raises(RuntimeError, match="already been connected"):
+ await client.connect(_default_params())
+
+
+async def test_set_params_enqueues_update_configuration(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ from assemblyai.streaming.v3.models import (
+ StreamingSessionParameters,
+ )
+
+ await client.set_params(
+ StreamingSessionParameters(end_of_turn_confidence_threshold=0.5)
+ )
+
+ for _ in range(100):
+ update_frames = [
+ s for s in fake_ws.sent if isinstance(s, str) and "UpdateConfiguration" in s
+ ]
+ if update_frames:
+ break
+ await asyncio.sleep(0.01)
+
+ update_frames = [
+ s for s in fake_ws.sent if isinstance(s, str) and "UpdateConfiguration" in s
+ ]
+ assert len(update_frames) == 1
+ payload = json.loads(update_frames[0])
+ assert payload["type"] == "UpdateConfiguration"
+ assert payload["end_of_turn_confidence_threshold"] == 0.5
+
+ await client.disconnect()
+
+
+async def test_force_endpoint_enqueues_force_endpoint_frame(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+
+ await client.force_endpoint()
+
+ for _ in range(100):
+ force_frames = [
+ s for s in fake_ws.sent if isinstance(s, str) and "ForceEndpoint" in s
+ ]
+ if force_frames:
+ break
+ await asyncio.sleep(0.01)
+
+ force_frames = [
+ s for s in fake_ws.sent if isinstance(s, str) and "ForceEndpoint" in s
+ ]
+ assert len(force_frames) == 1
+ payload = json.loads(force_frames[0])
+ assert payload["type"] == "ForceEndpoint"
+
+ await client.disconnect()
+
+
+async def test_warning_event_dispatched_to_handler(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ received = []
+
+ def on_warning(_client, event):
+ received.append(event)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Warning, on_warning)
+ await client.connect(_default_params())
+
+ fake_ws.push_message(
+ json.dumps({"type": "Warning", "warning": "slow audio", "warning_code": 1234})
+ )
+
+ for _ in range(100):
+ if received:
+ break
+ await asyncio.sleep(0.01)
+
+ assert len(received) == 1
+ assert received[0].warning == "slow audio"
+ assert received[0].warning_code == 1234
+
+ await client.disconnect()
+
+
+async def test_termination_event_sets_stop_and_dispatches(mocker: MockFixture):
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ received = []
+
+ def on_termination(_client, event):
+ received.append(event)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Termination, on_termination)
+ await client.connect(_default_params())
+
+ fake_ws.push_message(
+ json.dumps(
+ {
+ "type": "Termination",
+ "audio_duration_seconds": 12,
+ "session_duration_seconds": 15,
+ }
+ )
+ )
+
+ # Termination sets stop_event but doesn't close the socket; wait for the
+ # handler to fire and stop_event to flip.
+ for _ in range(100):
+ if received and client._stop_event is not None and client._stop_event.is_set():
+ break
+ await asyncio.sleep(0.01)
+
+ assert len(received) == 1
+ assert client._stop_event is not None
+ assert client._stop_event.is_set()
+
+ await client.disconnect()
+
+
+async def test_disconnect_before_connect_is_safe_noop(mocker: MockFixture):
+ """``disconnect()`` is safe before ``connect()``. With the httpx client
+ lazy-constructed (no work done in ``__init__``), there is nothing to close
+ on a never-used client, so ``aclose`` should not be invoked."""
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+
+ closed = []
+
+ async def fake_aclose(self):
+ closed.append(True)
+
+ mocker.patch("httpx.AsyncClient.aclose", new=fake_aclose)
+
+ await client.disconnect()
+
+ # Nothing was ever instantiated, so nothing to close.
+ assert closed == []
+ assert client._read_task is None
+ assert client._write_task is None
+
+
+async def test_construct_only_does_not_instantiate_httpx_client(
+ mocker: MockFixture,
+):
+ """Constructing an ``AsyncStreamingClient`` and never calling
+ ``connect()`` / ``create_temporary_token()`` / ``disconnect()`` must not
+ instantiate an ``httpx.AsyncClient`` — otherwise an unused client leaks
+ the pool. The HTTP client should be built lazily on first use."""
+ import httpx
+
+ constructed = []
+ real_init = httpx.AsyncClient.__init__
+
+ def counting_init(self, *args, **kwargs):
+ constructed.append(True)
+ return real_init(self, *args, **kwargs)
+
+ mocker.patch.object(httpx.AsyncClient, "__init__", counting_init)
+
+ AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+
+ assert constructed == [], (
+ "AsyncStreamingClient should not eagerly instantiate httpx.AsyncClient; "
+ "got constructions: " + str(constructed)
+ )
+
+
+async def test_async_context_manager_calls_disconnect_on_exit(mocker: MockFixture):
+ """``async with AsyncStreamingClient(opts) as c:`` must invoke
+ ``disconnect()`` on block exit so callers can't forget cleanup."""
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ async with AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ ) as client:
+ await client.connect(_default_params())
+ await client.stream(b"\x00" * 32)
+
+ # On exit, disconnect should have torn down read/write tasks.
+ assert client._read_task is not None and client._read_task.done()
+ assert client._write_task is not None and client._write_task.done()
+ assert client._stop_event is not None and client._stop_event.is_set()
+
+
+async def test_async_context_manager_disconnect_runs_on_exception(
+ mocker: MockFixture,
+):
+ """Exception inside the ``async with`` body must still trigger
+ ``disconnect()`` so the websocket / http client don't leak when user
+ code raises."""
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ class _Boom(Exception):
+ pass
+
+ client_ref = {}
+
+ with pytest.raises(_Boom):
+ async with AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ ) as client:
+ client_ref["c"] = client
+ await client.connect(_default_params())
+ raise _Boom()
+
+ client = client_ref["c"]
+ assert client._stop_event is not None and client._stop_event.is_set()
+ assert client._websocket is None or fake_ws.close_call_count >= 1
+
+
+async def test_disconnect_closes_http_client_when_used(mocker: MockFixture):
+ """Once the lazy ``httpx.AsyncClient`` has been instantiated (by a call
+ that goes through HTTP — e.g. ``create_temporary_token``), ``disconnect``
+ must close it so the pool doesn't leak."""
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ async def fake_get(self, url, params=None):
+ class _R:
+ def raise_for_status(self):
+ pass
+
+ def json(self):
+ return {"token": "t"}
+
+ return _R()
+
+ mocker.patch("httpx.AsyncClient.get", new=fake_get)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ await client.connect(_default_params())
+ # Force the http client to be instantiated.
+ await client.create_temporary_token(expires_in_seconds=60)
+
+ closed = []
+
+ async def fake_aclose(self):
+ closed.append(True)
+
+ mocker.patch("httpx.AsyncClient.aclose", new=fake_aclose)
+
+ await client.disconnect()
+
+ assert closed == [True]
+
+
+async def test_server_error_dedups_concurrent_write_side_close(mocker: MockFixture):
+ """Regression: a slow async ``on_error`` handler must not race a concurrent
+ write-side ``ConnectionClosed`` into a duplicate dispatch. The
+ ``_server_error_reported`` flag is set synchronously before the first
+ ``await`` in ``_report_server_error`` — this test locks in that ordering."""
+ close_exc = ConnectionClosed(rcvd=Close(1011, "send-side close"), sent=None)
+ fake_ws = _FakeAsyncWebSocket(send_raises=close_exc)
+ _patch_connect(mocker, fake_ws)
+
+ received = []
+ handler_started = asyncio.Event()
+ handler_release = asyncio.Event()
+
+ async def slow_on_error(_client, err):
+ received.append(err)
+ handler_started.set()
+ await handler_release.wait()
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Error, slow_on_error)
+ await client.connect(_default_params())
+
+ # Push a server Error frame; the read task enters the slow handler.
+ fake_ws.push_message(
+ json.dumps({"type": "Error", "error": "boom", "error_code": 4002})
+ )
+ await asyncio.wait_for(handler_started.wait(), timeout=1.0)
+
+ # While the handler is parked, trigger a write-side close concurrently.
+ await client.stream(b"\x00" * 32)
+ for _ in range(50):
+ if fake_ws.send_call_count >= 1:
+ break
+ await asyncio.sleep(0.01)
+
+ # Release the handler; the read task finishes dispatch and exits.
+ handler_release.set()
+
+ await _wait_for_tasks(client)
+
+ assert len(received) == 1, (
+ f"expected exactly one on_error despite concurrent write-side close, "
+ f"got {received}"
+ )
+ assert received[0].code == 4002
+
+ await client.disconnect()
+
+
+async def test_disconnect_during_slow_handler_tears_down(mocker: MockFixture):
+ """Regression: ``disconnect()`` while an async handler is parked in a long
+ ``await`` must cleanly cancel the read task. ``CancelledError`` is a
+ ``BaseException`` (not ``Exception``), so it propagates through
+ ``_invoke_handler`` and out of the read task — ``disconnect()`` then
+ completes the cleanup."""
+ fake_ws = _FakeAsyncWebSocket()
+ _patch_connect(mocker, fake_ws)
+
+ handler_started = asyncio.Event()
+
+ async def slow_handler(_client, _event):
+ handler_started.set()
+ await asyncio.sleep(60)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Begin, slow_handler)
+ await client.connect(_default_params())
+
+ fake_ws.push_message(
+ json.dumps({"type": "Begin", "id": "abc", "expires_at": "2030-01-01T00:00:00"})
+ )
+ await asyncio.wait_for(handler_started.wait(), timeout=1.0)
+
+ # disconnect() should not hang waiting for the parked sleep — the read
+ # task is cancelled, CancelledError propagates, and disconnect returns.
+ await asyncio.wait_for(client.disconnect(), timeout=2.0)
+
+ assert client._read_task.done()
+
+
+async def test_write_side_close_is_dispatched_when_read_short_circuits_on_stop(
+ mocker: MockFixture, caplog: pytest.LogCaptureFixture
+):
+ """Regression: if the read task observes ``_stop_event`` at the top of its
+ loop (e.g. after processing a buffered message) before its next ``recv()``
+ raises, the write task must still dispatch the connection-closed event.
+ Previously the write task only set stop and exited, so this close went
+ unreported."""
+ caplog.set_level(logging.ERROR)
+
+ close_exc = ConnectionClosed(rcvd=Close(1011, "send-side close"), sent=None)
+ fake_ws = _FakeAsyncWebSocket(send_raises=close_exc)
+ _patch_connect(mocker, fake_ws)
+
+ received = []
+
+ def on_error(_client, err):
+ received.append(err)
+
+ client = AsyncStreamingClient(
+ StreamingClientOptions(api_key="test", api_host="api.example.com")
+ )
+ client.on(StreamingEvents.Error, on_error)
+ await client.connect(_default_params())
+
+ # Queue a write so the write task hits send() and raises ConnectionClosed.
+ await client.stream(b"\x00" * 32)
+
+ # Wait for write task to finish dispatching the close.
+ for _ in range(200):
+ if received:
+ break
+ await asyncio.sleep(0.01)
+
+ assert len(received) == 1, (
+ f"expected exactly one on_error from write-side close, got {received}"
+ )
+ assert received[0].code == 1011
+
+ await client.disconnect()
diff --git a/tox.ini b/tox.ini
index 3bfddd2..23daedf 100644
--- a/tox.ini
+++ b/tox.ini
@@ -27,7 +27,14 @@ deps =
pytest-xdist
pytest-mock
pytest-cov
+ pytest-asyncio
factory-boy
allowlist_externals = pytest
commands = pytest -n auto --cov-report term --cov-report xml:coverage.xml --cov=assemblyai
+
+[pytest]
+# Streaming async tests use explicit ``pytestmark = pytest.mark.asyncio``.
+# ``strict`` keeps that opt-in pattern and silences the pytest-asyncio
+# unset-mode deprecation warning on >=0.21.
+asyncio_mode = strict