

test: switch to async http client
@875890de74d5e375fc33225b3abdf4a140cfa9f7
--- pyproject.toml
+++ pyproject.toml
... | ... | @@ -30,6 +30,7 @@ |
30 | 30 |
"basedpyright==1.13.0", |
31 | 31 |
"pytest-xdist==3.6.1", |
32 | 32 |
"pytest-asyncio>=0.24.0", |
33 |
+ "anyio>=4.4.0", |
|
33 | 34 |
] |
34 | 35 |
|
35 | 36 |
[build-system] |
--- tests/api_timestamp_granularities_test.py
+++ tests/api_timestamp_granularities_test.py
... | ... | @@ -1,5 +1,7 @@ |
1 | 1 |
"""See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501 |
2 | 2 |
|
3 |
+from pathlib import Path |
|
4 |
+ |
|
3 | 5 |
from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities |
4 | 6 |
from openai import AsyncOpenAI |
5 | 7 |
import pytest |
... | ... | @@ -11,10 +13,10 @@ |
11 | 13 |
openai_client: AsyncOpenAI, |
12 | 14 |
timestamp_granularities: TimestampGranularities, |
13 | 15 |
) -> None: |
14 |
- audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 |
|
16 |
+ file_path = Path("audio.wav") |
|
15 | 17 |
|
16 | 18 |
await openai_client.audio.transcriptions.create( |
17 |
- file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities |
|
19 |
+ file=file_path, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities |
|
18 | 20 |
) |
19 | 21 |
|
20 | 22 |
|
... | ... | @@ -24,10 +26,10 @@ |
24 | 26 |
openai_client: AsyncOpenAI, |
25 | 27 |
timestamp_granularities: TimestampGranularities, |
26 | 28 |
) -> None: |
27 |
- audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 |
|
29 |
+ file_path = Path("audio.wav") |
|
28 | 30 |
|
29 | 31 |
transcription = await openai_client.audio.transcriptions.create( |
30 |
- file=audio_file, |
|
32 |
+ file=file_path, |
|
31 | 33 |
model="whisper-1", |
32 | 34 |
response_format="verbose_json", |
33 | 35 |
timestamp_granularities=timestamp_granularities, |
--- tests/conftest.py
+++ tests/conftest.py
... | ... | @@ -18,6 +18,7 @@ |
18 | 18 |
logger.disabled = True |
19 | 19 |
|
20 | 20 |
|
21 |
+# NOTE: not being used. Keeping just in case |
|
21 | 22 |
@pytest.fixture() |
22 | 23 |
def client() -> Generator[TestClient, None, None]: |
23 | 24 |
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" |
--- tests/openai_timestamp_granularities_test.py
+++ tests/openai_timestamp_granularities_test.py
... | ... | @@ -1,5 +1,7 @@ |
1 | 1 |
"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501 |
2 | 2 |
|
3 |
+from pathlib import Path |
|
4 |
+ |
|
3 | 5 |
from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities |
4 | 6 |
from openai import AsyncOpenAI, BadRequestError |
5 | 7 |
import pytest |
... | ... | @@ -12,19 +14,18 @@ |
12 | 14 |
actual_openai_client: AsyncOpenAI, |
13 | 15 |
timestamp_granularities: TimestampGranularities, |
14 | 16 |
) -> None: |
15 |
- audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 |
|
16 |
- |
|
17 |
+ file_path = Path("audio.wav") |
|
17 | 18 |
if "word" in timestamp_granularities: |
18 | 19 |
with pytest.raises(BadRequestError): |
19 | 20 |
await actual_openai_client.audio.transcriptions.create( |
20 |
- file=audio_file, |
|
21 |
+ file=file_path, |
|
21 | 22 |
model="whisper-1", |
22 | 23 |
response_format="json", |
23 | 24 |
timestamp_granularities=timestamp_granularities, |
24 | 25 |
) |
25 | 26 |
else: |
26 | 27 |
await actual_openai_client.audio.transcriptions.create( |
27 |
- file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities |
|
28 |
+ file=file_path, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities |
|
28 | 29 |
) |
29 | 30 |
|
30 | 31 |
|
... | ... | @@ -35,10 +36,10 @@ |
35 | 36 |
actual_openai_client: AsyncOpenAI, |
36 | 37 |
timestamp_granularities: TimestampGranularities, |
37 | 38 |
) -> None: |
38 |
- audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 |
|
39 |
+ file_path = Path("audio.wav") |
|
39 | 40 |
|
40 | 41 |
transcription = await actual_openai_client.audio.transcriptions.create( |
41 |
- file=audio_file, |
|
42 |
+ file=file_path, |
|
42 | 43 |
model="whisper-1", |
43 | 44 |
response_format="verbose_json", |
44 | 45 |
timestamp_granularities=timestamp_granularities, |
--- tests/sse_test.py
+++ tests/sse_test.py
... | ... | @@ -1,12 +1,13 @@ |
1 | 1 |
import json |
2 | 2 |
import os |
3 | 3 |
|
4 |
-from fastapi.testclient import TestClient |
|
4 |
+import anyio |
|
5 | 5 |
from faster_whisper_server.api_models import ( |
6 | 6 |
CreateTranscriptionResponseJson, |
7 | 7 |
CreateTranscriptionResponseVerboseJson, |
8 | 8 |
) |
9 |
-from httpx_sse import connect_sse |
|
9 |
+from httpx import AsyncClient |
|
10 |
+from httpx_sse import aconnect_sse |
|
10 | 11 |
import pytest |
11 | 12 |
import srt |
12 | 13 |
import webvtt |
... | ... | @@ -22,57 +23,61 @@ |
22 | 23 |
parameters = [(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS] |
23 | 24 |
|
24 | 25 |
|
26 |
+@pytest.mark.asyncio() |
|
25 | 27 |
@pytest.mark.parametrize(("file_path", "endpoint"), parameters) |
26 |
-def test_streaming_transcription_text(client: TestClient, file_path: str, endpoint: str) -> None: |
|
28 |
+async def test_streaming_transcription_text(aclient: AsyncClient, file_path: str, endpoint: str) -> None: |
|
27 | 29 |
extension = os.path.splitext(file_path)[1] |
28 |
- with open(file_path, "rb") as f: |
|
29 |
- data = f.read() |
|
30 |
+ async with await anyio.open_file(file_path, "rb") as f: |
|
31 |
+ data = await f.read() |
|
30 | 32 |
kwargs = { |
31 | 33 |
"files": {"file": (f"audio.{extension}", data, f"audio/{extension}")}, |
32 | 34 |
"data": {"response_format": "text", "stream": True}, |
33 | 35 |
} |
34 |
- with connect_sse(client, "POST", endpoint, **kwargs) as event_source: |
|
35 |
- for event in event_source.iter_sse(): |
|
36 |
+ async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source: |
|
37 |
+ async for event in event_source.aiter_sse(): |
|
36 | 38 |
print(event) |
37 | 39 |
assert len(event.data) > 1 # HACK: 1 because of the space character that's always prepended |
38 | 40 |
|
39 | 41 |
|
42 |
+@pytest.mark.asyncio() |
|
40 | 43 |
@pytest.mark.parametrize(("file_path", "endpoint"), parameters) |
41 |
-def test_streaming_transcription_json(client: TestClient, file_path: str, endpoint: str) -> None: |
|
44 |
+async def test_streaming_transcription_json(aclient: AsyncClient, file_path: str, endpoint: str) -> None: |
|
42 | 45 |
extension = os.path.splitext(file_path)[1] |
43 |
- with open(file_path, "rb") as f: |
|
44 |
- data = f.read() |
|
46 |
+ async with await anyio.open_file(file_path, "rb") as f: |
|
47 |
+ data = await f.read() |
|
45 | 48 |
kwargs = { |
46 | 49 |
"files": {"file": (f"audio.{extension}", data, f"audio/{extension}")}, |
47 | 50 |
"data": {"response_format": "json", "stream": True}, |
48 | 51 |
} |
49 |
- with connect_sse(client, "POST", endpoint, **kwargs) as event_source: |
|
50 |
- for event in event_source.iter_sse(): |
|
52 |
+ async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source: |
|
53 |
+ async for event in event_source.aiter_sse(): |
|
51 | 54 |
CreateTranscriptionResponseJson(**json.loads(event.data)) |
52 | 55 |
|
53 | 56 |
|
57 |
+@pytest.mark.asyncio() |
|
54 | 58 |
@pytest.mark.parametrize(("file_path", "endpoint"), parameters) |
55 |
-def test_streaming_transcription_verbose_json(client: TestClient, file_path: str, endpoint: str) -> None: |
|
59 |
+async def test_streaming_transcription_verbose_json(aclient: AsyncClient, file_path: str, endpoint: str) -> None: |
|
56 | 60 |
extension = os.path.splitext(file_path)[1] |
57 |
- with open(file_path, "rb") as f: |
|
58 |
- data = f.read() |
|
61 |
+ async with await anyio.open_file(file_path, "rb") as f: |
|
62 |
+ data = await f.read() |
|
59 | 63 |
kwargs = { |
60 | 64 |
"files": {"file": (f"audio.{extension}", data, f"audio/{extension}")}, |
61 | 65 |
"data": {"response_format": "verbose_json", "stream": True}, |
62 | 66 |
} |
63 |
- with connect_sse(client, "POST", endpoint, **kwargs) as event_source: |
|
64 |
- for event in event_source.iter_sse(): |
|
67 |
+ async with aconnect_sse(aclient, "POST", endpoint, **kwargs) as event_source: |
|
68 |
+ async for event in event_source.aiter_sse(): |
|
65 | 69 |
CreateTranscriptionResponseVerboseJson(**json.loads(event.data)) |
66 | 70 |
|
67 | 71 |
|
68 |
-def test_transcription_vtt(client: TestClient) -> None: |
|
69 |
- with open("audio.wav", "rb") as f: |
|
70 |
- data = f.read() |
|
72 |
+@pytest.mark.asyncio() |
|
73 |
+async def test_transcription_vtt(aclient: AsyncClient) -> None: |
|
74 |
+ async with await anyio.open_file("audio.wav", "rb") as f: |
|
75 |
+ data = await f.read() |
|
71 | 76 |
kwargs = { |
72 | 77 |
"files": {"file": ("audio.wav", data, "audio/wav")}, |
73 | 78 |
"data": {"response_format": "vtt", "stream": False}, |
74 | 79 |
} |
75 |
- response = client.post("/v1/audio/transcriptions", **kwargs) |
|
80 |
+ response = await aclient.post("/v1/audio/transcriptions", **kwargs) |
|
76 | 81 |
assert response.status_code == 200 |
77 | 82 |
assert response.headers["content-type"] == "text/vtt; charset=utf-8" |
78 | 83 |
text = response.text |
... | ... | @@ -82,14 +87,15 @@ |
82 | 87 |
webvtt.from_string(text) |
83 | 88 |
|
84 | 89 |
|
85 |
-def test_transcription_srt(client: TestClient) -> None: |
|
86 |
- with open("audio.wav", "rb") as f: |
|
87 |
- data = f.read() |
|
90 |
+@pytest.mark.asyncio() |
|
91 |
+async def test_transcription_srt(aclient: AsyncClient) -> None: |
|
92 |
+ async with await anyio.open_file("audio.wav", "rb") as f: |
|
93 |
+ data = await f.read() |
|
88 | 94 |
kwargs = { |
89 | 95 |
"files": {"file": ("audio.wav", data, "audio/wav")}, |
90 | 96 |
"data": {"response_format": "srt", "stream": False}, |
91 | 97 |
} |
92 |
- response = client.post("/v1/audio/transcriptions", **kwargs) |
|
98 |
+ response = await aclient.post("/v1/audio/transcriptions", **kwargs) |
|
93 | 99 |
assert response.status_code == 200 |
94 | 100 |
assert "text/plain" in response.headers["content-type"] |
95 | 101 |
|
--- uv.lock
+++ uv.lock
... | ... | @@ -295,6 +295,7 @@ |
295 | 295 |
{ name = "keyboard" }, |
296 | 296 |
] |
297 | 297 |
dev = [ |
298 |
+ { name = "anyio" }, |
|
298 | 299 |
{ name = "basedpyright" }, |
299 | 300 |
{ name = "pytest" }, |
300 | 301 |
{ name = "pytest-asyncio" }, |
... | ... | @@ -306,6 +307,7 @@ |
306 | 307 |
|
307 | 308 |
[package.metadata] |
308 | 309 |
requires-dist = [ |
310 |
+ { name = "anyio", marker = "extra == 'dev'", specifier = ">=4.4.0" }, |
|
309 | 311 |
{ name = "basedpyright", marker = "extra == 'dev'", specifier = "==1.13.0" }, |
310 | 312 |
{ name = "fastapi", specifier = "==0.112.4" }, |
311 | 313 |
{ name = "faster-whisper", specifier = "==1.0.3" }, |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?