Source code for rest_witchcraft.filters

"""Provides generic filtering backends that can be used to filter the results
returned by list views."""

from sqlalchemy import func, or_
from sqlalchemy.sql import operators

from django.template import loader
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy

from rest_framework.compat import coreapi, coreschema
from rest_framework.filters import BaseFilterBackend
from rest_framework.settings import api_settings


[docs]class SearchFilter(BaseFilterBackend): search_param = api_settings.SEARCH_PARAM template = "rest_framework/filters/search.html" lookup_prefixes = { "": lambda c, x: operators.ilike_op(c, "%{}%".format(x)), # icontains "^": lambda c, x: c.ilike(x.replace("%", "%%") + "%"), # istartswith "=": lambda c, x: func.lower(c) == func.lower(x), # iequals "@": operators.eq, # equals } search_title = gettext_lazy("Search") search_description = gettext_lazy("A search term.")
[docs] def get_schema_fields(self, view): assert coreapi is not None, "coreapi must be installed to use `get_schema_fields()`" assert coreschema is not None, "coreschema must be installed to use `get_schema_fields()`" return [ coreapi.Field( name=self.search_param, required=False, location="query", schema=coreschema.String( title=force_str(self.search_title), description=force_str(self.search_description) ), ) ]
[docs] def get_schema_operation_parameters(self, view): return [ { "name": self.search_param, "required": False, "in": "query", "description": force_str(self.search_description), "schema": {"type": "string"}, } ]
[docs] def get_search_fields(self, view, request): return getattr(view, "search_fields", None)
[docs] def get_search_terms(self, request): params = request.query_params.get(self.search_param, "") params = params.replace("\x00", "") # strip null characters params = params.replace(",", " ") return params.split()
[docs] def to_html(self, request, queryset, view): if not getattr(view, "search_fields", None): return "" term = self.get_search_terms(request) term = term[0] if term else "" context = {"param": self.search_param, "term": term} template = loader.get_template(self.template) return template.render(context)
[docs] def filter_queryset(self, request, queryset, view): search_fields = self.get_search_fields(view, request) search_terms = self.get_search_terms(request) if not search_fields or not search_terms: return queryset model = view.get_model() expressions = [] for field in search_fields: for term in search_terms: expr = self.get_expression(model, field, term) if expr is not None: expressions.append(expr) return queryset.filter(or_(*expressions))
[docs] def get_expression(self, model, field, term): op = self.lookup_prefixes[""] if field[0] in self.lookup_prefixes: op = self.lookup_prefixes[field[0]] field = field[1:] return op(getattr(model, field), term)