diff --git a/example/tests/integration/test_includes.py b/example/tests/integration/test_includes.py index 3284c79b..d7f6f6d1 100644 --- a/example/tests/integration/test_includes.py +++ b/example/tests/integration/test_includes.py @@ -1,20 +1,29 @@ -import pytest, json +import pytest from django.core.urlresolvers import reverse + from example.tests.utils import load_json pytestmark = pytest.mark.django_db def test_included_data_on_list(multiple_entries, client): - multiple_entries[1].comments = [] - response = client.get(reverse("entry-list") + '?include=comments') + response = client.get(reverse("entry-list") + '?include=comments&page_size=5') included = load_json(response.content).get('included') - assert [x.get('type') for x in included] == ['comments'] + assert len(load_json(response.content)['data']) == len(multiple_entries), 'Incorrect entry count' + assert [x.get('type') for x in included] == ['comments', 'comments'], 'List included types are incorrect' + + comment_count = len([resource for resource in included if resource["type"] == "comments"]) + expected_comment_count = sum([entry.comment_set.count() for entry in multiple_entries]) + assert comment_count == expected_comment_count, 'List comment count is incorrect' + def test_included_data_on_detail(single_entry, client): response = client.get(reverse("entry-detail", kwargs={'pk': single_entry.pk}) + '?include=comments') included = load_json(response.content).get('included') - assert [x.get('type') for x in included] == ['comments'] + assert [x.get('type') for x in included] == ['comments'], 'Detail included types are incorrect' + comment_count = len([resource for resource in included if resource["type"] == "comments"]) + expected_comment_count = single_entry.comment_set.count() + assert comment_count == expected_comment_count, 'Detail comment count is incorrect' diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index 5418141d..147ef177 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -2,6 +2,7 @@ Utils. """ import copy + import inflection from django.conf import settings from django.utils import six, encoding @@ -410,6 +411,7 @@ def extract_included(fields, resource, resource_instance, included_resources): current_serializer = fields.serializer context = current_serializer.context included_serializers = get_included_serializers(current_serializer) + included_resources = copy.copy(included_resources) for field_name, field in six.iteritems(fields): # Skip URL field