Fedir Zadniprovskyi 2024-05-23
refactor: simplify tests
@a9ee91b071c59d4d2ba3e8ead55f032b0c0e0e1e
tests/app_test.py
--- tests/app_test.py
+++ tests/app_test.py
@@ -15,6 +15,9 @@
 from speaches.server_models import TranscriptionVerboseResponse
 
 SIMILARITY_THRESHOLD = 0.97
+AUDIO_FILES_LIMIT = 5
+AUDIO_FILE_DIR = "tests/data"
+TRANSCRIBE_ENDPOINT = "/v1/audio/transcriptions?response_format=verbose_json"
 
 
 @pytest.fixture()
@@ -23,12 +26,17 @@
         yield client
 
 
+@pytest.fixture()
+def ws(client: TestClient) -> Generator[WebSocketTestSession, None, None]:
+    with client.websocket_connect(TRANSCRIBE_ENDPOINT) as ws:
+        yield ws
+
+
 def get_audio_file_paths():
     file_paths = []
     directory = "tests/data"
-    for filename in reversed(os.listdir(directory)[5:6]):
-        if filename.endswith(".raw"):
-            file_paths.append(os.path.join(directory, filename))
+    for filename in sorted(os.listdir(directory)[:AUDIO_FILES_LIMIT]):
+        file_paths.append(os.path.join(directory, filename))
     return file_paths
 
 
@@ -48,7 +56,7 @@
     client: TestClient, data: bytes
 ) -> TranscriptionVerboseResponse:
     response = client.post(
-        "/v1/audio/transcriptions?response_format=verbose_json",
+        TRANSCRIBE_ENDPOINT,
         files={"file": ("audio.raw", data, "audio/raw")},
     )
     data = json.loads(response.json())  # TODO: figure this out
@@ -56,29 +64,26 @@
 
 
 @pytest.mark.parametrize("file_path", file_paths)
-def test_ws_audio_transcriptions(client: TestClient, file_path: str):
+def test_ws_audio_transcriptions(
+    client: TestClient, ws: WebSocketTestSession, file_path: str
+):
     with open(file_path, "rb") as file:
         data = file.read()
-        streaming_transcription: TranscriptionVerboseResponse = None  # type: ignore
-        with client.websocket_connect(
-            "/v1/audio/transcriptions?response_format=verbose_json"
-        ) as ws:
-            thread = threading.Thread(
-                target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0}
-            )
-            thread.start()
-            while True:
-                try:
-                    streaming_transcription = TranscriptionVerboseResponse(
-                        **ws.receive_json()
-                    )
-                except WebSocketDisconnect:
-                    break
-            ws.close()
-        file_transcription = transcribe_audio_data(client, data)
-        s = SequenceMatcher(
-            lambda x: x == " ", file_transcription.text, streaming_transcription.text
-        )
-        assert (
-            s.ratio() > SIMILARITY_THRESHOLD
-        ), f"\nExpected: {file_transcription.text}\nReceived: {streaming_transcription.text}"
+
+    streaming_transcription: TranscriptionVerboseResponse = None  # type: ignore
+    thread = threading.Thread(
+        target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0}
+    )
+    thread.start()
+    while True:
+        try:
+            streaming_transcription = TranscriptionVerboseResponse(**ws.receive_json())
+        except WebSocketDisconnect:
+            break
+    file_transcription = transcribe_audio_data(client, data)
+    s = SequenceMatcher(
+        lambda x: x == " ", file_transcription.text, streaming_transcription.text
+    )
+    assert (
+        s.ratio() > SIMILARITY_THRESHOLD
+    ), f"\nExpected: {file_transcription.text}\nReceived: {streaming_transcription.text}"
Add a comment
List