Search code examples
pythondjangosolrdjango-haystackdjango-rest-framework

Solr search results with Django REST Framework


I'm using Django REST Framework for our web API and Solr to power a search. Currenlty, in a subclass of ListAPIView, I override get_queryset() to get a QuerySet with the Solr search results:

class ClipList(generics.ListAPIView):
    """
    List all Clips. 

    Permissions: IsAuthenticatedOrReadOnly

    Parameters:
    query -- Search for Clips. EX: clips/?query=aaron%20rodgers
    """
    model = Clip
    serializer_class = ClipSerializer
    permission_classes = (permissions.IsAuthenticatedOrReadOnly,)

    def get_queryset(self):
        params = request.GET
        query = params.get('query', None)

        queryset = Clip.objects.all()

        if query is not None:
            conn = solr.Solr(settings.SOLR_URL)
            sh = solr.SearchHandler(conn, "/select")
            response = sh(query)
            ids = []

            for result in response.results:
                ids.append(result['id'])

            # filter the initial queryset based on the Clip identifiers from the Solr response
            # PROBLEM: This does not preserve the order of the results as it equates to
            # `SELECT * FROM clips WHERE id in (75, 89, 106, 45)`.
            # SQL is not guaranteed to return the results in the same order used in the WHERE clause.
            # There is no way (that I'm aware of) to do this in SQL.
            queryset = queryset.filter(pk__in=ids)

        return queryset

However, as explained in the comments, this does not preserve the order of the results. I realize I could make a Python set of Clip objects, but I would then lose the lazy-evaluation of Django QuerySet and the results are likely to be large and will be paginated.

I looked into Haystack, and my understanding is that the code above using Haystack would look like:

    def get_queryset(self):
        params = self.request.GET
        query = params.get('query', None)

        search_queryset = SearchQuerySet().filter(content=query)

        return search_queryset

This is super simple and will maintain the order of the results, but Django REST Framework does not serialize SearchQuerySets.

Is there a method in REST Framework that I can override that would allow for the serialization of SearchQuerySets? Or is there a way to maintain the ranked results order without using Haystack or Python sets?


Solution

  • We came up with this solution. It mimics the idioms of Django and REST Framework.

    Briefly, I'll explain (hopefully, the code and comments are a more in-depth explanation :)). Instead of using a regular Django Page to power pagination, we have a FakePage that can be initialized with a total count that is not coupled to the object list. This is the hack. Yes, it's a "hack", but is a very simple solution. The alternative for us would have been to reimplement a version of QuerySet. For us, the simple solution always wins. Despite being simple, it is reusable and performant.

    Using the fake page, we have an abstract SearchListModelMixin class that knows how to get a serializer that uses our fake page. The mixin is applied to concrete view classes like SearchListCreateAPIView. The code is below. If other people have a need, I can explain more thoroughly.

    class FakePage(object):
        """
        Fake page used by Django paginator.
        Required for wrapping responses from Solr.
        """
    
        def __init__(self, object_list, number, total_count):
            """
            Create fake page instance.
    
            Args:
                object_list: list of objects, represented by a page
                number: 1-based page number
                total_count: total count of objects (in all pages)
            """
            self.object_list = object_list
            self.number = number
            # count of objects per page equals to length of list
            self.per_page = len(object_list)
            if self.per_page > 0:
                self.num_pages = total_count // self.per_page
            else:
                self.num_pages = 0
            self.total_count = total_count
    
        def __repr__(self):
            return '<Page %s of %s>' % (self.number, self.num_pages)
    
        def __len__(self):
            return len(self.object_list)
    
        def __getitem__(self, index):
            if not isinstance(index, (slice,) + six.integer_types):
                raise TypeError
            # The object_list is converted to a list so that if it was a QuerySet
            # it won't be a database hit per __getitem__.
            if not isinstance(self.object_list, list):
                self.object_list = list(self.object_list)
            return self.object_list[index]
    
        def total_count(self):
            return self.total_count
    
        def has_next(self):
            return self.number < self.num_pages
    
        def has_previous(self):
            return self.number > 1
    
        def has_other_pages(self):
            return self.has_previous() or self.has_next()
    
        def next_page_number(self):
            if self.has_next():
                return self.number + 1
            raise EmptyPage('Next page does not exist')
    
        def previous_page_number(self):
            if self.has_previous():
                return self.number - 1
            raise EmptyPage('Previous page does not exist')
    
        def start_index(self):
            """
            Returns the 1-based index of the first object on this page,
            relative to total objects in the paginator.
            """
            # Special case, return zero if no items.
            if self.total_count == 0:
                return 0
            return (self.per_page * (self.number - 1)) + 1
    
        def end_index(self):
            """
            Returns the 1-based index of the last object on this page,
            relative to total objects found (hits).
            """
            # Special case for the last page because there can be orphans.
            if self.number == self.num_pages:
                return self.total_count
            return self.number * self.per_page    
    
    class SearchListModelMixin(object):
        """
        List search results or a queryset.
        """
        # Set this attribute to make a custom URL paramter for the query.
        # EX: ../widgets?query=something
        query_param = 'query'
    
        # Set this attribute to define the type of search.
        # This determines the source of the query value.
        # For example, for regular text search,
        # the query comes from a URL param. However, for a related search,
        # the query is a unique identifier in the URL itself.
        search_type = SearchType.QUERY
    
        def search_list(self, request, *args, **kwargs):
            # Get the query from the URL parameters dictionary.
            query = self.request.GET.get(self.query_param, None)
    
            # If there is no query use default REST Framework behavior.
            if query is None and self.search_type == SearchType.QUERY:
                return self.list(request, *args, **kwargs)
    
            # Get the page of objects and the total count of the results.
            if hasattr(self, 'get_search_results'):
                self.object_list, total_count = self.get_search_results()
                if not isinstance(self.object_list, list) and not isinstance(total_count, int):
                   raise ImproperlyConfigured("'%s.get_search_results()' must return (list, int)"
                                        % self.__class__.__name__)
            else:
                raise ImproperlyConfigured("'%s' must define 'get_search_results()'"
                                        % self.__class__.__name__)
    
            # Normally, we would be serializing a QuerySet,
            # which is lazily evaluated and has the entire result set.
            # Here, we just have a Python list containing only the elements for the
            # requested page. Thus, we must generate a fake page,
            # simulating a Django page in order to fully support pagination.
            # Otherwise, the `count` field would be equal to the page size.
            page = FakePage(self.object_list,
                int(self.request.GET.get(self.page_kwarg, 1)),
                total_count)
    
            # Prepare a SearchPaginationSerializer
            # with the object_serializer_class
            # set to the serializer_class of the APIView.
            serializer = self.get_search_pagination_serializer(page)
    
            return Response(serializer.data)
    
        def get_search_pagination_serializer(self, page):
            """
            Return a serializer instance to use with paginated search data.
            """
            class SerializerClass(SearchPaginationSerializer):
                class Meta:
                    object_serializer_class = self.get_serializer_class()
    
            pagination_serializer_class = SerializerClass
            context = self.get_serializer_context()
            return pagination_serializer_class(instance=page, context=context)
    
        def get_solr_results(self, sort=None):
            """
            This method is optional. It encapsulates logic for a Solr request
            for a list of ids using pagination. Another method could be provided to
            do the search request. 
            """
            queryset = super(self.__class__, self).get_queryset()
    
            conn = solr.Solr(settings.SOLR_URL)
    
            # Create a SearchHandler and get the query with respect to 
            # the type of search.
            if self.search_type == SearchType.QUERY:
                query = self.request.GET.get(self.query_param, None)
                sh = solr.SearchHandler(conn, "/select")
            elif self.search_type == SearchType.RELATED:
                query  = str(self.kwargs[self.lookup_field])
                sh = solr.SearchHandler(conn, "/morelikethis")
    
            # Get the pagination information and populate kwargs.
            page_num = int(self.request.GET.get(self.page_kwarg, 1))
            per_page = self.get_paginate_by()
            offset = (page_num - 1) * per_page
            kwargs = {'rows': per_page, 'start': offset}
            if sort:
                kwargs['sort'] = sort
    
            # Perform the Solr request and build a list of results.
            # For now, this only gets the id field, but this could be 
            # customized later.
            response = sh(query, 'id', **kwargs)
            results = [int(r['id']) for r in response.results]
    
            # Get a dictionary representing a page of objects.
            # The dict has pk keys and object values, sorted by id.
            object_dict = queryset.in_bulk(results)
    
            # Build the sorted list of objects.
            sorted_objects = []
            for pk in results:
                obj = object_dict.get(pk, None)
                if obj:
                    sorted_objects.append(obj)
            return sorted_objects, response.numFound
    
    class SearchType(object):
        """
        This enum-like class defines types of Solr searches
        that can be used in APIViews.
        """
        QUERY = 1
        RELATED = 2
    
    
    class SearchPaginationSerializer(pagination.BasePaginationSerializer):
        count = serializers.Field(source='total_count')
        next = pagination.NextPageField(source='*')
        previous = pagination.PreviousPageField(source='*')
    
    
    class SearchListCreateAPIView(SearchListModelMixin, generics.ListCreateAPIView):
        """
        Concrete view for listing search results, a queryset, or creating a model instance.
        """
        def get(self, request, *args, **kwargs):
            return self.search_list(request, *args, **kwargs)
    
    
    class SearchListAPIView(SearchListModelMixin, generics.ListAPIView):
        """
        Concrete view for listing search results or a queryset.
        """
        def get(self, request, *args, **kwargs):
            return self.search_list(request, *args, **kwargs)
    
    
    class WidgetList(SearchListCreateAPIView):
        """
        List all Widgets or create a new Widget.
        """
        model = Widget
        queryset = Widget.objects.all()
        serializer_class = WidgetSerializer
        permission_classes = (permissions.IsAuthenticatedOrReadOnly,)
        search_type = SearchType.QUERY # this is default, but explicitly declared here    
    
        def get_queryset(self):
            """
            The method implemented for database powered results.
            I'm simply illustrating the default behavior here.
            """
            queryset = super(WidgetList, self).get_queryset()
            return queryset
    
        def get_search_results(self):
            """
            The method implemented for search engine powered results.
            """
            return self.get_solr_results(solr_sort)
    
    class RelatedList(SearchListAPIView):
        """
        List all related Widgets.
        """
        model = Widget
        queryset = Widget.objects.all()
        serializer_class = WdigetSerializer
        permission_classes = (permissions.IsAuthenticatedOrReadOnly,)
        search_type = SearchType.RELATED
    
        def get_search_results(self):
            return self.get_solr_results()