0

I have an app that basically ingests documents, stores their embeddings in vector form, then responds to queries that relate to the ingested docs.

I have this retrieval.py function which searches through ingested and selected documents and returns which files have matching vectors.

I'm testing a FastAPI service that queries a database using SQLAlchemy's async execution. However, when I mock the database query, execute().scalars().all() returns an empty list instead of the expected [1, 2, 3].

import numpy as np
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import text
from src.backend.database.config import AsyncSessionLocal
from src.services.ingestion_service.embedding_generator import EmbeddingGenerator
import asyncio


class RetrievalService:
    """
    Handles retrieval of similar documents based on query embeddings.
    """

    def __init__(self):
        self.embedding_generator = EmbeddingGenerator()

    async def retrieve_relevant_docs(self, query: str, top_k: int = 5):
        """
        Converts the query into an embedding and retrieves the most similar documents asynchronously.
        """
        async with AsyncSessionLocal() as db:
            async with db.begin():
                # ✅ Generate embedding in a separate thread
                query_embedding = await self.embedding_generator.generate_embedding(query)

                
                print('query_embedding', query_embedding)

                # ✅ Convert NumPy array to PostgreSQL-compatible format
                query_embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"

                # ✅ Fetch selected document IDs asynchronously
                selected_ids_result = await db.execute(text("SELECT document_id FROM selected_documents;"))
                print('selected_ids_result', selected_ids_result)
                selected_ids = (await selected_ids_result.scalars()).all()

                # ✅ Ensure selected_ids is not empty to prevent SQL errors
                if not selected_ids:
                    selected_ids = [-1]  # Dummy ID to avoid SQL failure

                # ✅ Execute the vector similarity search query
                search_query = text("""
                    SELECT document_id FROM embeddings
                    WHERE document_id = ANY(:selected_ids)
                    ORDER BY vector <-> CAST(:query_embedding AS vector)
                    LIMIT :top_k;
                """).execution_options(cacheable=False)

                results = await db.execute(
                    search_query,
                    {
                        "query_embedding": query_embedding_str,  # Pass as string
                        "top_k": top_k,
                        "selected_ids": selected_ids,
                    },
                )

                # document_ids = (await results.scalars()).all()\
                print('debug results', vars(results))
                document_ids = list(await results.scalars())

                print('document_ids', document_ids)
                return document_ids

    async def get_document_texts(self, document_ids: list[int]):
        """
        Fetches the actual document texts for the given document IDs.
        """
        if not document_ids:
            return []

        async with AsyncSessionLocal() as db:
            async with db.begin():
                query = text("SELECT content FROM documents WHERE id = ANY(:document_ids);")
                results = await db.execute(query, {"document_ids": document_ids})
                return (await results.scalars()).all()

I have a simple tests file:

import pytest
from unittest.mock import patch, AsyncMock
from src.services.retrieval_service.retrieval import RetrievalService
from sqlalchemy.ext.asyncio import AsyncSession

import pytest
from unittest.mock import patch, AsyncMock
from src.services.retrieval_service.retrieval import RetrievalService
from sqlalchemy.ext.asyncio import AsyncSession

@pytest.mark.asyncio
async def test_retrieve_relevant_docs_valid_query():
    service = RetrievalService()
    query = "What is AI?"
    top_k = 3

    with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
         patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        
        mock_generate_embedding.return_value = [0.1] * 384

        # Mock DB query for selected documents
        mock_scalars_selected = AsyncMock()
        mock_scalars_selected.scalars.return_value.all.return_value = [1, 2, 3]
        
        mock_execute.side_effect = [mock_scalars_selected, mock_scalars_selected]
        # Mock the `execute` method
        mock_execute.return_value = mock_scalars_selected

        # Call the method
        document_ids = await service.retrieve_relevant_docs(query, top_k)

        # Assertion
        assert document_ids == [1, 2, 3], f"Expected [1, 2, 3] but got {document_ids}"


@pytest.mark.asyncio
async def test_retrieve_relevant_docs_valid_query_1():
    service = RetrievalService()
    query = "What is AI?"
    top_k = 3

    with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
         patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        
        mock_generate_embedding.return_value = [0.1] * 384

        # Mock DB query for selected documents
        mock_scalars_selected = AsyncMock()
        mock_scalars_selected.all = AsyncMock(return_value=[1, 2, 3])
        mock_execute.side_effect = [AsyncMock(return_value=AsyncMock(scalars=mock_scalars_selected)), AsyncMock(return_value=AsyncMock(scalars=mock_scalars_selected))]

        document_ids = await service.retrieve_relevant_docs(query, top_k)
        assert document_ids == [1, 2, 3]


@pytest.mark.asyncio
async def test_retrieve_relevant_docs_no_selected_docs():
    service = RetrievalService()
    query = "What is AI?"
    top_k = 3

    with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
         patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        
        mock_generate_embedding.return_value = [0.1] * 384

        # Mock DB returning no selected docs
        mock_scalars_selected = AsyncMock()
        mock_scalars_selected.all.return_value = []
        mock_execute.return_value = mock_scalars_selected

        document_ids = await service.retrieve_relevant_docs(query, top_k)
        assert document_ids == []

@pytest.mark.asyncio
async def test_retrieve_relevant_docs_empty_query():
    service = RetrievalService()
    query = ""
    top_k = 3

    with patch.object(service.embedding_generator, 'generate_embedding', new_callable=AsyncMock) as mock_generate_embedding, \
         patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        
        mock_generate_embedding.return_value = [0.1] * 384

        # Mock DB returning no documents
        mock_scalars_selected = AsyncMock()
        mock_scalars_selected.all.return_value = []
        mock_execute.return_value = mock_scalars_selected

        document_ids = await service.retrieve_relevant_docs(query, top_k)
        assert document_ids == []

@pytest.mark.asyncio
async def test_get_document_texts_valid_ids():
    service = RetrievalService()
    document_ids = [1, 2, 3]

    with patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        # Mock query result
        mock_scalars = AsyncMock()
        mock_scalars.all.return_value = ["Document 1 text", "Document 2 text", "Document 3 text"]
        mock_execute.return_value = mock_scalars

        document_texts = await service.get_document_texts(document_ids)
        assert document_texts == ["Document 1 text", "Document 2 text", "Document 3 text"]

@pytest.mark.asyncio
async def test_get_document_texts_no_ids():
    service = RetrievalService()
    document_ids = []

    with patch.object(AsyncSession, 'execute', new_callable=AsyncMock) as mock_execute:
        document_texts = await service.get_document_texts(document_ids)
        assert document_texts == []

I have added so much debugging information, but i do not understand why when I am mocking the retreivalservice to have the side effect [1,2,3] and just yield that value after going through the service. I keep getting error that shows that my mock_execute.side_effect is not working at all.

These are the logs that get printed:

selected_ids_result debug results {'_mock_return_value': sentinel.DEFAULT, '_mock_parent': None, '_mock_name': None, '_mock_new_name': '()', '_mock_new_parent': , '_mock_sealed': False, '_spec_class': None, '_spec_set': None, '_spec_signature': None, '_mock_methods': None, '_spec_asyncs': [], '_mock_children': {'scalars': , 'str': }, '_mock_wraps': None, '_mock_delegate': None, '_mock_called': False, '_mock_call_args': None, '_mock_call_count': 0, '_mock_call_args_list': [], '_mock_mock_calls': [call.str(), call.scalars(), call.scalars().all()], 'method_calls': [call.scalars()], '_mock_unsafe': False, '_mock_side_effect': None, '_is_coroutine': <object object at 0x0000029C06E16A30>, '_mock_await_count': 0, '_mock_await_args': None, '_mock_await_args_list': [], 'code': , 'str': } document_ids []

and these errors:

short test summary info ==================================== FAILED src/tests/unit/test_retrieve_docs.py::test_retrieve_relevant_docs_valid_query - AssertionError: Expected [1, 2, 3] but got [] FAILED src/tests/unit/test_retrieve_docs.py::test_retrieve_relevant_docs_valid_query_1 - AssertionError: assert [] == [1, 2, 3] FAILED src/tests/unit/test_retrieve_docs.py::test_get_document_texts_valid_ids - AssertionError: assert <coroutine object AsyncMockMixin._execute_mock_call at 0x000002373EA3B...

Observed Behavior: results.scalars().all() unexpectedly returns [], even though I attempted to mock it. Debugging vars(results) shows _mock_side_effect = None, suggesting the mock isn't working as expected.

Expected Behavior: document_ids should contain [1, 2, 3], matching the mocked return value.

What I've Tried: Explicitly setting scalars().all().return_value = [1, 2, 3]. Checking vars(results) for missing attributes. Ensuring mock_execute.side_effect is properly assigned. Calling await session.execute(...).scalars().all() instead of wrapping it in list(). What is the correct way to mock SQLAlchemy's async execution (session.execute().scalars().all()) in a FastAPI test using AsyncMock? Or Can someone point why my is not behaving as I expect it to?

I feel if i fix one test, all my tests should get fixed the same way. I am not new to python, but very new to sqlalchemy

1 Answer 1

0

Try to use sync Mock for scalars. It worked for me in a similar case. Maybe someone know better way, but this at least works.

override_get_db.execute.return_value = Mock()

# scalar needs to be sync!
mock_scalars = Mock()
mock_scalars.first.return_value = existing_metrics

mock_result = Mock()
mock_result.scalars = Mock(return_value=mock_scalars)

# Only the execute method should be async
mock_db_session.execute = AsyncMock(return_value=mock_result)
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.