Adds qs._clone() to fix ref problems, fixes get()

This commit is contained in:
Kumar McMillan 2011-02-18 13:22:29 -06:00
Родитель 6eaadcf507
Коммит 8107137735
2 изменённых файлов: 44 добавлений и 30 удалений

Просмотреть файл

@ -1,10 +1,10 @@
import copy
import re
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
from django.db import connection, models
from django.db.models import Q
from django.db.models.sql.query import AND, OR
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
from django.utils.tree import Node
@ -45,9 +45,11 @@ class RawSQLManager(object):
self.base_query['having'] = []
if 'order_by' not in self.base_query:
self.base_query['order_by'] = []
self.base_query['limit'] = []
if 'limit' not in self.base_query:
self.base_query['limit'] = []
if '_args' not in self.base_query:
self.base_query['_args'] = {}
self._cursor = None
self._args = {}
self._record_set = []
def __iter__(self):
@ -69,8 +71,7 @@ class RawSQLManager(object):
return self._record_set
elif isinstance(key, int):
if not len(self._record_set):
# Fetch just enough rows to get this one item:
self.base_query['limit'] = [key + 1]
# Get all rows! Better to use a limit with slices.
self._build_cursor()
self._build_record_set()
return self._record_set[key]
@ -78,8 +79,7 @@ class RawSQLManager(object):
raise TypeError('Key must be a slice or integer.')
def all(self):
return self.__class__(self.sql_model,
base_query=copy.deepcopy(self.base_query))
return self._clone()
def count(self):
"""Count of all results, preserving aggregate grouping."""
@ -87,15 +87,15 @@ class RawSQLManager(object):
return self._cursor.fetchone()[0]
def get(self):
rows = list(self)
cnt = len(rows)
clone = self._clone()
cnt = clone.count()
if cnt > 1:
raise self.sql_model.MultipleObjectsReturned(
raise clone.sql_model.MultipleObjectsReturned(
'get() returned more than one row -- it returned %s!' % cnt)
elif cnt == 0:
raise self.sql_model.DoesNotExist('No rows matching query')
raise clone.sql_model.DoesNotExist('No rows matching query')
else:
return rows[0]
return clone[0:1][0]
def exclude(self, *args, **kw):
raise NotImplementedError()
@ -109,17 +109,18 @@ class RawSQLManager(object):
qs = qs.filter(Q(type=1) | Q(name='foo'))
"""
clone = self._clone()
for arg in args:
if isinstance(arg, Q):
self.base_query['where'].append('(%s)'
% (self._kw_clause_from_q(arg)))
clone.base_query['where'].append(
'(%s)' % (clone._kw_clause_from_q(arg)))
else:
raise TypeError(
'non keyword args should be Q objects, got %r' % arg)
for field, val in kw.items():
self.base_query['where'].append(self._kw_filter_to_clause(field,
val))
return self
clone.base_query['where'].append(clone._kw_filter_to_clause(field,
val))
return clone
def filter_raw(self, *args):
"""Adds a where clause in limited SQL.
@ -135,16 +136,17 @@ class RawSQLManager(object):
That is, it will be replaced with the actual expression when the
query is built.
"""
clone = self._clone()
specs = []
for arg in args:
if isinstance(arg, Q):
self.base_query['where'].append(
'(%s)' % (self._raw_clause_from_q(arg)))
clone.base_query['where'].append(
'(%s)' % (clone._raw_clause_from_q(arg)))
else:
specs.append(arg)
if len(specs):
self.base_query['where'].append(self._filter_to_clause(*specs))
return self
clone.base_query['where'].append(clone._filter_to_clause(*specs))
return clone
def having(self, spec, val):
"""Adds a having clause in limited SQL.
@ -158,13 +160,18 @@ class RawSQLManager(object):
That is, it will be replaced with the actual expression when the
query is built.
"""
self.base_query['having'].append(self._filter_to_clause(spec, val))
return self
clone = self._clone()
clone.base_query['having'].append(clone._filter_to_clause(spec, val))
return clone
def as_sql(self):
stmt = self._compile(self.base_query)
return stmt
def _clone(self):
return self.__class__(self.sql_model,
base_query=copy.deepcopy(self.base_query))
def _parse_q(self, q_object):
"""Returns a parsed Q object.
@ -231,8 +238,9 @@ class RawSQLManager(object):
# Support filtering by alias, similar to how a view works
field = self.base_query['select'][field]
if clause.group('op').lower() == 'in':
# eg. WHERE foo IN (1, 2, 3)
parts = ['%%(%s)s' % self._param(p) for p in iter(val)]
# eg. WHERE foo IN (%(param_0)s, %(param_1)s, %(param_2)s)
# WHERE foo IN (1, 2, 3)
parts = ['%(' + self._param(p) + ')s' for p in iter(val)]
param = '(%s)' % ', '.join(parts)
else:
param = '%%(%s)s' % self._param(val)
@ -254,8 +262,9 @@ class RawSQLManager(object):
else:
dir = 'ASC'
field = spec
self.base_query['order_by'].append('%s %s' % (field, dir))
return self
clone = self._clone()
clone.base_query['order_by'].append('%s %s' % (field, dir))
return clone
def _compile(self, parts):
sep = ",\n"
@ -279,11 +288,11 @@ class RawSQLManager(object):
def _execute(self, sql):
self._record_set = []
self._cursor = connection.cursor()
self._cursor.execute(sql, self._args)
self._cursor.execute(sql, self.base_query['_args'])
def _param(self, val):
param_k = 'param_%s' % len(self._args.keys())
self._args[param_k] = val
param_k = 'param_%s' % len(self.base_query['_args'].keys())
self.base_query['_args'][param_k] = val
return param_k
def _build_cursor(self):

Просмотреть файл

@ -118,6 +118,11 @@ class TestSQLModel(unittest.TestCase):
c = Summary.objects.all().order_by('category')[0]
eq_(c.category, 'apparel')
def test_get_by_index(self):
qs = Summary.objects.all().order_by('category')
eq_(qs[0].category, 'apparel')
eq_(qs[1].category, 'safety')
def test_get(self):
c = Summary.objects.all().having('total =', 1).get()
eq_(c.category, 'apparel')