Adds qs._clone() to fix ref problems, fixes get()
This commit is contained in:
Родитель
6eaadcf507
Коммит
8107137735
|
@ -1,10 +1,10 @@
|
|||
import copy
|
||||
import re
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
|
||||
from django.db import connection, models
|
||||
from django.db.models import Q
|
||||
from django.db.models.sql.query import AND, OR
|
||||
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
|
||||
from django.utils.tree import Node
|
||||
|
||||
|
||||
|
@ -45,9 +45,11 @@ class RawSQLManager(object):
|
|||
self.base_query['having'] = []
|
||||
if 'order_by' not in self.base_query:
|
||||
self.base_query['order_by'] = []
|
||||
self.base_query['limit'] = []
|
||||
if 'limit' not in self.base_query:
|
||||
self.base_query['limit'] = []
|
||||
if '_args' not in self.base_query:
|
||||
self.base_query['_args'] = {}
|
||||
self._cursor = None
|
||||
self._args = {}
|
||||
self._record_set = []
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -69,8 +71,7 @@ class RawSQLManager(object):
|
|||
return self._record_set
|
||||
elif isinstance(key, int):
|
||||
if not len(self._record_set):
|
||||
# Fetch just enough rows to get this one item:
|
||||
self.base_query['limit'] = [key + 1]
|
||||
# Get all rows! Better to use a limit with slices.
|
||||
self._build_cursor()
|
||||
self._build_record_set()
|
||||
return self._record_set[key]
|
||||
|
@ -78,8 +79,7 @@ class RawSQLManager(object):
|
|||
raise TypeError('Key must be a slice or integer.')
|
||||
|
||||
def all(self):
|
||||
return self.__class__(self.sql_model,
|
||||
base_query=copy.deepcopy(self.base_query))
|
||||
return self._clone()
|
||||
|
||||
def count(self):
|
||||
"""Count of all results, preserving aggregate grouping."""
|
||||
|
@ -87,15 +87,15 @@ class RawSQLManager(object):
|
|||
return self._cursor.fetchone()[0]
|
||||
|
||||
def get(self):
|
||||
rows = list(self)
|
||||
cnt = len(rows)
|
||||
clone = self._clone()
|
||||
cnt = clone.count()
|
||||
if cnt > 1:
|
||||
raise self.sql_model.MultipleObjectsReturned(
|
||||
raise clone.sql_model.MultipleObjectsReturned(
|
||||
'get() returned more than one row -- it returned %s!' % cnt)
|
||||
elif cnt == 0:
|
||||
raise self.sql_model.DoesNotExist('No rows matching query')
|
||||
raise clone.sql_model.DoesNotExist('No rows matching query')
|
||||
else:
|
||||
return rows[0]
|
||||
return clone[0:1][0]
|
||||
|
||||
def exclude(self, *args, **kw):
|
||||
raise NotImplementedError()
|
||||
|
@ -109,17 +109,18 @@ class RawSQLManager(object):
|
|||
qs = qs.filter(Q(type=1) | Q(name='foo'))
|
||||
|
||||
"""
|
||||
clone = self._clone()
|
||||
for arg in args:
|
||||
if isinstance(arg, Q):
|
||||
self.base_query['where'].append('(%s)'
|
||||
% (self._kw_clause_from_q(arg)))
|
||||
clone.base_query['where'].append(
|
||||
'(%s)' % (clone._kw_clause_from_q(arg)))
|
||||
else:
|
||||
raise TypeError(
|
||||
'non keyword args should be Q objects, got %r' % arg)
|
||||
for field, val in kw.items():
|
||||
self.base_query['where'].append(self._kw_filter_to_clause(field,
|
||||
val))
|
||||
return self
|
||||
clone.base_query['where'].append(clone._kw_filter_to_clause(field,
|
||||
val))
|
||||
return clone
|
||||
|
||||
def filter_raw(self, *args):
|
||||
"""Adds a where clause in limited SQL.
|
||||
|
@ -135,16 +136,17 @@ class RawSQLManager(object):
|
|||
That is, it will be replaced with the actual expression when the
|
||||
query is built.
|
||||
"""
|
||||
clone = self._clone()
|
||||
specs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, Q):
|
||||
self.base_query['where'].append(
|
||||
'(%s)' % (self._raw_clause_from_q(arg)))
|
||||
clone.base_query['where'].append(
|
||||
'(%s)' % (clone._raw_clause_from_q(arg)))
|
||||
else:
|
||||
specs.append(arg)
|
||||
if len(specs):
|
||||
self.base_query['where'].append(self._filter_to_clause(*specs))
|
||||
return self
|
||||
clone.base_query['where'].append(clone._filter_to_clause(*specs))
|
||||
return clone
|
||||
|
||||
def having(self, spec, val):
|
||||
"""Adds a having clause in limited SQL.
|
||||
|
@ -158,13 +160,18 @@ class RawSQLManager(object):
|
|||
That is, it will be replaced with the actual expression when the
|
||||
query is built.
|
||||
"""
|
||||
self.base_query['having'].append(self._filter_to_clause(spec, val))
|
||||
return self
|
||||
clone = self._clone()
|
||||
clone.base_query['having'].append(clone._filter_to_clause(spec, val))
|
||||
return clone
|
||||
|
||||
def as_sql(self):
|
||||
stmt = self._compile(self.base_query)
|
||||
return stmt
|
||||
|
||||
def _clone(self):
|
||||
return self.__class__(self.sql_model,
|
||||
base_query=copy.deepcopy(self.base_query))
|
||||
|
||||
def _parse_q(self, q_object):
|
||||
"""Returns a parsed Q object.
|
||||
|
||||
|
@ -231,8 +238,9 @@ class RawSQLManager(object):
|
|||
# Support filtering by alias, similar to how a view works
|
||||
field = self.base_query['select'][field]
|
||||
if clause.group('op').lower() == 'in':
|
||||
# eg. WHERE foo IN (1, 2, 3)
|
||||
parts = ['%%(%s)s' % self._param(p) for p in iter(val)]
|
||||
# eg. WHERE foo IN (%(param_0)s, %(param_1)s, %(param_2)s)
|
||||
# WHERE foo IN (1, 2, 3)
|
||||
parts = ['%(' + self._param(p) + ')s' for p in iter(val)]
|
||||
param = '(%s)' % ', '.join(parts)
|
||||
else:
|
||||
param = '%%(%s)s' % self._param(val)
|
||||
|
@ -254,8 +262,9 @@ class RawSQLManager(object):
|
|||
else:
|
||||
dir = 'ASC'
|
||||
field = spec
|
||||
self.base_query['order_by'].append('%s %s' % (field, dir))
|
||||
return self
|
||||
clone = self._clone()
|
||||
clone.base_query['order_by'].append('%s %s' % (field, dir))
|
||||
return clone
|
||||
|
||||
def _compile(self, parts):
|
||||
sep = ",\n"
|
||||
|
@ -279,11 +288,11 @@ class RawSQLManager(object):
|
|||
def _execute(self, sql):
|
||||
self._record_set = []
|
||||
self._cursor = connection.cursor()
|
||||
self._cursor.execute(sql, self._args)
|
||||
self._cursor.execute(sql, self.base_query['_args'])
|
||||
|
||||
def _param(self, val):
|
||||
param_k = 'param_%s' % len(self._args.keys())
|
||||
self._args[param_k] = val
|
||||
param_k = 'param_%s' % len(self.base_query['_args'].keys())
|
||||
self.base_query['_args'][param_k] = val
|
||||
return param_k
|
||||
|
||||
def _build_cursor(self):
|
||||
|
|
|
@ -118,6 +118,11 @@ class TestSQLModel(unittest.TestCase):
|
|||
c = Summary.objects.all().order_by('category')[0]
|
||||
eq_(c.category, 'apparel')
|
||||
|
||||
def test_get_by_index(self):
|
||||
qs = Summary.objects.all().order_by('category')
|
||||
eq_(qs[0].category, 'apparel')
|
||||
eq_(qs[1].category, 'safety')
|
||||
|
||||
def test_get(self):
|
||||
c = Summary.objects.all().having('total =', 1).get()
|
||||
eq_(c.category, 'apparel')
|
||||
|
|
Загрузка…
Ссылка в новой задаче