passing commandline arguments to a selenium python webdriver test case - python-2.7

The following code is written using selenium python web driver which is run in saucelabs.I am providing the browser name,version and platform in a list,how do i do the same by providing the browser details through command line arguments? I am using py.test to execute the test cases.
import os
import sys
import httplib
import base64
import json
import new
import unittest
import sauceclient
from selenium import webdriver
from sauceclient import SauceClient
# it's best to remove the hardcoded defaults and always get these values
# from environment variables
USERNAME = os.environ.get('SAUCE_USERNAME', "ranjanprabhub")
ACCESS_KEY = os.environ.get('SAUCE_ACCESS_KEY', "ecec4dd0-d8da-49b9-b719-17e2c43d0165")
sauce = SauceClient(USERNAME, ACCESS_KEY)
browsers = [{"platform": "Mac OS X 10.9",
"browserName": "chrome",
"version": ""},
]
def on_platforms(platforms):
def decorator(base_class):
module = sys.modules[base_class.__module__].__dict__
for i, platform in enumerate(platforms):
d = dict(base_class.__dict__)
d['desired_capabilities'] = platform
name = "%s_%s" % (base_class.__name__, i + 1)
module[name] = new.classobj(name, (base_class,), d)
return decorator
#on_platforms(browsers)
class SauceSampleTest(unittest.TestCase):
def setUp(self):
self.desired_capabilities['name'] = self.id()
sauce_url = "http://%s:%s#ondemand.saucelabs.com:80/wd/hub"
self.driver = webdriver.Remote(
desired_capabilities=self.desired_capabilities,
command_executor=sauce_url % (USERNAME, ACCESS_KEY)
)
self.driver.implicitly_wait(30)
def test_sauce(self):
self.driver.get('http://saucelabs.com/test/guinea-pig')
assert "I am a page title - Sauce Labs" in self.driver.title
comments = self.driver.find_element_by_id('comments')
comments.send_keys('Hello! I am some example comments.'
' I should be in the page after submitting the form')
self.driver.find_element_by_id('submit').click()
commented = self.driver.find_element_by_id('your_comments')
assert ('Your comments: Hello! I am some example comments.'
' I should be in the page after submitting the form'
in commented.text)
body = self.driver.find_element_by_xpath('//body')
assert 'I am some other page content' not in body.text
self.driver.find_elements_by_link_text('i am a link')[0].click()
body = self.driver.find_element_by_xpath('//body')
assert 'I am some other page content' in body.text
def tearDown(self):
print("Link to your job: https://saucelabs.com/jobs/%s" % self.driver.session_id)
try:
if sys.exc_info() == (None, None, None):
sauce.jobs.update_job(self.driver.session_id, passed=True)
else:
sauce.jobs.update_job(self.driver.session_id, passed=False)
finally:
self.driver.quit()

So this is a bit complicated because you can pass an array of browsers into the #on_platforms decorator. My solution will only work for a single browser, as it looks like that's what you're doing right now.
For the current, single browser, situation -- you're looking for argparse. Here's my suggested fix:
import argparse
def setup_parser():
parser = argparse.ArgumentParser(description='Automation Testing!')
parser.add_argument('-p', '--platform', help='Platform for desired_caps', default='Mac OS X 10.9')
parser.add_argument('-b', '--browser-name', help='Browser Name for desired_caps', default='chrome')
parser.add_argument('-v', '--version', default='')
args = vars(parser.parse_args())
return args
desired_caps = setup_parser()
browsers = [desired_caps]
print browsers
But if you're looking to test multiple browsers (which I suggest you do!), you should not try and use command line arguments for the desired_caps of each individual browser. You should instead load a json config file for the browsers and the desired_caps for each one that you want Sauce to run.
Maybe have a different config file for each set of browsers, and then use command line arguments to pass in the config files you want to load.

Related

Start and Stop a periodically background Task with Django

I would like to make a bitcoin notification with Django. If managed to have a working Telegram bot that send the bitcoin stat when I ask him to do so. Now I would like him to send me a message if bitcoin reaches a specific value. There are some tutorials with running python script on server but not with Django. I read some answers and descriptions about django channels but couldn't adapt them to my project.
I would like to send, by telegram, a command about the amount and duration. Django would then start a process with these values and values of the channel I'm sending from in the background. If now, within the duration, the amount is reached, Django sends a message back to my channel. This should also be possible for more than one person.
Is these possible to do with Django out of the box, maybe with decorators, or do I need django-channels or something else?
Edit 2018-08-10:
Maybe my code explains a little bit better what I want to do.
import requests
import json
from datetime import datetime
from django.shortcuts import render
from django.http import HttpResponse
from django.conf import settings
from django.views.generic import TemplateView
from django.views.decorators.csrf
import csrf_exempt
class AboutView(TemplateView):
template_name = 'telapi/about.html'
bot_token = settings.BOT_TOKEN
def get_url(method):
return 'https://api.telegram.org/bot{}/{}'.format(bot_token, method)
def process_message(update):
data = {}
data['chat_id'] = update['message']['from']['id']
data['text'] = "I can hear you!"
r = requests.post(get_url('sendMessage'), data=data)
#csrf_exempt
def process_update(request, r_bot_token):
''' Method that is called from telegram-bot'''
if request.method == 'POST' and r_bot_token == bot_token:
update = json.loads(request.body.decode('utf-8'))
if 'message' in update:
if update['message']['text'] == 'give me news':
new_bitcoin_price(update)
else:
process_message(update)
return HttpResponse(status=200)
bitconin_api_uri = 'https://api.coinmarketcap.com/v2/ticker/1/?convert=EUR'
# response = requests.get(bitconin_api_uri)
def get_latest_bitcoin_price():
response = requests.get(bitconin_api_uri)
response_json = response.json()
euro_price = float(response_json['data']['quotes']['EUR']['price'])
timestamp = int(response_json['metadata']['timestamp'])
date = datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
return euro_price, date
def new_bitcoin_price(update):
data = {}
data['chat_id'] = update['message']['from']['id']
euro_price, date = get_latest_bitcoin_price()
data['text'] = "Aktuel ({}) beträgt der Preis {:.2f}€".format(
date, euro_price)
r = requests.post(get_url('sendMessage'), data=data)
Edit 2018-08-13:
I think the solution would be celery-beat and channels. Does anyone know a good tutorial?
One of my teammates uses django-celery-beat, that is available at https://github.com/celery/django-celery-beat to do this and he gave me some excellent feedback from it. You can schedule the celery tasks using the crontab syntax.
I had same issue, there are several typical approaches: Celery, Django-Channels, etc.
But you can avoid them all with simple approach: https://docs.djangoproject.com/en/2.1/howto/custom-management-commands/
I have used django commands in my project to run periodically tasks to rebuild users statistics:
Implement yourself application command, for example your application name is myapp and you have placed my_periodic_task.py in myapp/management/commands folder, so you can run your task once by typing python manage.py my_periodic_task
place beside manage.py file new file for example background.py with same code:
-
import os
from subprocess import call
BASE = os.path.dirname(__file__)
MANAGE_BASE = os.path.join(BASE, 'manage.py')
while True:
sleep(YOUR_TIMEOUT)
call(['python', MANAGE_BASE , 'my_periodic_task'])
Run your server for example: python background.py & python manage.py runserver 0.0.0.0:8000

Automating pulling csv files off google Trends

pyGTrends does not seem to work. Giving errors in Python.
pyGoogleTrendsCsvDownloader seems to work, logs in, but after getting 1-3 requests (per day!) complains about exhausted quota, even though manual download with the same login/IP works flawlessly.
Bottom line: neither work. Searching through stackoverflow: many questions from people trying to pull csv's from Google, but no workable solution I could find...
Thank you in advance: whoever will be able to help. How should the code be changed? Do you know of another solution that works?
Here's the code of pyGoogleTrendsCsvDownloader.py
import httplib
import urllib
import urllib2
import re
import csv
import lxml.etree as etree
import lxml.html as html
import traceback
import gzip
import random
import time
import sys
from cookielib import Cookie, CookieJar
from StringIO import StringIO
class pyGoogleTrendsCsvDownloader(object):
'''
Google Trends Downloader
Recommended usage:
from pyGoogleTrendsCsvDownloader import pyGoogleTrendsCsvDownloader
r = pyGoogleTrendsCsvDownloader(username, password)
r.get_csv(cat='0-958', geo='US-ME-500')
'''
def __init__(self, username, password):
'''
Provide login and password to be used to connect to Google Trends
All immutable system variables are also defined here
'''
# The amount of time (in secs) that the script should wait before making a request.
# This can be used to throttle the downloading speed to avoid hitting servers too hard.
# It is further randomized.
self.download_delay = 0.25
self.service = "trendspro"
self.url_service = "http://www.google.com/trends/"
self.url_download = self.url_service + "trendsReport?"
self.login_params = {}
# These headers are necessary, otherwise Google will flag the request at your account level
self.headers = [('User-Agent', 'Mozilla/5.0 (Windows NT 6.1; WOW64; rv:12.0) Gecko/20100101 Firefox/12.0'),
("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"),
("Accept-Language", "en-gb,en;q=0.5"),
("Accept-Encoding", "gzip, deflate"),
("Connection", "keep-alive")]
self.url_login = 'https://accounts.google.com/ServiceLogin?service='+self.service+'&passive=1209600&continue='+self.url_service+'&followup='+self.url_service
self.url_authenticate = 'https://accounts.google.com/accounts/ServiceLoginAuth'
self.header_dictionary = {}
self._authenticate(username, password)
def _authenticate(self, username, password):
'''
Authenticate to Google:
1 - make a GET request to the Login webpage so we can get the login form
2 - make a POST request with email, password and login form input values
'''
# Make sure we get CSV results in English
ck = Cookie(version=0, name='I4SUserLocale', value='en_US', port=None, port_specified=False, domain='www.google.com', domain_specified=False,domain_initial_dot=False, path='/trends', path_specified=True, secure=False, expires=None, discard=False, comment=None, comment_url=None, rest=None)
self.cj = CookieJar()
self.cj.set_cookie(ck)
self.opener = urllib2.build_opener(urllib2.HTTPCookieProcessor(self.cj))
self.opener.addheaders = self.headers
# Get all of the login form input values
find_inputs = etree.XPath("//form[#id='gaia_loginform']//input")
try:
#
resp = self.opener.open(self.url_login)
if resp.info().get('Content-Encoding') == 'gzip':
buf = StringIO( resp.read())
f = gzip.GzipFile(fileobj=buf)
data = f.read()
else:
data = resp.read()
xmlTree = etree.fromstring(data, parser=html.HTMLParser(recover=True, remove_comments=True))
for input in find_inputs(xmlTree):
name = input.get('name')
if name:
name = name.encode('utf8')
value = input.get('value', '').encode('utf8')
self.login_params[name] = value
except:
print("Exception while parsing: %s\n" % traceback.format_exc())
self.login_params["Email"] = username
self.login_params["Passwd"] = password
params = urllib.urlencode(self.login_params)
self.opener.open(self.url_authenticate, params)
def get_csv(self, throttle=False, **kwargs):
'''
Download CSV reports
'''
# Randomized download delay
if throttle:
r = random.uniform(0.5 * self.download_delay, 1.5 * self.download_delay)
time.sleep(r)
params = {
'export': 1
}
params.update(kwargs)
params = urllib.urlencode(params)
r = self.opener.open(self.url_download + params)
# Make sure everything is working ;)
if not r.info().has_key('Content-Disposition'):
print "You've exceeded your quota. Continue tomorrow..."
sys.exit(0)
if r.info().get('Content-Encoding') == 'gzip':
buf = StringIO( r.read())
f = gzip.GzipFile(fileobj=buf)
data = f.read()
else:
data = r.read()
myFile = open('trends_%s.csv' % '_'.join(['%s-%s' % (key, value) for (key, value) in kwargs.items()]), 'w')
myFile.write(data)
myFile.close()
Although I don't know python, I may have a solution. I am currently doing the same thing in C# and though I didn't get the .csv file, I got created a custom URL through code and then downloaded that HTML and saved to a text file (also through code). In this HTML (at line 12) is all the information needed to create the graph that is used on Google Trends. However, this has alot of unnecessary text within it that needs to be cut down. But either way, you end up with the same result. The Google Trends data. I posted a more detailed answer to my question here:
Downloading .csv file from Google Trends
There is an alternative module named pytrends - https://pypi.org/project/pytrends/ It is really cool. I would recommend this.
Example usage:
import numpy as np
import pandas as pd
from pytrends.request import TrendReq
pytrend = TrendReq()
#It is the term that you want to search
pytrend.build_payload(kw_list=["Eminem is the Rap God"])
# Find which region has searched the term
df = pytrend.interest_by_region()
df.to_csv("path\Eminem_InterestbyRegion.csv")
Potentially if you have a list of terms to search you could make use of "for loop" to automate the insights as per your wish.

how to unit test file upload in django

In my django app, I have a view which accomplishes file upload.The core snippet is like this
...
if (request.method == 'POST'):
if request.FILES.has_key('file'):
file = request.FILES['file']
with open(settings.destfolder+'/%s' % file.name, 'wb+') as dest:
for chunk in file.chunks():
dest.write(chunk)
I would like to unit test the view.I am planning to test the happy path as well as the fail path..ie,the case where the request.FILES has no key 'file' , case where request.FILES['file'] has None..
How do I set up the post data for the happy path?Can somebody tell me?
I used to do the same with open('some_file.txt') as fp: but then I needed images, videos and other real files in the repo and also I was testing a part of a Django core component that is well tested, so currently this is what I have been doing:
from django.core.files.uploadedfile import SimpleUploadedFile
def test_upload_video(self):
video = SimpleUploadedFile("file.mp4", "file_content", content_type="video/mp4")
self.client.post(reverse('app:some_view'), {'video': video})
# some important assertions ...
In Python 3.5+ you need to use bytes object instead of str. Change "file_content" to b"file_content"
It's been working fine, SimpleUploadedFile creates an InMemoryFile that behaves like a regular upload and you can pick the name, content and content type.
From Django docs on Client.post:
Submitting files is a special case. To POST a file, you need only
provide the file field name as a key, and a file handle to the file
you wish to upload as a value. For example:
c = Client()
with open('wishlist.doc') as fp:
c.post('/customers/wishes/', {'name': 'fred', 'attachment': fp})
I recommend you to take a look at Django RequestFactory. It's the best way to mock data provided in the request.
Said that, I found several flaws in your code.
"unit" testing means to test just one "unit" of functionality. So,
if you want to test that view you'd be testing the view, and the file
system, ergo, not really unit test. To make this point more clear. If
you run that test, and the view works fine, but you don't have
permissions to save that file, your test would fail because of that.
Other important thing is test speed. If you're doing something like
TDD the speed of execution of your tests is really important.
Accessing any I/O is not a good idea.
So, I recommend you to refactor your view to use a function like:
def upload_file_to_location(request, location=None): # Can use the default configured
And do some mocking on that. You can use Python Mock.
PS: You could also use Django Test Client But that would mean that you're adding another thing more to test, because that client make use of Sessions, middlewares, etc. Nothing similar to Unit Testing.
I do something like this for my own event related application but you should have more than enough code to get on with your own use case
import tempfile, csv, os
class UploadPaperTest(TestCase):
def generate_file(self):
try:
myfile = open('test.csv', 'wb')
wr = csv.writer(myfile)
wr.writerow(('Paper ID','Paper Title', 'Authors'))
wr.writerow(('1','Title1', 'Author1'))
wr.writerow(('2','Title2', 'Author2'))
wr.writerow(('3','Title3', 'Author3'))
finally:
myfile.close()
return myfile
def setUp(self):
self.user = create_fuser()
self.profile = ProfileFactory(user=self.user)
self.event = EventFactory()
self.client = Client()
self.module = ModuleFactory()
self.event_module = EventModule.objects.get_or_create(event=self.event,
module=self.module)[0]
add_to_admin(self.event, self.user)
def test_paper_upload(self):
response = self.client.login(username=self.user.email, password='foz')
self.assertTrue(response)
myfile = self.generate_file()
file_path = myfile.name
f = open(file_path, "r")
url = reverse('registration_upload_papers', args=[self.event.slug])
# post wrong data type
post_data = {'uploaded_file': i}
response = self.client.post(url, post_data)
self.assertContains(response, 'File type is not supported.')
post_data['uploaded_file'] = f
response = self.client.post(url, post_data)
import_file = SubmissionImportFile.objects.all()[0]
self.assertEqual(SubmissionImportFile.objects.all().count(), 1)
#self.assertEqual(import_file.uploaded_file.name, 'files/registration/{0}'.format(file_path))
os.remove(myfile.name)
file_path = import_file.uploaded_file.path
os.remove(file_path)
I did something like that :
from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import TestCase
from django.core.urlresolvers import reverse
from django.core.files import File
from django.utils.six import BytesIO
from .forms import UploadImageForm
from PIL import Image
from io import StringIO
def create_image(storage, filename, size=(100, 100), image_mode='RGB', image_format='PNG'):
"""
Generate a test image, returning the filename that it was saved as.
If ``storage`` is ``None``, the BytesIO containing the image data
will be passed instead.
"""
data = BytesIO()
Image.new(image_mode, size).save(data, image_format)
data.seek(0)
if not storage:
return data
image_file = ContentFile(data.read())
return storage.save(filename, image_file)
class UploadImageTests(TestCase):
def setUp(self):
super(UploadImageTests, self).setUp()
def test_valid_form(self):
'''
valid post data should redirect
The expected behavior is to show the image
'''
url = reverse('image')
avatar = create_image(None, 'avatar.png')
avatar_file = SimpleUploadedFile('front.png', avatar.getvalue())
data = {'image': avatar_file}
response = self.client.post(url, data, follow=True)
image_src = response.context.get('image_src')
self.assertEquals(response.status_code, 200)
self.assertTrue(image_src)
self.assertTemplateUsed('content_upload/result_image.html')
create_image function will create image so you don't need to give static path of image.
Note : You can update code as per you code.
This code for Python 3.6.
from rest_framework.test import force_authenticate
from rest_framework.test import APIRequestFactory
factory = APIRequestFactory()
user = User.objects.get(username='#####')
view = <your_view_name>.as_view()
with open('<file_name>.pdf', 'rb') as fp:
request=factory.post('<url_path>',{'file_name':fp})
force_authenticate(request, user)
response = view(request)
As mentioned in Django's official documentation:
Submitting files is a special case. To POST a file, you need only provide the file field name as a key, and a file handle to the file you wish to upload as a value. For example:
c = Client()
with open('wishlist.doc') as fp:
c.post('/customers/wishes/', {'name': 'fred', 'attachment': fp})
More Information: How to check if the file is passed as an argument to some function?
While testing, sometimes we want to make sure that the file is passed as an argument to some function.
e.g.
...
class AnyView(CreateView):
...
def post(self, request, *args, **kwargs):
attachment = request.FILES['attachment']
# pass the file as an argument
my_function(attachment)
...
In tests, use Python's mock something like this:
# Mock 'my_function' and then check the following:
response = do_a_post_request()
self.assertEqual(mock_my_function.call_count, 1)
self.assertEqual(
mock_my_function.call_args,
call(response.wsgi_request.FILES['attachment']),
)
if you want to add other data with file upload then follow the below method
file = open('path/to/file.txt', 'r', encoding='utf-8')
data = {
'file_name_to_receive_on_backend': file,
'param1': 1,
'param2': 2,
.
.
}
response = self.client.post("/url/to/view", data, format='multipart')`
The only file_name_to_receive_on_backend will be received as a file other params received normally as post paramas.
In Django 1.7 there's an issue with the TestCase wich can be resolved by using open(filepath, 'rb') but when using the test client we have no control over it. I think it's probably best to ensure file.read() returns always bytes.
source: https://code.djangoproject.com/ticket/23912, by KevinEtienne
Without rb option, a TypeError is raised:
TypeError: sequence item 4: expected bytes, bytearray, or an object with the buffer interface, str found
from django.test import Client
from requests import Response
client = Client()
with open(template_path, 'rb') as f:
file = SimpleUploadedFile('Name of the django file', f.read())
response: Response = client.post(url, format='multipart', data={'file': file})
Hope this helps.
Very handy solution with mock
from django.test import TestCase, override_settings
#use your own client request factory
from my_framework.test import APIClient
from django.core.files import File
import tempfile
from pathlib import Path
import mock
image_mock = mock.MagicMock(spec=File)
image_mock.name = 'image.png' # or smt else
class MyTest(TestCase):
# I assume we want to put this file in storage
# so to avoid putting garbage in our MEDIA_ROOT
# we're using temporary storage for test purposes
#override_settings(MEDIA_ROOT=Path(tempfile.gettempdir()))
def test_send_file(self):
client = APIClient()
client.post(
'/endpoint/'
{'file':image_mock},
format="multipart"
)
I am using Python==3.8.2 , Django==3.0.4, djangorestframework==3.11.0
I tried self.client.post but got a Resolver404 exception.
Following worked for me:
import requests
upload_url='www.some.com/oaisjdoasjd' # your url to upload
with open('/home/xyz/video1.webm', 'rb') as video_file:
# if it was a text file we would perhaps do
# file = video_file.read()
response_upload = requests.put(
upload_url,
data=video_file,
headers={'content-type': 'video/webm'}
)
I am using django rest framework and I had to test the upload of multiple files.
I finally get it by using format="multipart" in my APIClient.post request.
from rest_framework.test import APIClient
...
self.client = APIClient()
with open('./photo.jpg', 'rb') as fp:
resp = self.client.post('/upload/',
{'images': [fp]},
format="multipart")
I am using GraphQL, upload for test:
with open('test.jpg', 'rb') as fp:
response = self.client.execute(query, variables, data={'image': [fp]})
code in class mutation
#classmethod
def mutate(cls, root, info, **kwargs):
if image := info.context.FILES.get("image", None):
kwargs["image"] = image
TestingMainModel.objects.get_or_create(
id=kwargs["id"],
defaults=kwargs
)

Getting all queries that django run on postgresql

I am working on a django-postgresql project and I need to see every query that django run on database(so I can fine-tune queries). Is there a way to get those queries.
Update: My development environment is on ubuntu linux
Well, you could just set the pgsql server to log every query. Or just to log the slow ones. Look in the postgresql.conf file, it's pretty close to self-documenting.
Check out this Question (and the two top most answers):
django orm, how to view (or log) the executed query?
You can also have a look at the Djando documenation:
https://docs.djangoproject.com/en/dev/faq/models/#how-can-i-see-the-raw-sql-queries-django-is-running
Hope this helps,
Anton
You can decorate a request handler or other function with this and it will print the sql nicely formated with totals at the end.
from functools import wraps
from django.utils import termcolors
format_ok = termcolors.make_style(opts=('bold',), fg='green')
format_warning = termcolors.make_style(opts=('bold',), fg='yellow')
format_error = termcolors.make_style(opts=('bold',), fg='red')
try:
from pygments import highlight
from pygments.lexers import SqlLexer
from pygments.formatters import TerminalFormatter
pygments_sql_lexer = SqlLexer()
pygments_terminal_formatter = TerminalFormatter()
highlight_sql = lambda s: highlight(s, pygments_sql_lexer,
pygments_terminal_formatter)
except ImportError:
highlight_sql = lambda s: s
def debug_sql(f):
"""
Turn SQL statement debugging on for a test run.
"""
#wraps(f)
def wrapper(*a, **kw):
from django.conf import settings
from django.db import connection
try:
debug = settings.DEBUG
settings.DEBUG = True
connection.queries = []
return f(*a, **kw)
finally:
total_time = 0
for q in connection.queries:
fmt = format_ok
t = float(q['time'])
total_time += t
if t > 1:
fmt = format_error
elif t > 0.3:
fmt = format_warning
print '[%s] %s' % (fmt(q['time']), highlight_sql(q['sql']))
print "total time =", total_time
print "num queries =", len(connection.queries)
settings.DEBUG = debug
return wrapper
Try the django debug toolbar. It'll show you all the SQL executed over the request. When something is executing way too many queries, it becomes really slow, though. For that, I've been meaning to try out this profiler. However, I've rolled this middleware on a couple of projects:
try:
from cStringIO import StringIO
except ImportError:
import StringIO
from django.conf import settings
from django.db import connection
class DatabaseProfilerMiddleware(object):
def can(self, request):
return settings.DEBUG and 'dbprof' in request.GET
def process_response(self, request, response):
if self.can(request):
out = StringIO()
out.write('time sql\n')
total_time = 0
for query in reversed(sorted(connection.queries, key=lambda x: x['time'])):
total_time += float(query['time'])*1000
out.write('%s %s\n' % (query['time'], query['sql']))
response.content = '<pre style="white-space:pre-wrap">%d queries executed in %.3f seconds\n%s</pre>' \
% (len(connection.queries), total_time/1000, out.getvalue())
return response
Just go to the relevant URL for the request you are interested in and add a dbprof GET parameter, you'll see the profiling output instead of the normal response.

Django: is there a way to count SQL queries from an unit test?

I am trying to find out the number of queries executed by a utility function. I have written a unit test for this function and the function is working well. What I would like to do is track the number of SQL queries executed by the function so that I can see if there is any improvement after some refactoring.
def do_something_in_the_database():
# Does something in the database
# return result
class DoSomethingTests(django.test.TestCase):
def test_function_returns_correct_values(self):
self.assertEqual(n, <number of SQL queries executed>)
EDIT: I found out that there is a pending Django feature request for this. However the ticket is still open. In the meantime is there another way to go about this?
Since Django 1.3 there is a assertNumQueries available exactly for this purpose.
One way to use it (as of Django 3.2) is as a context manager:
# measure queries of some_func and some_func2
with self.assertNumQueries(2):
result = some_func()
result2 = some_func2()
Vinay's response is correct, with one minor addition.
Django's unit test framework actually sets DEBUG to False when it runs, so no matter what you have in settings.py, you will not have anything populated in connection.queries in your unit test unless you re-enable debug mode. The Django docs explain the rationale for this as:
Regardless of the value of the DEBUG setting in your configuration file, all Django tests run with DEBUG=False. This is to ensure that the observed output of your code matches what will be seen in a production setting.
If you're certain that enabling debug will not affect your tests (such as if you're specifically testing DB hits, as it sounds like you are), the solution is to temporarily re-enable debug in your unit test, then set it back afterward:
def test_myself(self):
from django.conf import settings
from django.db import connection
settings.DEBUG = True
connection.queries = []
# Test code as normal
self.assert_(connection.queries)
settings.DEBUG = False
If you are using pytest, pytest-django has django_assert_num_queries fixture for this purpose:
def test_queries(django_assert_num_queries):
with django_assert_num_queries(3):
Item.objects.create('foo')
Item.objects.create('bar')
Item.objects.create('baz')
If you don't want use TestCase (with assertNumQueries) or change settings to DEBUG=True, you can use context manager CaptureQueriesContext (same as assertNumQueries using).
from django.db import ConnectionHandler
from django.test.utils import CaptureQueriesContext
DB_NAME = "default" # name of db configured in settings you want to use - "default" is standard
connection = ConnectionHandler()[DB_NAME]
with CaptureQueriesContext(connection) as context:
... # do your thing
num_queries = context.initial_queries - context.final_queries
assert num_queries == expected_num_queries
db settings
In modern Django (>=1.8) it's well documented (it's also documented for 1.7) here, you have the method reset_queries instead of assigning connection.queries=[] which indeed is raising an error, something like that works on django>=1.8:
class QueriesTests(django.test.TestCase):
def test_queries(self):
from django.conf import settings
from django.db import connection, reset_queries
try:
settings.DEBUG = True
# [... your ORM code ...]
self.assertEquals(len(connection.queries), num_of_expected_queries)
finally:
settings.DEBUG = False
reset_queries()
You may also consider resetting queries on setUp/tearDown to ensure queries are reset for each test instead of doing it on finally clause, but this way is more explicit (although more verbose), or you can use reset_queries in the try clause as many times as you need to evaluate queries counting from 0.
Here is the working prototype of context manager withAssertNumQueriesLessThan
import json
from contextlib import contextmanager
from django.test.utils import CaptureQueriesContext
from django.db import connections
#contextmanager
def withAssertNumQueriesLessThan(self, value, using='default', verbose=False):
with CaptureQueriesContext(connections[using]) as context:
yield # your test will be run here
if verbose:
msg = "\r\n%s" % json.dumps(context.captured_queries, indent=4)
else:
msg = None
self.assertLess(len(context.captured_queries), value, msg=msg)
It can be simply used in your unit tests for example for checking the number of queries per Django REST API call
with self.withAssertNumQueriesLessThan(10):
response = self.client.get('contacts/')
self.assertEqual(response.status_code, 200)
Also you can provide exact DB using and verbose if you want to pretty-print list of actual queries to stdout
If you have DEBUG set to True in your settings.py (presumably so in your test environment) then you can count queries executed in your test as follows:
from django.db import connection
class DoSomethingTests(django.test.TestCase):
def test_something_or_other(self):
num_queries_old = len(connection.queries)
do_something_in_the_database()
num_queries_new = len(connection.queries)
self.assertEqual(n, num_queries_new - num_queries_old)
If you want to use a decorator for that there is a nice gist:
import functools
import sys
import re
from django.conf import settings
from django.db import connection
def shrink_select(sql):
return re.sub("^SELECT(.+)FROM", "SELECT .. FROM", sql)
def shrink_update(sql):
return re.sub("SET(.+)WHERE", "SET .. WHERE", sql)
def shrink_insert(sql):
return re.sub("\((.+)\)", "(..)", sql)
def shrink_sql(sql):
return shrink_update(shrink_insert(shrink_select(sql)))
def _err_msg(num, expected_num, verbose, func=None):
func_name = "%s:" % func.__name__ if func else ""
msg = "%s Expected number of queries is %d, actual number is %d.\n" % (func_name, expected_num, num,)
if verbose > 0:
queries = [query['sql'] for query in connection.queries[-num:]]
if verbose == 1:
queries = [shrink_sql(sql) for sql in queries]
msg += "== Queries == \n" +"\n".join(queries)
return msg
def assertNumQueries(expected_num, verbose=1):
class DecoratorOrContextManager(object):
def __call__(self, func): # decorator
#functools.wraps(func)
def inner(*args, **kwargs):
handled = False
try:
self.__enter__()
return func(*args, **kwargs)
except:
self.__exit__(*sys.exc_info())
handled = True
raise
finally:
if not handled:
self.__exit__(None, None, None)
return inner
def __enter__(self):
self.old_debug = settings.DEBUG
self.old_query_count = len(connection.queries)
settings.DEBUG = True
def __exit__(self, type, value, traceback):
if not type:
num = len(connection.queries) - self.old_query_count
assert expected_num == num, _err_msg(num, expected_num, verbose)
settings.DEBUG = self.old_debug
return DecoratorOrContextManager()