Testing when a ValidationError is raised - django

I am new to programming and Django in general. I am trying to test one of my functions to make sure that a validation error is raised. The test confirms that the error is raised but also says the test Failed. How is this possible?
**models.py**
def check_user_words(sender, instance, **kwargs):
for field in instance._meta.get_fields():
#field_name = getattr(instance, field.attname)
if (isinstance(field, models.CharField) and
contains_bad_words(getattr(instance, field.attname))):
raise ValidationError("We don't use words like '{}' around here!".format(getattr(instance, field.attname)))
#tests.py
from __future__ import unicode_literals
import datetime
from django.test import TestCase
from django.utils import timezone
from django.test import TestCase
from django.urls import reverse
from .models import Question, Choice, contains_bad_words, check_user_words
from django.core.exceptions import ValidationError
def create_question(question_text, days):
time = timezone.now() + datetime.timedelta(days=days)
return Question.objects.create(question_text=question_text, pub_date=time)
class ContainsBadWordsTests(TestCase):
def test_check_user_words(self):
question = create_question(question_text="What a minute bucko", days=1)
with self.assertRaises(ValidationError):
check_user_words(question)
question.full_clean()
#after running python manage.py test polls
......
raise ValidationError("We don't use words like '{}' around here!".format(getattr(instance, field.attname)))
ValidationError: [u"We don't use words like 'What a minute bucko' around here!"]
models.py How I import
from __future__ import unicode_literals .... (and others)
filepath = "polls/static/polls/blacklist.yaml"
config = yaml_loader(filepath)
blacklist = [word.lower() for word in config['blacklist']]
def contains_bad_words(user_input_txt):
""" remove punctuation from text
and make it case-insensitive"""
user_typ = user_input_txt.encode()
translate_table = maketrans(string.punctuation, 32 * " ")
words = user_typ.translate(translate_table).lower().split()
for bad_word in blacklist:
for word in words:
if word == bad_word:
return True
return False
#receiver(pre_save)
def check_user_words(sender, instance, **kwargs):
for field in instance._meta.get_fields():
if (isinstance(field, models.CharField) and
contains_bad_words(getattr(instance, field.attname))):
raise ValidationError("We don't use words like '{}' around here!".format(getattr(instance, field.attname)))

We need to see more of your code (specifically create_question() and how check_user_words is connected to a signal) to be sure, but I think the issue is that you are using a post_save signal handler to execute check_user_words().
If this is the case, then the reason your test is failing is that create_question() will cause the post_save signal to fire, and check_user_words() will be executed immediately - i.e., before the with self.assertRaises context, and hence your test fails.
If this is the case, then try this:
def test_check_user_words(self):
with self.assertRaises(ValidationError):
create_question(question_text="What a minute bucko", days=1)
This test should now pass, because the validation error will be thrown as soon as you try to create the question.
Note however that doing this in a signal will result in an uncaught exception when something tries to save an object. Depending on what your use case is, you might be better off doing this in the clean() method of the model itself (see docs here), because this will cause appropriate errors to be reported on model forms etc:
def clean(self):
for field in instance._meta.get_fields():
if (isinstance(field, models.CharField) and contains_bad_words(getattr(instance, field.attname))):
raise ValidationError("We don't use words like '{}' around here!".format(getattr(instance, field.attname)))
(and then drop your signal handler). Then you can test this with:
q = create_question(question_text="What a minute bucko", days=1)
with self.assertRaises(ValidationError):
q.clean()

Related

Why are integration tests failing on updates to model objects when the function is run on Django q-cluster?

I am running some django integration tests of code that passes some functions to Django-Q for processing. This code essentially updates some fields on a model object.
The app uses signals.py to listen for a post_save change of the status field of the Foo object. Assuming it's not a newly created object and it has the correct status, the Foo.prepare_foo() function is called. This is handled by services.py which hands it off to q-cluster. Assuming this executes without error hooks.py changes the status field of Foo to published, or keeps it at prepare if it fails. If it passes then it also sets prepared to True. (I know this sounds convoluted and with overlapping variables - part of the desire to get tests running is to be able to refactor).
The code runs correctly in all environments, but when I try to run tests to assert that the fields have been updated, they fail.
(If I tweak the code to bypass the cluster and have the functions run in-memory then the tests pass.)
Given this, I'm assuming the issue lies in how I've written the tests, but I cannot diagnose the issue.
tests_prepare_foo.py
import time
from django.test import TestCase, override_settings, TransactionTestCase
from app.models import Foo
class CreateCourseTest(TransactionTestCase):
reset_sequences = True
#classmethod
def setUp(cls):
cls.foo = Foo(status='draft',
prepared=False,
)
cls.foo.save()
def test_foo_prepared(self):
self.foo.status = 'prepare'
self.foo.save()
time.sleep(15) # to allow the cluster to finish processing the request
self.assertEquals(self.foo.prepared, True)
models.py
import uuid
from django.db import models
from model_utils.fields import StatusField
from model_utils import Choices
class Foo(models.Model):
ref = models.UUIDField(default=uuid.uuid4,
editable=False,
)
STATUS = Choices('draft', 'prepare', 'published', 'archived')
status = StatusField(null=True,
blank=True,
)
prepared = models.BooleanField(default=False)
def prepare_foo(self):
""""
...
do stuff
"""
signals.py
from django.dispatch import receiver
from django.db.models.signals import post_save
from django_q.tasks import async_task
from app.models import Foo
#receiver(post_save, sender=Foo)
def make_foo(sender, instance, **kwargs):
if not kwargs.get('created', False) and instance.status == 'prepare' and not instance.prepared:
async_task('app.services.prepare_foo',
instance,
hook='app.hooks.check_foo_prepared',
)
services.py
def prepare_foo(foo):
foo.prepare_foo()
hooks.py
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def check_foo_prepared(task):
foo = task.args[0]
if task.success:
logger.info("Foo Prepared: Successful")
foo.status = 'published'
foo.prepared = True
foo.save()
logger.info("Foo Status: %s", foo.status)
logger.info("Foo Prepared: %s", foo.prepared)
else:
logger.info("Foo Prepared: Unsuccessful")
foo.status = 'draft'
foo.save()
Finally - logs when the test is run with the cluster on:
q-cluster server
INFO:app.hooks:Foo Prepared: Successful
INFO:app.hooks: Foo Status: published
INFO:app.hooks: Foo Prepared: True
django server
h:m:s [Q] INFO Enqueued 1
F
======
FAIL
self.assertEquals(self.foo.prepared, True)
AssertionError: False != True
I think I'm either missing something obvious in my tests, or something really subtle, but I can't work out which. I've tried setting the cluster to run synchronously (sync=True), and in my test reloading Foo just before the assertion:
self.foo.save()
time.sleep(15)
test_foo = Foo.objects.get(pk=1)
self.assertEquals(test_foo.prepared, True)
But this also fails with
self.assertEquals(test_foo.prepared, True)
AssertionError: False != True
Which leads me to believe that the cluster is not updating the object under test (unlikely), or the assertion is being checked before the cluster has updated the object (more likely).
This is the first time I've written tests that require hand-offs to a cluster, so any pointers, suggestions gratefully received!

AssertionError on unit testing a celery task with autoretry, backoff and jitter

Using celery 4.3.0. I tried to write a unit test for the following task.
from django.core.exceptions import ObjectDoesNotExist
#shared_task(autoretry_for=(ObjectDoesNotExist,), max_retries=5, retry_backoff=10)
def process_something(data):
product = Product()
product.process(data)
Unit test:
#mock.patch('proj.tasks.Product')
#mock.patch('proj.tasks.process_something.retry')
def test_process_something_retry_failed_task(self, process_something_retry, mock_product):
mock_object = mock.MagicMock()
mock_product.return_value = mock_object
mock_object.process.side_effect = error = ObjectDoesNotExist()
with pytest.raises(ObjectDoesNotExist):
process_something(self.data)
process_something_retry.assert_called_with(exc=error)
This is the error I get after running the test:
#wraps(task.run)
def run(*args, **kwargs):
try:
return task._orig_run(*args, **kwargs)
except autoretry_for as exc:
if retry_backoff:
retry_kwargs['countdown'] = \
get_exponential_backoff_interval(
factor=retry_backoff,
retries=task.request.retries,
maximum=retry_backoff_max,
full_jitter=retry_jitter)
> raise task.retry(exc=exc, **retry_kwargs)
E TypeError: exceptions must derive from BaseException
I understand it is because of the exception. I replaced ObjectDoesNotExist everywhere with Exception instead. After running the test, I get this error:
def assert_called_with(self, /, *args, **kwargs):
"""assert that the last call was made with the specified arguments.
Raises an AssertionError if the args and keyword args passed in are
different to the last call to the mock."""
if self.call_args is None:
expected = self._format_mock_call_signature(args, kwargs)
actual = 'not called.'
error_message = ('expected call not found.\nExpected: %s\nActual: %s'
% (expected, actual))
raise AssertionError(error_message)
def _error_message():
msg = self._format_mock_failure_message(args, kwargs)
return msg
expected = self._call_matcher((args, kwargs))
actual = self._call_matcher(self.call_args)
if expected != actual:
cause = expected if isinstance(expected, Exception) else None
> raise AssertionError(_error_message()) from cause
E AssertionError: expected call not found.
E Expected: retry(exc=Exception())
E Actual: retry(exc=Exception(), countdown=7)
Please let me know how I can fix both the errors.
I had the similar issue, while I was working on tests to ensure that the celery retry logic was covering my specific scenarios. What worked for me was to use explicit retry instead of the autoretry_for parameter.
I have adjusted your code to my solution. Although my solution didn't use
shared_task I think It should work likewise. Tested on celery==5.1.2
task:
from django.core.exceptions import ObjectDoesNotExist
#shared_task(bind=True, max_retries=5, retry_backoff=10)
def process_something(self, data):
try:
product = Product()
product.process(data)
except ObjectDoesNotExist as exc:
raise self.retry(exc=exc)
test:
from proj.tasks import Product # I assume the Product class is located here
from django.core.exceptions import ObjectDoesNotExist
import celery
#mock.patch.object(Product, "__init__", Mock(return_value=None)) # just mocking the init method
#mock.patch.object(Product, "process")
#mock.patch('proj.tasks.process_something.retry')
def test_process_something_retry_failed_task(self, retry_mock, process_mock):
exc = ObjectDoesNotExist()
process_mock.side_effect = exc
retry_mock.side_effect = celery.exceptions.Retry
with pytest.raises(celery.exceptions.Retry):
process_something(self.data)
retry_mock.assert_called_with(exc=exc)
In my problem I also was using custom exceptions. With this solution I didnt need change the type of my exceptions.

Why don't my Django unittests know that MessageMiddleware is installed?

I'm working on a Django project and am writing unittests for it. However, in a test, when I try and log a user in, I get this error:
MessageFailure: You cannot add messages without installing django.contrib.messages.middleware.MessageMiddleware
Logging in on the actual site works fine -- and a login message is displayed using the MessageMiddleware.
In my tests, if I do this:
from django.conf import settings
print settings.MIDDLEWARE_CLASSES
Then it outputs this:
('django.middleware.cache.UpdateCacheMiddleware',
'django.middleware.common.CommonMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'django.middleware.cache.FetchFromCacheMiddleware',
'debug_toolbar.middleware.DebugToolbarMiddleware')
Which appears to show the MessageMiddleware is installed when tests are run.
Is there an obvious step I'm missing?
UPDATE
After suggestions below, it does look like it's a settings thing.
I currently have settings/__init__.py like this:
try:
from settings.development import *
except ImportError:
pass
and settings/defaults.py containing most of the standard settings (including MIDDLEWARE_CLASSES). And then settings.development.py overrides some of those defaults like this:
from defaults import *
DEBUG = True
# etc
It looks like my dev site itself works fine, using the development settings. But although the tests seem to load the settings OK (both defaults and development) settings.DEBUG is set to False. I don't know why, or whether that's the cause of the problem.
Django 1.4 has a expected behavior when you create the request with RequestFactory that can trigger this error.
To resolve this issue, create your request with RequestFactory and do this:
from django.contrib.messages.storage.fallback import FallbackStorage
setattr(request, 'session', 'session')
messages = FallbackStorage(request)
setattr(request, '_messages', messages)
Works for me!
A way to solve this quite elegant is to mock the messages module using mock
Say you have a class based view named FooView in app named myapp
from django.contrib import messages
from django.views.generic import TemplateView
class FooView(TemplateView):
def post(self, request, *args, **kwargs):
...
messages.add_message(request, messages.SUCCESS, '\o/ Profit \o/')
...
You now can test it with
def test_successful_post(self):
mock_messages = patch('myapp.views.FooView.messages').start()
mock_messages.SUCCESS = success = 'super duper'
request = self.rf.post('/', {})
view = FooView.as_view()
response = view(request)
msg = _(u'\o/ Profit \o/')
mock_messages.add_message.assert_called_with(request, success, msg)
In my case (django 1.8) this problem occurs in when unit-test calls signal handler for user_logged_in signal, looks like messages app has not been called, i.e. request._messages is not yet set. This fails:
from django.contrib.auth.signals import user_logged_in
...
#receiver(user_logged_in)
def user_logged_in_handler(sender, user, request, **kwargs):
...
messages.warning(request, "user has logged in")
the same call to messages.warning in normal view function (that is called after) works without any issues.
A workaround I based on one of the suggestions from https://code.djangoproject.com/ticket/17971, use fail_silently argument only in signal handler function, i.e. this solved my problem:
messages.warning(request, "user has logged in",
fail_silently=True )
Do you only have one settings.py?
Tests create custom (tests) database. Maybe you have no messages there or something... Maybe you need setUp() fixtures or something?
Need more info to answer properly.
Why not simply do something like ? You sure run tests in debug mode right?
# settings.py
DEBUG = True
from django.conf import settings
# where message is sent:
if not settings.DEBUG:
# send your message ...
This builds on Tarsis Azevedo's answer by creating a MessagingRequest helper class below.
Given say a KittenAdmin I'd want to get 100% test coverage for:
from django.contrib import admin, messages
class KittenAdmin(admin.ModelAdmin):
def warm_fuzzy_method(self, request):
messages.warning(request, 'Can I haz cheezburger?')
I created a MessagingRequest helper class to use in say a test_helpers.py file:
from django.contrib.messages.storage.fallback import FallbackStorage
from django.http import HttpRequest
class MessagingRequest(HttpRequest):
session = 'session'
def __init__(self):
super(MessagingRequest, self).__init__()
self._messages = FallbackStorage(self)
def get_messages(self):
return getattr(self._messages, '_queued_messages')
def get_message_strings(self):
return [str(m) for m in self.get_messages()]
Then in a standard Django tests.py:
from django.contrib.admin.sites import AdminSite
from django.test import TestCase
from cats.kitten.admin import KittenAdmin
from cats.kitten.models import Kitten
from cats.kitten.test_helpers import MessagingRequest
class KittenAdminTest(TestCase):
def test_kitten_admin_message(self):
admin = KittenAdmin(model=Kitten, admin_site=AdminSite())
expect = ['Can I haz cheezburger?']
request = MessagingRequest()
admin.warm_fuzzy_method(request)
self.assertEqual(request.get_message_strings(), expect)
Results:
coverage run --include='cats/kitten/*' manage.py test; coverage report -m
Creating test database for alias 'default'...
.
----------------------------------------------------------------------
Ran 1 test in 0.001s
OK
Destroying test database for alias 'default'...
Name Stmts Miss Cover Missing
----------------------------------------------------------------------
cats/kitten/__init__.py 0 0 100%
cats/kitten/admin.py 4 0 100%
cats/kitten/migrations/0001_initial.py 5 0 100%
cats/kitten/migrations/__init__.py 0 0 100%
cats/kitten/models.py 3 0 100%
cats/kitten/test_helpers.py 11 0 100%
cats/kitten/tests.py 12 0 100%
----------------------------------------------------------------------
TOTAL 35 0 100%
This happened to me in the login_callback signal receiver function when called from a unit test, and the way around the problem was:
from django.contrib.messages.storage import default_storage
#receiver(user_logged_in)
def login_callback(sender, user, request, **kwargs):
if not hasattr(request, '_messages'): # fails for tests
request._messages = default_storage(request)
Django 2.0.x
I found when I had a problem patching messages the solution was to patch the module from within the class under test (obsolete Django version BTW, YMMV). Pseudocode follows.
my_module.py:
from django.contrib import messages
class MyClass:
def help(self):
messages.add_message(self.request, messages.ERROR, "Foobar!")
test_my_module.py:
from unittest import patch, MagicMock
from my_module import MyClass
class TestMyClass(TestCase):
def test_help(self):
with patch("my_module.messages") as mock_messages:
mock_messages.add_message = MagicMock()
MyClass().help() # shouldn't complain about middleware
If you're seeing a problem in your Middleware, then you're not doing "Unit Test". Unit tests test a unit of functionality. If you interact with other parts of your system, you're making something called "integration" testing.
You should try to write better tests, and this kind of problems shouldn't arise. Try RequestFactory. ;)
def test_some_view(self):
factory = RequestFactory()
user = get_mock_user()
request = factory.get("/my/view")
request.user = user
response = my_view(request)
self.asssertEqual(status_code, 200)

Getting all queries that django run on postgresql

I am working on a django-postgresql project and I need to see every query that django run on database(so I can fine-tune queries). Is there a way to get those queries.
Update: My development environment is on ubuntu linux
Well, you could just set the pgsql server to log every query. Or just to log the slow ones. Look in the postgresql.conf file, it's pretty close to self-documenting.
Check out this Question (and the two top most answers):
django orm, how to view (or log) the executed query?
You can also have a look at the Djando documenation:
https://docs.djangoproject.com/en/dev/faq/models/#how-can-i-see-the-raw-sql-queries-django-is-running
Hope this helps,
Anton
You can decorate a request handler or other function with this and it will print the sql nicely formated with totals at the end.
from functools import wraps
from django.utils import termcolors
format_ok = termcolors.make_style(opts=('bold',), fg='green')
format_warning = termcolors.make_style(opts=('bold',), fg='yellow')
format_error = termcolors.make_style(opts=('bold',), fg='red')
try:
from pygments import highlight
from pygments.lexers import SqlLexer
from pygments.formatters import TerminalFormatter
pygments_sql_lexer = SqlLexer()
pygments_terminal_formatter = TerminalFormatter()
highlight_sql = lambda s: highlight(s, pygments_sql_lexer,
pygments_terminal_formatter)
except ImportError:
highlight_sql = lambda s: s
def debug_sql(f):
"""
Turn SQL statement debugging on for a test run.
"""
#wraps(f)
def wrapper(*a, **kw):
from django.conf import settings
from django.db import connection
try:
debug = settings.DEBUG
settings.DEBUG = True
connection.queries = []
return f(*a, **kw)
finally:
total_time = 0
for q in connection.queries:
fmt = format_ok
t = float(q['time'])
total_time += t
if t > 1:
fmt = format_error
elif t > 0.3:
fmt = format_warning
print '[%s] %s' % (fmt(q['time']), highlight_sql(q['sql']))
print "total time =", total_time
print "num queries =", len(connection.queries)
settings.DEBUG = debug
return wrapper
Try the django debug toolbar. It'll show you all the SQL executed over the request. When something is executing way too many queries, it becomes really slow, though. For that, I've been meaning to try out this profiler. However, I've rolled this middleware on a couple of projects:
try:
from cStringIO import StringIO
except ImportError:
import StringIO
from django.conf import settings
from django.db import connection
class DatabaseProfilerMiddleware(object):
def can(self, request):
return settings.DEBUG and 'dbprof' in request.GET
def process_response(self, request, response):
if self.can(request):
out = StringIO()
out.write('time sql\n')
total_time = 0
for query in reversed(sorted(connection.queries, key=lambda x: x['time'])):
total_time += float(query['time'])*1000
out.write('%s %s\n' % (query['time'], query['sql']))
response.content = '<pre style="white-space:pre-wrap">%d queries executed in %.3f seconds\n%s</pre>' \
% (len(connection.queries), total_time/1000, out.getvalue())
return response
Just go to the relevant URL for the request you are interested in and add a dbprof GET parameter, you'll see the profiling output instead of the normal response.

Django: is there a way to count SQL queries from an unit test?

I am trying to find out the number of queries executed by a utility function. I have written a unit test for this function and the function is working well. What I would like to do is track the number of SQL queries executed by the function so that I can see if there is any improvement after some refactoring.
def do_something_in_the_database():
# Does something in the database
# return result
class DoSomethingTests(django.test.TestCase):
def test_function_returns_correct_values(self):
self.assertEqual(n, <number of SQL queries executed>)
EDIT: I found out that there is a pending Django feature request for this. However the ticket is still open. In the meantime is there another way to go about this?
Since Django 1.3 there is a assertNumQueries available exactly for this purpose.
One way to use it (as of Django 3.2) is as a context manager:
# measure queries of some_func and some_func2
with self.assertNumQueries(2):
result = some_func()
result2 = some_func2()
Vinay's response is correct, with one minor addition.
Django's unit test framework actually sets DEBUG to False when it runs, so no matter what you have in settings.py, you will not have anything populated in connection.queries in your unit test unless you re-enable debug mode. The Django docs explain the rationale for this as:
Regardless of the value of the DEBUG setting in your configuration file, all Django tests run with DEBUG=False. This is to ensure that the observed output of your code matches what will be seen in a production setting.
If you're certain that enabling debug will not affect your tests (such as if you're specifically testing DB hits, as it sounds like you are), the solution is to temporarily re-enable debug in your unit test, then set it back afterward:
def test_myself(self):
from django.conf import settings
from django.db import connection
settings.DEBUG = True
connection.queries = []
# Test code as normal
self.assert_(connection.queries)
settings.DEBUG = False
If you are using pytest, pytest-django has django_assert_num_queries fixture for this purpose:
def test_queries(django_assert_num_queries):
with django_assert_num_queries(3):
Item.objects.create('foo')
Item.objects.create('bar')
Item.objects.create('baz')
If you don't want use TestCase (with assertNumQueries) or change settings to DEBUG=True, you can use context manager CaptureQueriesContext (same as assertNumQueries using).
from django.db import ConnectionHandler
from django.test.utils import CaptureQueriesContext
DB_NAME = "default" # name of db configured in settings you want to use - "default" is standard
connection = ConnectionHandler()[DB_NAME]
with CaptureQueriesContext(connection) as context:
... # do your thing
num_queries = context.initial_queries - context.final_queries
assert num_queries == expected_num_queries
db settings
In modern Django (>=1.8) it's well documented (it's also documented for 1.7) here, you have the method reset_queries instead of assigning connection.queries=[] which indeed is raising an error, something like that works on django>=1.8:
class QueriesTests(django.test.TestCase):
def test_queries(self):
from django.conf import settings
from django.db import connection, reset_queries
try:
settings.DEBUG = True
# [... your ORM code ...]
self.assertEquals(len(connection.queries), num_of_expected_queries)
finally:
settings.DEBUG = False
reset_queries()
You may also consider resetting queries on setUp/tearDown to ensure queries are reset for each test instead of doing it on finally clause, but this way is more explicit (although more verbose), or you can use reset_queries in the try clause as many times as you need to evaluate queries counting from 0.
Here is the working prototype of context manager withAssertNumQueriesLessThan
import json
from contextlib import contextmanager
from django.test.utils import CaptureQueriesContext
from django.db import connections
#contextmanager
def withAssertNumQueriesLessThan(self, value, using='default', verbose=False):
with CaptureQueriesContext(connections[using]) as context:
yield # your test will be run here
if verbose:
msg = "\r\n%s" % json.dumps(context.captured_queries, indent=4)
else:
msg = None
self.assertLess(len(context.captured_queries), value, msg=msg)
It can be simply used in your unit tests for example for checking the number of queries per Django REST API call
with self.withAssertNumQueriesLessThan(10):
response = self.client.get('contacts/')
self.assertEqual(response.status_code, 200)
Also you can provide exact DB using and verbose if you want to pretty-print list of actual queries to stdout
If you have DEBUG set to True in your settings.py (presumably so in your test environment) then you can count queries executed in your test as follows:
from django.db import connection
class DoSomethingTests(django.test.TestCase):
def test_something_or_other(self):
num_queries_old = len(connection.queries)
do_something_in_the_database()
num_queries_new = len(connection.queries)
self.assertEqual(n, num_queries_new - num_queries_old)
If you want to use a decorator for that there is a nice gist:
import functools
import sys
import re
from django.conf import settings
from django.db import connection
def shrink_select(sql):
return re.sub("^SELECT(.+)FROM", "SELECT .. FROM", sql)
def shrink_update(sql):
return re.sub("SET(.+)WHERE", "SET .. WHERE", sql)
def shrink_insert(sql):
return re.sub("\((.+)\)", "(..)", sql)
def shrink_sql(sql):
return shrink_update(shrink_insert(shrink_select(sql)))
def _err_msg(num, expected_num, verbose, func=None):
func_name = "%s:" % func.__name__ if func else ""
msg = "%s Expected number of queries is %d, actual number is %d.\n" % (func_name, expected_num, num,)
if verbose > 0:
queries = [query['sql'] for query in connection.queries[-num:]]
if verbose == 1:
queries = [shrink_sql(sql) for sql in queries]
msg += "== Queries == \n" +"\n".join(queries)
return msg
def assertNumQueries(expected_num, verbose=1):
class DecoratorOrContextManager(object):
def __call__(self, func): # decorator
#functools.wraps(func)
def inner(*args, **kwargs):
handled = False
try:
self.__enter__()
return func(*args, **kwargs)
except:
self.__exit__(*sys.exc_info())
handled = True
raise
finally:
if not handled:
self.__exit__(None, None, None)
return inner
def __enter__(self):
self.old_debug = settings.DEBUG
self.old_query_count = len(connection.queries)
settings.DEBUG = True
def __exit__(self, type, value, traceback):
if not type:
num = len(connection.queries) - self.old_query_count
assert expected_num == num, _err_msg(num, expected_num, verbose)
settings.DEBUG = self.old_debug
return DecoratorOrContextManager()