I am trying to create test class for my custom middleware. The project is using Django REST framework. Middleware class works fine when server is running, but when I run test it behaves not quite as I would expect it to do. Maybe I misunderstood something, as I am quite new to testing in Django.
my_middleware.py:
class FX:
a = False
b = None
c = ''
def __init__(self) -> None:
pass
def __str__(self):
return 'fx ok'
class MyMiddleware(object):
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
request.fx = FX()
response = self.get_response(request)
print('done')
return response
views.py:
class TestView(APIView):
def get(self, request, format=None):
print('View ok')
print('FX: ', request.fx)
return Response({'result':'ok'})
tests.py:
class TestMyMiddleware(APITestCase):
#classmethod
def setUpTestData(cls):
pass
def setUp(self):
pass
def test_fx(self):
response = self.client.get(reverse('TestView'), content_type="application/json")
request = response.request
self.assertTrue(hasattr(request, 'fx'))
The code above actually runs the middleware. It prints "done" form the middleware call, then prints 'View ok' and also prints FX instance. However request.fx is not available in the test_fx method, thus giving assertion failure:
self.assertTrue(hasattr(request, 'fx'))
AssertionError: False is not true
Any idea what I might be doing wrong?
You need to access the request object from the response with response.wsgi_request instead of response.request.
class TestMyMiddleware(APITestCase):
#classmethod
def setUpTestData(cls):
pass
def setUp(self):
pass
def test_fx(self):
response = self.client.get(reverse('TestView'), content_type="application/json")
request = response.wsgi_request
self.assertTrue(hasattr(request, 'fx'))
I have a class based view that works as designed in the browser. I'm trying to write unit tests for the view and they keep failing. I'm wondering why. The view (the UserPassesTest is whether the user is a superuser or not):
class EditUserView(LoginRequiredMixin, UserPassesTestMixin, TemplateView):
"""handles get and post for adding a new AEUser"""
template_name = 'editUser.html'
title = 'Edit User'
def get(self, request, *args, **kwargs):
"""handles the GET"""
post_url = reverse('edit_user', args=[kwargs['user_id']])
usr = get_object_or_404(AEUser, pk=kwargs['user_id'])
form = EditUserForm(initial={'is_active':usr.is_active, 'is_superuser':usr.is_superuser}, \
user=usr, request=request)
return render(request, self.template_name, \
{'title_text':self.title, 'post_url':post_url, 'form':form})
The Test Case:
class TestEditUser(TestCase):
"""test the AddUser view"""
#classmethod
def setUpTestData(cls):
cls.user = AEUser.objects.create_user(username='shawn', email='shawn#gmail.com', password='test')
cls.user.is_superuser = True
cls.user.save()
def setUp(self):
self.client = Client()
def test_get(self):
"""tests the GET"""
self.client.login(username=self.user.username, password=self.user.password)
get_url = reverse('edit_user', args=[self.user.id])
response = self.client.get(get_url, follow=True)
self.assertEqual(self.user.is_superuser, True)
self.assertEqual(response.status_code, 200)
self.assertTemplateUsed(response, 'editUser.html')
I have 3 asserts in the test case. If I comment out the last two, and only assert that the user is a superuser, the test passes. For whatever reason, though, on the other two asserts, I get failures. The error I receive is:
AssertionError: False is not true : Template 'editUser.html' was not a template used to render the response. Actual template(s) used: 404.html, base.html, which leads me to believe the get_object_or_404 call is what's triggering the failure. Where am I going wrong with this test case? Thanks!
Test should be:
class TestEditUser(TestCase):
"""test the AddUser view"""
#classmethod
def setUpTestData(cls):
cls.user = AEUser.objects.create_user(username='shawn', email='shawn#gmail.com', password='test')
cls.user.is_superuser = True
cls.user.save()
def setUp(self):
self.client = Client()
def test_get(self):
"""tests the GET"""
self.client.login(username=self.user.username, password='test')
get_url = reverse('edit_user', args=[self.user.id])
response = self.client.get(get_url, follow=True)
self.assertEqual(self.user.is_superuser, True)
self.assertEqual(response.status_code, 200)
self.assertTemplateUsed(response, 'editUser.html')
I am writing a test for a DetailView that queries get_object() by accessing a value set in a Middleware. This is for a Companies application and Company Model. Each user is in a Company.
To access the company throughout the project, I set the current user's Company.uuid on the request via a custom middleware.
Middleware
from django.utils.deprecation import MiddlewareMixin
class DynamicCompanyUUIDMiddleware(MiddlewareMixin):
""" Adds the current organization's UUID from the current user."""
def process_request(self, request):
try:
company_uuid = request.user.company_uuid
except:
company_uuid = None
request.company_uuid = company_uuid
That is used in the CompanyDetailView's get_object() method via a Mixin that I use for the other Company Views.
Mixin
class CompanyMixin(LoginRequiredMixin, SetHeadlineMixin):
model = Company
def get_object(self):
return get_object_or_404(
self.model,
uuid=self.request.user.company_uuid)
Test
The test that I'm trying to write is:
from django.test import RequestFactory
from django.urls import reverse, resolve
from test_plus.test import TestCase
from ..models import Company
from ..views import CompanyDetailView
class BaseCompanyTestCase(TestCase):
def setUp(self):
self.user = self.make_user()
self.object = Company.objects.create(owner=self.user, name="testcompany")
self.user.company_uuid = self.object.uuid
self.factory = RequestFactory()
class TestCompanyDetailView(BaseCompanyTestCase):
def setUp(self):
super(TestCompanyDetailView, self).setUp()
self.client.login(username="testuser", password="password")
self.view = CompanyDetailView()
self.view.object = self.object
request = self.factory.get(reverse('companies:detail'))
request.user = self.user
request.company_uuid = self.user.company_uuid
response = CompanyDetailView.as_view()(request)
self.assertEqual(response.status_code, 200)
def test_get_headline(self):
self.assertEqual(
self.view.get_headline(),
'%s Members' % self.object.name
Result
This results in a 404 with the testuser's company not being found.
Walking through it:
I create the user
Create the company for this new testuser
Set the user.company_uuid
This should allow the mixin to access the company_uuid
Therefore return the user's company in the request
However I'm not returning the company as the 404 shows.
Question
Where am I going wrong on this? Thanks in advance for your help.
Answer
I was mixing Django's Client & RequestFactory. I have corrected the code above which is correct.
I was mixing Django's Client & RequestFactory. After stepping away, I figured it out below -
from django.test import RequestFactory
from django.urls import reverse, resolve
from test_plus.test import TestCase
from ..models import Company
from ..views import CompanyDetailView
class BaseCompanyTestCase(TestCase):
def setUp(self):
self.user = self.make_user()
self.object = Company.objects.create(owner=self.user, name="testcompany")
self.user.company_uuid = self.object.uuid
self.factory = RequestFactory()
class TestCompanyDetailView(BaseCompanyTestCase):
def setUp(self):
super(TestCompanyDetailView, self).setUp()
self.client.login(username="testuser", password="password")
self.view = CompanyDetailView()
self.view.object = self.object
request = self.factory.get(reverse('companies:detail'))
request.user = self.user
request.company_uuid = self.user.company_uuid
response = CompanyDetailView.as_view()(request)
self.assertEqual(response.status_code, 200)
def test_get_headline(self):
self.assertEqual(
self.view.get_headline(),
'%s Members' % self.object.name
I find message return in google, not find.
Whats my code not post values correct?
I need help for solution correct.
As use form based generic views?
Im desenv an restAPI, i not understanding problem in my code, i running and return:
I retrieve message, flow.
views.py :
from snippets.models import Equipamento, Colaborador
from snippets.serializers import EquipamentoSerializer, ColaboradorSerializer
from rest_framework import mixins
from rest_framework import generics
class EquipamentoList(generics.ListCreateAPIView):
serializer_class = EquipamentoSerializer
def get_queryset(self):
queryset = Equipamento.objects.all()
id = self.request.query_params.get('id', None)
if id is not None:
queryset = queryset.filter(id=id)
return queryset
# class ColaboradorList(generics.CreateAPIView):
# queryset = Colaborador.objects.all()
# serializer_class = ColaboradorSerializer
# def get_queryset(self):
# queryset = Colaborador.objects.all()
# id = self.request.query_params.get('id', None)
# if id is not None:
# queryset = queryset.filter(pk=pk)
# return queryset
# def create(self, request, pk):
# queryset = Colaborador.objects.all()
# return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
# class ColaboradorDetail(generics.RetrieveUpdateDestroyAPIView):
# queryset = Colaborador.objects.all()
# serializer_class = ColaboradorSerializer
class ColaboradorList(mixins.ListModelMixin,
mixins.CreateModelMixin,
generics.GenericAPIView):
queryset = Colaborador.objects.all()
serializer_class = ColaboradorSerializer
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
class ColaboradorDetail(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
generics.GenericAPIView):
queryset = Colaborador.objects.all()
serializer_class = ColaboradorSerializer
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
serializers.py
from rest_framework import serializers
from rest_framework.validators import UniqueValidator
from snippets.models import Equipamento, Colaborador, Propriedade, MotivoParada, Apontamento
class EquipamentoSerializer(serializers.Serializer):
id = serializers.IntegerField(read_only=True)
cod_equip = serializers.IntegerField(validators=[UniqueValidator(queryset=Equipamento.objects.all())])
desc_equip = serializers.CharField(allow_blank=True, max_length=15, required=False)
def restore_object(self, attrs, instance=None):
"""
Create or update a new snippet instance, given a dictionary
of deserialized field values.
Note that if we don't define this method, then deserializing
data will simply return a dictionary of items.
"""
if instance:
# Update existing instance
instance.id = attrs.get('id', instance.id)
instance.cod_equip = attrs.get('cod_equip', instance.cod_equip)
instance.des_equip = attrs.get('desc_equip', instance.desc_equip)
return instance
# Create new instance
return Equipamento(**attrs)
class ColaboradorSerializer(serializers.Serializer):
id = serializers.IntegerField(read_only=True)
cod_colab = serializers.IntegerField(validators=[UniqueValidator(queryset=Colaborador.objects.all())])
nome_colab = serializers.CharField(max_length=30)
def restore_object(self, attrs, instance=None):
"""
Create or update a new snippet instance, given a dictionary
of deserialized field values.
Note that if we don't define this method, then deserializing
data will simply return a dictionary of items.
"""
if instance:
# Update existing instance
instance.id = attrs.get('id', instance.id)
instance.cod_colab = attrs.get('cod_colab', instance.cod_colab)
instance.nome_colab = attrs.get('nome_colab', instance.nome_colab)
return instance
# Create new instance
return Colaborador(**attrs)
class ApontamentoSerializer(serializers.Serializer):
id = serializers.IntegerField(read_only=True)
criado = serializers.DateTimeField(read_only=True)
apont_inicio = serializers.TimeField()
apont_fim = serializers.TimeField()
duracao = serializers.TimeField()
equipamento = serializers.PrimaryKeyRelatedField(queryset=Equipamento.objects.all())
colaborador = serializers.PrimaryKeyRelatedField(queryset=Colaborador.objects.all())
propriedade = serializers.PrimaryKeyRelatedField(queryset=Propriedade.objects.all())
m_parada = serializers.PrimaryKeyRelatedField(queryset=MotivoParada.objects.all())
def restore_object(self, attrs, instance=None):
"""
Create or update a new snippet instance, given a dictionary
of deserialized field values.
Note that if we don't define this method, then deserializing
data will simply return a dictionary of items.
"""
if instance:
# Update existing instance
instance.id = attrs.get('id', instance.id)
instance.criado = attrs.get('criado', instance.criado)
instance.apont_inicio = attrs.get('apont_inicio', instance.apont_inicio)
instance.apont_fim = attrs.get('apont_fim', instance.apont_fim)
instance.duracao = attrs.get('duracao', instance.duracao)
instance.equipamento = attrs.get('equipamento', instance.equipamento)
instance.colaborador = attrs.get('colaborador', instance.colaborador)
instance.propriedade = attrs.get('propriedade', instance.propriedade)
instance.m_parada = attrs.get('m_parada', instance.m_parada)
return instance
# Create new instance
return Apontamento(**attrs)
class PropriedadeSerializer(serializers.Serializer):
id = serializers.IntegerField(read_only=True)
cod_prop = serializers.IntegerField(validators=[UniqueValidator(queryset=Propriedade.objects.all())])
desc_prop = serializers.CharField(max_length=30)
def restore_object(self, attrs, instance=None):
"""
Create or update a new snippet instance, given a dictionary
of deserialized field values.
Note that if we don't define this method, then deserializing
data will simply return a dictionary of items.
"""
if instance:
# Update existing instance
instance.id = attrs.get('id', instance.id)
instance.cod_prop = attrs.get('cod_prop', instance.cod_prop)
instance.des_prop = attrs.get('desc_prop', instance.desc_prop)
return instance
# Create new instance
return Propriedade(**attrs)
class MotivoParadaSerializer(serializers.Serializer):
id = serializers.IntegerField(read_only=True)
cod_mparada = serializers.IntegerField(validators=[UniqueValidator(queryset=MotivoParada.objects.all())])
desc_mparada = serializers.CharField(max_length=30)
def restore_object(self, attrs, instance=None):
"""
Create or update a new snippet instance, given a dictionary
of deserialized field values.
Note that if we don't define this method, then deserializing
data will simply return a dictionary of items.
"""
if instance:
# Update existing instance
instance.id = attrs.get('id', instance.id)
instance.cod_mparada = attrs.get('cod_mparada', instance.cod_mparada)
instance.des_mparada = attrs.get('desc_mparada', instance.desc_mparada)
return instance
# Create new instance
return MotivoParada(**attrs)
urls.py
from django.conf.urls import url
# from snippets.views import EquipamentoList, ColaboradorList, ColaboradorDetail
from snippets import views
from rest_framework.urlpatterns import format_suffix_patterns
urlpatterns = [
# url(r'^snippets/$', views.snippet_list),
# url(r'^snippets/(?P<pk>[0-9]+)/$', views.snippet_detail),
# url('^equipamento/(?P<id>.+)/$', EquipamentoList.as_view()),
#url('^colab/(?P<id>.+)/$', ColaboradorList.as_view()),
url('^colab/$', views.ColaboradorList.as_view()),
url('^colab_add/(?P<pk>[0-9]+)/$', views.ColaboradorDetail.as_view()),
]
urlpatterns = format_suffix_patterns(urlpatterns)I find message return in google, not find solv problem?
Help.
Whats my code not post values?
I need help for solution correct.
The error message is pretty straightforward. Your serializers use the restore_object method, which is deprecated in Rest Framework 3. Either downgrade your Django Rest Framework version to 2.x, or (recommended) rewrite your code to make it compatible with Rest Framework's latest version.
I have a decorator where I can verify if the user has any permission. The code is working for me but I want to write a test for it.
How can I test the any_permission_required function?
from django.contrib.auth.decorators import user_passes_test
def any_permission_required(*perms):
return user_passes_test(lambda u: any(u.has_perm(perm) for perm in perms))
#any_permission_required('app.ticket_admin', 'app.ticket_read')
def ticket_list(request):
...
Finally with help of Alasdair and the Django test code I found a solution.
from django.test import RequestFactory
class TestFoo(TestCase):
def setUp(self):
self.user = models.User.objects.create(username='foo', password='bar')
self.factory = RequestFactory()
def test_any_permissions_pass(self):
perms = Permission.objects.filter(codename__in=('ticket_admin', 'ticket_read'))
self.user.user_permissions.add(*perms)
#any_permission_required('app.ticket_admin', 'app.ticket_read')
def a_view(request):
return HttpResponse()
request = self.factory.get('/foo')
request.user = self.user
resp = a_view(request)
self.assertEqual(resp.status_code, 200)