I'm using Django REST Framework for our web API and Solr to power a search. Currenlty, in a subclass of ListAPIView
, I override get_queryset()
to get a QuerySet
with the Solr search results:
class ClipList(generics.ListAPIView):
"""
List all Clips.
Permissions: IsAuthenticatedOrReadOnly
Parameters:
query -- Search for Clips. EX: clips/?query=aaron%20rodgers
"""
model = Clip
serializer_class = ClipSerializer
permission_classes = (permissions.IsAuthenticatedOrReadOnly,)
def get_queryset(self):
params = request.GET
query = params.get('query', None)
queryset = Clip.objects.all()
if query is not None:
conn = solr.Solr(settings.SOLR_URL)
sh = solr.SearchHandler(conn, "/select")
response = sh(query)
ids = []
for result in response.results:
ids.append(result['id'])
# filter the initial queryset based on the Clip identifiers from the Solr response
# PROBLEM: This does not preserve the order of the results as it equates to
# `SELECT * FROM clips WHERE id in (75, 89, 106, 45)`.
# SQL is not guaranteed to return the results in the same order used in the WHERE clause.
# There is no way (that I'm aware of) to do this in SQL.
queryset = queryset.filter(pk__in=ids)
return queryset
However, as explained in the comments, this does not preserve the order of the results. I realize I could make a Python set of Clip objects, but I would then lose the lazy-evaluation of Django QuerySet
and the results are likely to be large and will be paginated.
I looked into Haystack, and my understanding is that the code above using Haystack would look like:
def get_queryset(self):
params = self.request.GET
query = params.get('query', None)
search_queryset = SearchQuerySet().filter(content=query)
return search_queryset
This is super simple and will maintain the order of the results, but Django REST Framework does not serialize SearchQuerySets
.
Is there a method in REST Framework that I can override that would allow for the serialization of SearchQuerySets
? Or is there a way to maintain the ranked results order without using Haystack or Python sets?
We came up with this solution. It mimics the idioms of Django and REST Framework.
Briefly, I'll explain (hopefully, the code and comments are a more in-depth explanation :)). Instead of using a regular Django Page
to power pagination, we have a FakePage
that can be initialized with a total count that is not coupled to the object list. This is the hack. Yes, it's a "hack", but is a very simple solution. The alternative for us would have been to reimplement a version of QuerySet
. For us, the simple solution always wins. Despite being simple, it is reusable and performant.
Using the fake page, we have an abstract SearchListModelMixin
class that knows how to get a serializer that uses our fake page. The mixin is applied to concrete view classes like SearchListCreateAPIView
. The code is below. If other people have a need, I can explain more thoroughly.
class FakePage(object):
"""
Fake page used by Django paginator.
Required for wrapping responses from Solr.
"""
def __init__(self, object_list, number, total_count):
"""
Create fake page instance.
Args:
object_list: list of objects, represented by a page
number: 1-based page number
total_count: total count of objects (in all pages)
"""
self.object_list = object_list
self.number = number
# count of objects per page equals to length of list
self.per_page = len(object_list)
if self.per_page > 0:
self.num_pages = total_count // self.per_page
else:
self.num_pages = 0
self.total_count = total_count
def __repr__(self):
return '<Page %s of %s>' % (self.number, self.num_pages)
def __len__(self):
return len(self.object_list)
def __getitem__(self, index):
if not isinstance(index, (slice,) + six.integer_types):
raise TypeError
# The object_list is converted to a list so that if it was a QuerySet
# it won't be a database hit per __getitem__.
if not isinstance(self.object_list, list):
self.object_list = list(self.object_list)
return self.object_list[index]
def total_count(self):
return self.total_count
def has_next(self):
return self.number < self.num_pages
def has_previous(self):
return self.number > 1
def has_other_pages(self):
return self.has_previous() or self.has_next()
def next_page_number(self):
if self.has_next():
return self.number + 1
raise EmptyPage('Next page does not exist')
def previous_page_number(self):
if self.has_previous():
return self.number - 1
raise EmptyPage('Previous page does not exist')
def start_index(self):
"""
Returns the 1-based index of the first object on this page,
relative to total objects in the paginator.
"""
# Special case, return zero if no items.
if self.total_count == 0:
return 0
return (self.per_page * (self.number - 1)) + 1
def end_index(self):
"""
Returns the 1-based index of the last object on this page,
relative to total objects found (hits).
"""
# Special case for the last page because there can be orphans.
if self.number == self.num_pages:
return self.total_count
return self.number * self.per_page
class SearchListModelMixin(object):
"""
List search results or a queryset.
"""
# Set this attribute to make a custom URL paramter for the query.
# EX: ../widgets?query=something
query_param = 'query'
# Set this attribute to define the type of search.
# This determines the source of the query value.
# For example, for regular text search,
# the query comes from a URL param. However, for a related search,
# the query is a unique identifier in the URL itself.
search_type = SearchType.QUERY
def search_list(self, request, *args, **kwargs):
# Get the query from the URL parameters dictionary.
query = self.request.GET.get(self.query_param, None)
# If there is no query use default REST Framework behavior.
if query is None and self.search_type == SearchType.QUERY:
return self.list(request, *args, **kwargs)
# Get the page of objects and the total count of the results.
if hasattr(self, 'get_search_results'):
self.object_list, total_count = self.get_search_results()
if not isinstance(self.object_list, list) and not isinstance(total_count, int):
raise ImproperlyConfigured("'%s.get_search_results()' must return (list, int)"
% self.__class__.__name__)
else:
raise ImproperlyConfigured("'%s' must define 'get_search_results()'"
% self.__class__.__name__)
# Normally, we would be serializing a QuerySet,
# which is lazily evaluated and has the entire result set.
# Here, we just have a Python list containing only the elements for the
# requested page. Thus, we must generate a fake page,
# simulating a Django page in order to fully support pagination.
# Otherwise, the `count` field would be equal to the page size.
page = FakePage(self.object_list,
int(self.request.GET.get(self.page_kwarg, 1)),
total_count)
# Prepare a SearchPaginationSerializer
# with the object_serializer_class
# set to the serializer_class of the APIView.
serializer = self.get_search_pagination_serializer(page)
return Response(serializer.data)
def get_search_pagination_serializer(self, page):
"""
Return a serializer instance to use with paginated search data.
"""
class SerializerClass(SearchPaginationSerializer):
class Meta:
object_serializer_class = self.get_serializer_class()
pagination_serializer_class = SerializerClass
context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context)
def get_solr_results(self, sort=None):
"""
This method is optional. It encapsulates logic for a Solr request
for a list of ids using pagination. Another method could be provided to
do the search request.
"""
queryset = super(self.__class__, self).get_queryset()
conn = solr.Solr(settings.SOLR_URL)
# Create a SearchHandler and get the query with respect to
# the type of search.
if self.search_type == SearchType.QUERY:
query = self.request.GET.get(self.query_param, None)
sh = solr.SearchHandler(conn, "/select")
elif self.search_type == SearchType.RELATED:
query = str(self.kwargs[self.lookup_field])
sh = solr.SearchHandler(conn, "/morelikethis")
# Get the pagination information and populate kwargs.
page_num = int(self.request.GET.get(self.page_kwarg, 1))
per_page = self.get_paginate_by()
offset = (page_num - 1) * per_page
kwargs = {'rows': per_page, 'start': offset}
if sort:
kwargs['sort'] = sort
# Perform the Solr request and build a list of results.
# For now, this only gets the id field, but this could be
# customized later.
response = sh(query, 'id', **kwargs)
results = [int(r['id']) for r in response.results]
# Get a dictionary representing a page of objects.
# The dict has pk keys and object values, sorted by id.
object_dict = queryset.in_bulk(results)
# Build the sorted list of objects.
sorted_objects = []
for pk in results:
obj = object_dict.get(pk, None)
if obj:
sorted_objects.append(obj)
return sorted_objects, response.numFound
class SearchType(object):
"""
This enum-like class defines types of Solr searches
that can be used in APIViews.
"""
QUERY = 1
RELATED = 2
class SearchPaginationSerializer(pagination.BasePaginationSerializer):
count = serializers.Field(source='total_count')
next = pagination.NextPageField(source='*')
previous = pagination.PreviousPageField(source='*')
class SearchListCreateAPIView(SearchListModelMixin, generics.ListCreateAPIView):
"""
Concrete view for listing search results, a queryset, or creating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.search_list(request, *args, **kwargs)
class SearchListAPIView(SearchListModelMixin, generics.ListAPIView):
"""
Concrete view for listing search results or a queryset.
"""
def get(self, request, *args, **kwargs):
return self.search_list(request, *args, **kwargs)
class WidgetList(SearchListCreateAPIView):
"""
List all Widgets or create a new Widget.
"""
model = Widget
queryset = Widget.objects.all()
serializer_class = WidgetSerializer
permission_classes = (permissions.IsAuthenticatedOrReadOnly,)
search_type = SearchType.QUERY # this is default, but explicitly declared here
def get_queryset(self):
"""
The method implemented for database powered results.
I'm simply illustrating the default behavior here.
"""
queryset = super(WidgetList, self).get_queryset()
return queryset
def get_search_results(self):
"""
The method implemented for search engine powered results.
"""
return self.get_solr_results(solr_sort)
class RelatedList(SearchListAPIView):
"""
List all related Widgets.
"""
model = Widget
queryset = Widget.objects.all()
serializer_class = WdigetSerializer
permission_classes = (permissions.IsAuthenticatedOrReadOnly,)
search_type = SearchType.RELATED
def get_search_results(self):
return self.get_solr_results()