SQLModel supports nested Q objects now (+ recursive brain hurt).

This commit is contained in:
Kumar McMillan 2011-02-22 16:09:04 -06:00
Родитель 202fd8ec0b
Коммит 6a24fbbd91
2 изменённых файлов: 77 добавлений и 62 удалений

Просмотреть файл

@ -113,7 +113,7 @@ class RawSQLManager(object):
for arg in args:
if isinstance(arg, Q):
clone.base_query['where'].append(
'(%s)' % (clone._kw_clause_from_q(arg)))
'(%s)' % (clone._flatten_q(arg, clone._kw_clause_from_q)))
else:
raise TypeError(
'non keyword args should be Q objects, got %r' % arg)
@ -141,7 +141,7 @@ class RawSQLManager(object):
for arg in args:
if isinstance(arg, Q):
clone.base_query['where'].append(
'(%s)' % (clone._raw_clause_from_q(arg)))
'(%s)' % (clone._flatten_q(arg, clone._filter_to_clause)))
else:
specs.append(arg)
if len(specs):
@ -164,6 +164,21 @@ class RawSQLManager(object):
clone.base_query['having'].append(clone._filter_to_clause(spec, val))
return clone
def order_by(self, spec):
"""Order by column (ascending) or -column (descending)."""
if not ORDER_PATTERN.match(spec):
raise ValueError('Invalid order by value: %r' % spec)
if spec.startswith('-'):
dir = 'DESC'
field = spec[1:]
else:
dir = 'ASC'
field = spec
clone = self._clone()
clone.base_query['order_by'].append(
'%s %s' % (clone._resolve_alias(field), dir))
return clone
def as_sql(self):
stmt = self._compile(self.base_query)
return stmt
@ -172,54 +187,59 @@ class RawSQLManager(object):
return self.__class__(self.sql_model,
base_query=copy.deepcopy(self.base_query))
def _parse_q(self, q_object):
"""Returns a parsed Q object.
def _flatten_q(self, q_object, join_specs, stack=None):
"""Makes a WHERE clause out of a Q object (supports nested Q objects).
eg. [([('product =', 'AND'), ('life jacket', 'AND')], 'OR'),
([('product =', 'AND'), ('defilbrilator', 'AND')], 'OR')]
Pass in join_specs(*specs) based on what kind of arguments you think
the Q object will have. filter() Qs are different from filter_raw() Qs.
"""
specs = []
if stack is None:
stack = [None]
# TODO(Kumar): construct NOT clause:
if q_object.negated:
raise NotImplementedError('negated Q objects')
connector = q_object.connector
def add(specs):
c = join_specs(*specs, connector=connector)
if stack[-1] in (AND, OR):
c = u'(%s)' % (c)
elif stack[-1] is not None:
stack.append(connector)
if c:
stack.append(c)
for child in q_object.children:
connector = q_object.connector
if isinstance(child, Node):
sp = self._parse_q(child)
specs.append((sp, connector))
add(specs)
specs[:] = []
self._flatten_q(child, join_specs, stack=stack)
else:
specs.append((child, connector))
return specs
specs.append(child)
if len(specs):
add(specs)
return u' '.join([c for c in stack if c])
def _raw_clause_from_q(self, q_object):
parts = self._parse_q(q_object)
clause = []
# TODO(Kumar) this doesn't handle nesting!
for part in parts:
specs, connector = part
# Remove the AND in each spec part:
specs = [s[0] for s in specs]
clause.extend([self._filter_to_clause(*specs),
connector])
return u' '.join(clause[:-1]) # skip the last connector
def _kw_clause_from_q(self, q_object):
parts = self._parse_q(q_object)
clause = []
for part in parts:
specs, connector = part
clause.extend([self._kw_filter_to_clause(*specs),
connector])
return u' '.join(clause[:-1]) # skip the last connector
def _kw_clause_from_q(self, *pairs, **kw):
"""Makes a WHERE clause out of pairs of (key, val) from Q objects."""
connector = kw.get('connector', AND)
stmt = []
for field, val in pairs:
stmt.append(self._kw_filter_to_clause(field, val))
return (u' %s ' % connector).join(stmt)
def _kw_filter_to_clause(self, field, val):
"""Makes a WHERE clause out of field = val."""
if not FIELD_PATTERN.match(field):
raise ValueError('Not a valid field for where clause: %r' % field)
param_k = self._param(val)
field = self._resolve_alias(field)
return '%s = %%(%s)s' % (field, param_k)
return u'%s = %%(%s)s' % (field, param_k)
def _filter_to_clause(self, *specs):
def _filter_to_clause(self, *specs, **kw):
"""Makes a WHERE clause out of filter_raw() arguments."""
connector = kw.get('connector', AND)
specs = list(specs)
if (len(specs) % 2) != 0:
raise TypeError(
@ -242,51 +262,36 @@ class RawSQLManager(object):
else:
param = '%%(%s)s' % self._param(val)
full_clause.append('%s %s %s' % (field, clause.group('op'), param))
and_ = u' %s ' % AND
c = and_.join(full_clause)
c = (u' %s ' % connector).join(full_clause)
if len(full_clause) > 1:
# Protect OR clauses
c = u'(%s)' % c
return c
def _resolve_alias(self, field):
"""Access a field (or expression) by alias, similar to how a view works.
"""
if field in self.base_query['select']:
# Support accedssing a field by alias, similar to how a view works
field = self.base_query['select'][field]
return field
def order_by(self, spec):
"""Order by column (ascending) or -column (descending)."""
if not ORDER_PATTERN.match(spec):
raise ValueError('Invalid order by value: %r' % spec)
if spec.startswith('-'):
dir = 'DESC'
field = spec[1:]
else:
dir = 'ASC'
field = spec
clone = self._clone()
clone.base_query['order_by'].append(
'%s %s' % (clone._resolve_alias(field), dir))
return clone
def _compile(self, parts):
sep = ",\n"
and_ = ' %s\n' % AND
select = ['%s AS %s' % (v, k) for k, v in parts['select'].items()]
stmt = "SELECT\n%s\nFROM\n%s" % (sep.join(select),
"\n".join(parts['from']))
sep = u",\n"
and_ = u' %s\n' % AND
select = [u'%s AS %s' % (v, k) for k, v in parts['select'].items()]
stmt = u"SELECT\n%s\nFROM\n%s" % (sep.join(select),
u"\n".join(parts['from']))
if parts.get('where'):
stmt = "%s\nWHERE\n%s" % (stmt, and_.join(parts['where']))
stmt = u"%s\nWHERE\n%s" % (stmt, and_.join(parts['where']))
if parts.get('group_by'):
stmt = "%s\nGROUP BY\n%s" % (stmt, parts['group_by'])
stmt = u"%s\nGROUP BY\n%s" % (stmt, parts['group_by'])
if parts.get('having'):
stmt = "%s\nHAVING\n%s" % (stmt, sep.join(parts['having']))
stmt = u"%s\nHAVING\n%s" % (stmt, sep.join(parts['having']))
if parts.get('order_by'):
stmt = "%s\nORDER BY\n%s" % (stmt, sep.join(parts['order_by']))
stmt = u"%s\nORDER BY\n%s" % (stmt, sep.join(parts['order_by']))
if len(parts['limit']):
stmt = "%s\nLIMIT %s" % (stmt, ', '.join([str(i) for i in
parts['limit']]))
stmt = u"%s\nLIMIT %s" % (stmt, ', '.join([str(i) for i in
parts['limit']]))
return stmt
def _execute(self, sql):

Просмотреть файл

@ -208,6 +208,16 @@ class TestSQLModel(unittest.TestCase):
Q('product =', 'life jacket')))
eq_(sorted([r.product for r in qs]), ['life jacket'])
def test_crazy_nesting(self):
qs = (ProductDetail.objects.all()
.filter_raw(Q('category =', 'apparel',
'product =', 'defilbrilator',
Q('product =', 'life jacket') |
Q('product =', 'snake skin jacket'),
'category =', 'safety')))
# print qs.as_sql()
eq_(sorted([r.product for r in qs]), ['life jacket'])
def test_having_gte(self):
c = Summary.objects.all().having('total >=', 2)[0]
eq_(c.category, 'safety')