SQL Model now supports very basic or'ing with Q objects.
This commit is contained in:
Родитель
6aabc28260
Коммит
6eaadcf507
|
@ -2,7 +2,10 @@ import copy
|
|||
import re
|
||||
|
||||
from django.db import connection, models
|
||||
from django.db.models import Q
|
||||
from django.db.models.sql.query import AND, OR
|
||||
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
|
||||
from django.utils.tree import Node
|
||||
|
||||
|
||||
ORDER_PATTERN = re.compile(r'^[-+]?[a-zA-Z0-9_]+$')
|
||||
|
@ -47,10 +50,42 @@ class RawSQLManager(object):
|
|||
self._args = {}
|
||||
self._record_set = []
|
||||
|
||||
def __iter__(self):
|
||||
self._build_cursor()
|
||||
for row in self._iter_cursor_results():
|
||||
yield row
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
if key.start and key.stop:
|
||||
self.base_query['limit'] = [self._check_limit(key.start),
|
||||
self._check_limit(key.stop)]
|
||||
elif key.start:
|
||||
self.base_query['limit'] = [self._check_limit(key.start)]
|
||||
elif key.stop:
|
||||
self.base_query['limit'] = [0, self._check_limit(key.stop)]
|
||||
self._build_cursor()
|
||||
self._build_record_set()
|
||||
return self._record_set
|
||||
elif isinstance(key, int):
|
||||
if not len(self._record_set):
|
||||
# Fetch just enough rows to get this one item:
|
||||
self.base_query['limit'] = [key + 1]
|
||||
self._build_cursor()
|
||||
self._build_record_set()
|
||||
return self._record_set[key]
|
||||
else:
|
||||
raise TypeError('Key must be a slice or integer.')
|
||||
|
||||
def all(self):
|
||||
return self.__class__(self.sql_model,
|
||||
base_query=copy.deepcopy(self.base_query))
|
||||
|
||||
def count(self):
|
||||
"""Count of all results, preserving aggregate grouping."""
|
||||
self._execute('SELECT count(*) from (%s) as q' % self.as_sql())
|
||||
return self._cursor.fetchone()[0]
|
||||
|
||||
def get(self):
|
||||
rows = list(self)
|
||||
cnt = len(rows)
|
||||
|
@ -62,47 +97,53 @@ class RawSQLManager(object):
|
|||
else:
|
||||
return rows[0]
|
||||
|
||||
def as_sql(self):
|
||||
stmt = self._compile(self.base_query)
|
||||
return stmt
|
||||
def exclude(self, *args, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
def count(self):
|
||||
"""Count of all results, preserving aggregate grouping."""
|
||||
self._execute('SELECT count(*) from (%s) as q' % self.as_sql())
|
||||
return self._cursor.fetchone()[0]
|
||||
|
||||
def filter(self, **kw):
|
||||
def filter(self, *args, **kw):
|
||||
"""Adds a where clause with keyword args.
|
||||
|
||||
Example::
|
||||
|
||||
qs = qs.filter(category='trees')
|
||||
qs = qs.filter(Q(type=1) | Q(name='foo'))
|
||||
|
||||
"""
|
||||
# NOTE: or is not supported, i.e. no Q objects
|
||||
for arg in args:
|
||||
if isinstance(arg, Q):
|
||||
self.base_query['where'].append('(%s)'
|
||||
% (self._kw_clause_from_q(arg)))
|
||||
else:
|
||||
raise TypeError(
|
||||
'non keyword args should be Q objects, got %r' % arg)
|
||||
for field, val in kw.items():
|
||||
if not FIELD_PATTERN.match(field):
|
||||
raise ValueError(
|
||||
'Not a valid field for where clause: %r' % field)
|
||||
param_k = self._param(val)
|
||||
if field in self.base_query['select']:
|
||||
field = self.base_query['select'][field]
|
||||
self.base_query['where'].append('%s = %%(%s)s' % (field, param_k))
|
||||
self.base_query['where'].append(self._kw_filter_to_clause(field,
|
||||
val))
|
||||
return self
|
||||
|
||||
def filter_raw(self, spec, val):
|
||||
def filter_raw(self, *args):
|
||||
"""Adds a where clause in limited SQL.
|
||||
|
||||
Examples::
|
||||
|
||||
qs = qs.where('total >', 1)
|
||||
qs = qs.where('total >=', 1)
|
||||
qs = qs.filter_raw('total >', 1)
|
||||
qs = qs.filter_raw('total >=', 1)
|
||||
qs = qs.filter_raw(Q('name LIKE', '%foo%') |
|
||||
Q('status IN', [1, 2, 3]))
|
||||
|
||||
The field on the leftside can be a key in the select dictionary.
|
||||
That is, it will be replaced with the actual expression when the
|
||||
query is built.
|
||||
"""
|
||||
self.base_query['where'].append(self._filter_to_clause(spec, val))
|
||||
specs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, Q):
|
||||
self.base_query['where'].append(
|
||||
'(%s)' % (self._raw_clause_from_q(arg)))
|
||||
else:
|
||||
specs.append(arg)
|
||||
if len(specs):
|
||||
self.base_query['where'].append(self._filter_to_clause(*specs))
|
||||
return self
|
||||
|
||||
def having(self, spec, val):
|
||||
|
@ -120,23 +161,88 @@ class RawSQLManager(object):
|
|||
self.base_query['having'].append(self._filter_to_clause(spec, val))
|
||||
return self
|
||||
|
||||
def _filter_to_clause(self, spec, val):
|
||||
clause = RAW_FILTER_PATTERN.match(spec)
|
||||
if not clause:
|
||||
raise ValueError(
|
||||
'This is not a valid clause: %r; must match: %s' % (
|
||||
spec, RAW_FILTER_PATTERN.pattern))
|
||||
field = clause.group('field')
|
||||
def as_sql(self):
|
||||
stmt = self._compile(self.base_query)
|
||||
return stmt
|
||||
|
||||
def _parse_q(self, q_object):
|
||||
"""Returns a parsed Q object.
|
||||
|
||||
eg. [([('product =', 'AND'), ('life jacket', 'AND')], 'OR'),
|
||||
([('product =', 'AND'), ('defilbrilator', 'AND')], 'OR')]
|
||||
"""
|
||||
specs = []
|
||||
# TODO(Kumar): construct NOT clause:
|
||||
if q_object.negated:
|
||||
raise NotImplementedError('negated Q objects')
|
||||
for child in q_object.children:
|
||||
connector = q_object.connector
|
||||
if isinstance(child, Node):
|
||||
sp = self._parse_q(child)
|
||||
specs.append((sp, connector))
|
||||
else:
|
||||
specs.append((child, connector))
|
||||
return specs
|
||||
|
||||
def _raw_clause_from_q(self, q_object):
|
||||
parts = self._parse_q(q_object)
|
||||
clause = []
|
||||
# TODO(Kumar) this doesn't handle nesting!
|
||||
for part in parts:
|
||||
specs, connector = part
|
||||
# Remove the AND in each spec part:
|
||||
specs = [s[0] for s in specs]
|
||||
clause.extend([self._filter_to_clause(*specs),
|
||||
connector])
|
||||
return u' '.join(clause[:-1]) # skip the last connector
|
||||
|
||||
def _kw_clause_from_q(self, q_object):
|
||||
parts = self._parse_q(q_object)
|
||||
clause = []
|
||||
for part in parts:
|
||||
specs, connector = part
|
||||
clause.extend([self._kw_filter_to_clause(*specs),
|
||||
connector])
|
||||
return u' '.join(clause[:-1]) # skip the last connector
|
||||
|
||||
def _kw_filter_to_clause(self, field, val):
|
||||
if not FIELD_PATTERN.match(field):
|
||||
raise ValueError('Not a valid field for where clause: %r' % field)
|
||||
param_k = self._param(val)
|
||||
if field in self.base_query['select']:
|
||||
# Support filtering by alias, similar to how a view works
|
||||
field = self.base_query['select'][field]
|
||||
if clause.group('op').lower() == 'in':
|
||||
# eg. WHERE foo IN (1, 2, 3)
|
||||
parts = ['%%(%s)s' % self._param(p) for p in iter(val)]
|
||||
param = '(%s)' % ', '.join(parts)
|
||||
else:
|
||||
param = '%%(%s)s' % self._param(val)
|
||||
return '%s %s %s' % (field, clause.group('op'), param)
|
||||
return '%s = %%(%s)s' % (field, param_k)
|
||||
|
||||
def _filter_to_clause(self, *specs):
|
||||
specs = list(specs)
|
||||
if (len(specs) % 2) != 0:
|
||||
raise TypeError(
|
||||
"Expected pairs of 'spec =', 'val'. Got: %r" % specs)
|
||||
full_clause = []
|
||||
while len(specs):
|
||||
spec, val = specs.pop(0), specs.pop(0)
|
||||
clause = RAW_FILTER_PATTERN.match(spec)
|
||||
if not clause:
|
||||
raise ValueError(
|
||||
'This is not a valid clause: %r; must match: %s' % (
|
||||
spec, RAW_FILTER_PATTERN.pattern))
|
||||
field = clause.group('field')
|
||||
if field in self.base_query['select']:
|
||||
# Support filtering by alias, similar to how a view works
|
||||
field = self.base_query['select'][field]
|
||||
if clause.group('op').lower() == 'in':
|
||||
# eg. WHERE foo IN (1, 2, 3)
|
||||
parts = ['%%(%s)s' % self._param(p) for p in iter(val)]
|
||||
param = '(%s)' % ', '.join(parts)
|
||||
else:
|
||||
param = '%%(%s)s' % self._param(val)
|
||||
full_clause.append('%s %s %s' % (field, clause.group('op'), param))
|
||||
and_ = u' %s ' % AND
|
||||
c = and_.join(full_clause)
|
||||
if len(full_clause) > 1:
|
||||
# Protect OR clauses
|
||||
c = u'(%s)' % c
|
||||
return c
|
||||
|
||||
def order_by(self, spec):
|
||||
"""Order by column (ascending) or -column (descending)."""
|
||||
|
@ -153,12 +259,12 @@ class RawSQLManager(object):
|
|||
|
||||
def _compile(self, parts):
|
||||
sep = ",\n"
|
||||
and_ = ' %s\n' % AND
|
||||
select = ['%s AS %s' % (v, k) for k, v in parts['select'].items()]
|
||||
stmt = "SELECT\n%s\nFROM\n%s" % (sep.join(select),
|
||||
"\n".join(parts['from']))
|
||||
if parts.get('where'):
|
||||
stmt = "%s\nWHERE\n%s" % (stmt,
|
||||
" AND\n".join(parts['where']))
|
||||
stmt = "%s\nWHERE\n%s" % (stmt, and_.join(parts['where']))
|
||||
if parts.get('group_by'):
|
||||
stmt = "%s\nGROUP BY\n%s" % (stmt, parts['group_by'])
|
||||
if parts.get('having'):
|
||||
|
@ -188,11 +294,6 @@ class RawSQLManager(object):
|
|||
for row in self._iter_cursor_results():
|
||||
self._record_set.append(row)
|
||||
|
||||
def __iter__(self):
|
||||
self._build_cursor()
|
||||
for row in self._iter_cursor_results():
|
||||
yield row
|
||||
|
||||
def _iter_cursor_results(self):
|
||||
col_names = [c[0] for c in self._cursor.description]
|
||||
while 1:
|
||||
|
@ -211,28 +312,6 @@ class RawSQLManager(object):
|
|||
raise IndexError("Negative indexing is not supported")
|
||||
return i
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
if key.start and key.stop:
|
||||
self.base_query['limit'] = [self._check_limit(key.start),
|
||||
self._check_limit(key.stop)]
|
||||
elif key.start:
|
||||
self.base_query['limit'] = [self._check_limit(key.start)]
|
||||
elif key.stop:
|
||||
self.base_query['limit'] = [0, self._check_limit(key.stop)]
|
||||
self._build_cursor()
|
||||
self._build_record_set()
|
||||
return self._record_set
|
||||
elif isinstance(key, int):
|
||||
if not len(self._record_set):
|
||||
# Fetch just enough rows to get this one item:
|
||||
self.base_query['limit'] = [key + 1]
|
||||
self._build_cursor()
|
||||
self._build_record_set()
|
||||
return self._record_set[key]
|
||||
else:
|
||||
raise TypeError('Key must be a slice or integer.')
|
||||
|
||||
|
||||
class RawSQLModelMeta(type):
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# -*- coding: utf8 -*-
|
||||
"""Tests for SQL Model.
|
||||
|
||||
Currently these tests are coupled tighly with MySQL
|
||||
|
@ -6,6 +7,7 @@ from datetime import datetime
|
|||
import unittest
|
||||
|
||||
from django.db import connection, models
|
||||
from django.db.models import Q
|
||||
from nose.tools import eq_, raises
|
||||
|
||||
from editors.sql_model import RawSQLModel
|
||||
|
@ -85,6 +87,24 @@ class Summary(RawSQLModel):
|
|||
}
|
||||
|
||||
|
||||
class ProductDetail(RawSQLModel):
|
||||
product = models.CharField(max_length=255)
|
||||
category = models.CharField(max_length=255)
|
||||
|
||||
def base_query(self):
|
||||
return {
|
||||
'select': {
|
||||
'product': 'p.name',
|
||||
'category': 'c.name'
|
||||
},
|
||||
'from': [
|
||||
'sql_model_test_product p',
|
||||
'join sql_model_test_product_cat x on x.product_id=p.id',
|
||||
'join sql_model_test_cat c on x.cat_id=c.id'],
|
||||
'where': []
|
||||
}
|
||||
|
||||
|
||||
class TestSQLModel(unittest.TestCase):
|
||||
|
||||
def test_all(self):
|
||||
|
@ -153,6 +173,30 @@ class TestSQLModel(unittest.TestCase):
|
|||
['apparel', 'safety'])
|
||||
eq_([c.category for c in qs], ['apparel', 'safety'])
|
||||
|
||||
def test_filter_raw_non_ascii(self):
|
||||
uni = 'フォクすけといっしょ'.decode('utf8')
|
||||
qs = (Summary.objects.all().filter_raw('category =', uni)
|
||||
.filter_raw(Q('category =', uni) | Q('category !=', uni)))
|
||||
eq_([c.category for c in qs], [])
|
||||
|
||||
def test_combining_filters_with_or(self):
|
||||
qs = (ProductDetail.objects.all()
|
||||
.filter(Q(product='life jacket') | Q(product='defilbrilator')))
|
||||
eq_(sorted([r.product for r in qs]), ['defilbrilator', 'life jacket'])
|
||||
|
||||
def test_combining_raw_filters_with_or(self):
|
||||
qs = (ProductDetail.objects.all()
|
||||
.filter_raw(Q('product =', 'life jacket') |
|
||||
Q('product =', 'defilbrilator')))
|
||||
eq_(sorted([r.product for r in qs]), ['defilbrilator', 'life jacket'])
|
||||
|
||||
def test_nested_raw_filters_with_or(self):
|
||||
qs = (ProductDetail.objects.all()
|
||||
.filter_raw(Q('category =', 'apparel',
|
||||
'product =', 'defilbrilator') |
|
||||
Q('product =', 'life jacket')))
|
||||
eq_(sorted([r.product for r in qs]), ['life jacket'])
|
||||
|
||||
def test_having_gte(self):
|
||||
c = Summary.objects.all().having('total >=', 2)[0]
|
||||
eq_(c.category, 'safety')
|
||||
|
|
Загрузка…
Ссылка в новой задаче