SQL Model now supports very basic or'ing with Q objects.

This commit is contained in:
Kumar McMillan 2011-02-17 12:39:08 -06:00
Родитель 6aabc28260
Коммит 6eaadcf507
2 изменённых файлов: 188 добавлений и 65 удалений

Просмотреть файл

@ -2,7 +2,10 @@ import copy
import re
from django.db import connection, models
from django.db.models import Q
from django.db.models.sql.query import AND, OR
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
from django.utils.tree import Node
ORDER_PATTERN = re.compile(r'^[-+]?[a-zA-Z0-9_]+$')
@ -47,10 +50,42 @@ class RawSQLManager(object):
self._args = {}
self._record_set = []
def __iter__(self):
self._build_cursor()
for row in self._iter_cursor_results():
yield row
def __getitem__(self, key):
if isinstance(key, slice):
if key.start and key.stop:
self.base_query['limit'] = [self._check_limit(key.start),
self._check_limit(key.stop)]
elif key.start:
self.base_query['limit'] = [self._check_limit(key.start)]
elif key.stop:
self.base_query['limit'] = [0, self._check_limit(key.stop)]
self._build_cursor()
self._build_record_set()
return self._record_set
elif isinstance(key, int):
if not len(self._record_set):
# Fetch just enough rows to get this one item:
self.base_query['limit'] = [key + 1]
self._build_cursor()
self._build_record_set()
return self._record_set[key]
else:
raise TypeError('Key must be a slice or integer.')
def all(self):
return self.__class__(self.sql_model,
base_query=copy.deepcopy(self.base_query))
def count(self):
"""Count of all results, preserving aggregate grouping."""
self._execute('SELECT count(*) from (%s) as q' % self.as_sql())
return self._cursor.fetchone()[0]
def get(self):
rows = list(self)
cnt = len(rows)
@ -62,47 +97,53 @@ class RawSQLManager(object):
else:
return rows[0]
def as_sql(self):
stmt = self._compile(self.base_query)
return stmt
def exclude(self, *args, **kw):
raise NotImplementedError()
def count(self):
"""Count of all results, preserving aggregate grouping."""
self._execute('SELECT count(*) from (%s) as q' % self.as_sql())
return self._cursor.fetchone()[0]
def filter(self, **kw):
def filter(self, *args, **kw):
"""Adds a where clause with keyword args.
Example::
qs = qs.filter(category='trees')
qs = qs.filter(Q(type=1) | Q(name='foo'))
"""
# NOTE: or is not supported, i.e. no Q objects
for arg in args:
if isinstance(arg, Q):
self.base_query['where'].append('(%s)'
% (self._kw_clause_from_q(arg)))
else:
raise TypeError(
'non keyword args should be Q objects, got %r' % arg)
for field, val in kw.items():
if not FIELD_PATTERN.match(field):
raise ValueError(
'Not a valid field for where clause: %r' % field)
param_k = self._param(val)
if field in self.base_query['select']:
field = self.base_query['select'][field]
self.base_query['where'].append('%s = %%(%s)s' % (field, param_k))
self.base_query['where'].append(self._kw_filter_to_clause(field,
val))
return self
def filter_raw(self, spec, val):
def filter_raw(self, *args):
"""Adds a where clause in limited SQL.
Examples::
qs = qs.where('total >', 1)
qs = qs.where('total >=', 1)
qs = qs.filter_raw('total >', 1)
qs = qs.filter_raw('total >=', 1)
qs = qs.filter_raw(Q('name LIKE', '%foo%') |
Q('status IN', [1, 2, 3]))
The field on the leftside can be a key in the select dictionary.
That is, it will be replaced with the actual expression when the
query is built.
"""
self.base_query['where'].append(self._filter_to_clause(spec, val))
specs = []
for arg in args:
if isinstance(arg, Q):
self.base_query['where'].append(
'(%s)' % (self._raw_clause_from_q(arg)))
else:
specs.append(arg)
if len(specs):
self.base_query['where'].append(self._filter_to_clause(*specs))
return self
def having(self, spec, val):
@ -120,23 +161,88 @@ class RawSQLManager(object):
self.base_query['having'].append(self._filter_to_clause(spec, val))
return self
def _filter_to_clause(self, spec, val):
clause = RAW_FILTER_PATTERN.match(spec)
if not clause:
raise ValueError(
'This is not a valid clause: %r; must match: %s' % (
spec, RAW_FILTER_PATTERN.pattern))
field = clause.group('field')
def as_sql(self):
stmt = self._compile(self.base_query)
return stmt
def _parse_q(self, q_object):
"""Returns a parsed Q object.
eg. [([('product =', 'AND'), ('life jacket', 'AND')], 'OR'),
([('product =', 'AND'), ('defilbrilator', 'AND')], 'OR')]
"""
specs = []
# TODO(Kumar): construct NOT clause:
if q_object.negated:
raise NotImplementedError('negated Q objects')
for child in q_object.children:
connector = q_object.connector
if isinstance(child, Node):
sp = self._parse_q(child)
specs.append((sp, connector))
else:
specs.append((child, connector))
return specs
def _raw_clause_from_q(self, q_object):
parts = self._parse_q(q_object)
clause = []
# TODO(Kumar) this doesn't handle nesting!
for part in parts:
specs, connector = part
# Remove the AND in each spec part:
specs = [s[0] for s in specs]
clause.extend([self._filter_to_clause(*specs),
connector])
return u' '.join(clause[:-1]) # skip the last connector
def _kw_clause_from_q(self, q_object):
parts = self._parse_q(q_object)
clause = []
for part in parts:
specs, connector = part
clause.extend([self._kw_filter_to_clause(*specs),
connector])
return u' '.join(clause[:-1]) # skip the last connector
def _kw_filter_to_clause(self, field, val):
if not FIELD_PATTERN.match(field):
raise ValueError('Not a valid field for where clause: %r' % field)
param_k = self._param(val)
if field in self.base_query['select']:
# Support filtering by alias, similar to how a view works
field = self.base_query['select'][field]
if clause.group('op').lower() == 'in':
# eg. WHERE foo IN (1, 2, 3)
parts = ['%%(%s)s' % self._param(p) for p in iter(val)]
param = '(%s)' % ', '.join(parts)
else:
param = '%%(%s)s' % self._param(val)
return '%s %s %s' % (field, clause.group('op'), param)
return '%s = %%(%s)s' % (field, param_k)
def _filter_to_clause(self, *specs):
specs = list(specs)
if (len(specs) % 2) != 0:
raise TypeError(
"Expected pairs of 'spec =', 'val'. Got: %r" % specs)
full_clause = []
while len(specs):
spec, val = specs.pop(0), specs.pop(0)
clause = RAW_FILTER_PATTERN.match(spec)
if not clause:
raise ValueError(
'This is not a valid clause: %r; must match: %s' % (
spec, RAW_FILTER_PATTERN.pattern))
field = clause.group('field')
if field in self.base_query['select']:
# Support filtering by alias, similar to how a view works
field = self.base_query['select'][field]
if clause.group('op').lower() == 'in':
# eg. WHERE foo IN (1, 2, 3)
parts = ['%%(%s)s' % self._param(p) for p in iter(val)]
param = '(%s)' % ', '.join(parts)
else:
param = '%%(%s)s' % self._param(val)
full_clause.append('%s %s %s' % (field, clause.group('op'), param))
and_ = u' %s ' % AND
c = and_.join(full_clause)
if len(full_clause) > 1:
# Protect OR clauses
c = u'(%s)' % c
return c
def order_by(self, spec):
"""Order by column (ascending) or -column (descending)."""
@ -153,12 +259,12 @@ class RawSQLManager(object):
def _compile(self, parts):
sep = ",\n"
and_ = ' %s\n' % AND
select = ['%s AS %s' % (v, k) for k, v in parts['select'].items()]
stmt = "SELECT\n%s\nFROM\n%s" % (sep.join(select),
"\n".join(parts['from']))
if parts.get('where'):
stmt = "%s\nWHERE\n%s" % (stmt,
" AND\n".join(parts['where']))
stmt = "%s\nWHERE\n%s" % (stmt, and_.join(parts['where']))
if parts.get('group_by'):
stmt = "%s\nGROUP BY\n%s" % (stmt, parts['group_by'])
if parts.get('having'):
@ -188,11 +294,6 @@ class RawSQLManager(object):
for row in self._iter_cursor_results():
self._record_set.append(row)
def __iter__(self):
self._build_cursor()
for row in self._iter_cursor_results():
yield row
def _iter_cursor_results(self):
col_names = [c[0] for c in self._cursor.description]
while 1:
@ -211,28 +312,6 @@ class RawSQLManager(object):
raise IndexError("Negative indexing is not supported")
return i
def __getitem__(self, key):
if isinstance(key, slice):
if key.start and key.stop:
self.base_query['limit'] = [self._check_limit(key.start),
self._check_limit(key.stop)]
elif key.start:
self.base_query['limit'] = [self._check_limit(key.start)]
elif key.stop:
self.base_query['limit'] = [0, self._check_limit(key.stop)]
self._build_cursor()
self._build_record_set()
return self._record_set
elif isinstance(key, int):
if not len(self._record_set):
# Fetch just enough rows to get this one item:
self.base_query['limit'] = [key + 1]
self._build_cursor()
self._build_record_set()
return self._record_set[key]
else:
raise TypeError('Key must be a slice or integer.')
class RawSQLModelMeta(type):

Просмотреть файл

@ -1,3 +1,4 @@
# -*- coding: utf8 -*-
"""Tests for SQL Model.
Currently these tests are coupled tighly with MySQL
@ -6,6 +7,7 @@ from datetime import datetime
import unittest
from django.db import connection, models
from django.db.models import Q
from nose.tools import eq_, raises
from editors.sql_model import RawSQLModel
@ -85,6 +87,24 @@ class Summary(RawSQLModel):
}
class ProductDetail(RawSQLModel):
product = models.CharField(max_length=255)
category = models.CharField(max_length=255)
def base_query(self):
return {
'select': {
'product': 'p.name',
'category': 'c.name'
},
'from': [
'sql_model_test_product p',
'join sql_model_test_product_cat x on x.product_id=p.id',
'join sql_model_test_cat c on x.cat_id=c.id'],
'where': []
}
class TestSQLModel(unittest.TestCase):
def test_all(self):
@ -153,6 +173,30 @@ class TestSQLModel(unittest.TestCase):
['apparel', 'safety'])
eq_([c.category for c in qs], ['apparel', 'safety'])
def test_filter_raw_non_ascii(self):
uni = 'フォクすけといっしょ'.decode('utf8')
qs = (Summary.objects.all().filter_raw('category =', uni)
.filter_raw(Q('category =', uni) | Q('category !=', uni)))
eq_([c.category for c in qs], [])
def test_combining_filters_with_or(self):
qs = (ProductDetail.objects.all()
.filter(Q(product='life jacket') | Q(product='defilbrilator')))
eq_(sorted([r.product for r in qs]), ['defilbrilator', 'life jacket'])
def test_combining_raw_filters_with_or(self):
qs = (ProductDetail.objects.all()
.filter_raw(Q('product =', 'life jacket') |
Q('product =', 'defilbrilator')))
eq_(sorted([r.product for r in qs]), ['defilbrilator', 'life jacket'])
def test_nested_raw_filters_with_or(self):
qs = (ProductDetail.objects.all()
.filter_raw(Q('category =', 'apparel',
'product =', 'defilbrilator') |
Q('product =', 'life jacket')))
eq_(sorted([r.product for r in qs]), ['life jacket'])
def test_having_gte(self):
c = Summary.objects.all().having('total >=', 2)[0]
eq_(c.category, 'safety')