I have troubles checking the user token inside of middleware. I'm getting token from cookies and then I need to query database to check if this token exists and belongs to user that made a request.
routing.py
from channels.routing import ProtocolTypeRouter, URLRouter
import game.routing
from authentication.utils import TokenAuthMiddlewareStack
application = ProtocolTypeRouter({
# (http->django views is added by default)
'websocket': TokenAuthMiddlewareStack(
URLRouter(
game.routing.websocket_urlpatterns
)
),
})
middleware.py
from rest_framework.authentication import TokenAuthentication
from rest_framework.exceptions import AuthenticationFailed
from rest_auth.models import TokenModel
from channels.auth import AuthMiddlewareStack
from django.contrib.auth.models import AnonymousUser
from django.db import close_old_connections
...
class TokenAuthMiddleware:
"""
Token authorization middleware for Django Channels 2
"""
def __init__(self, inner):
self.inner = inner
def __call__(self, scope):
close_old_connections()
headers = dict(scope['headers'])
if b'Authorization' in headers[b'cookie']:
try:
cookie_str = headers[b'cookie'].decode('utf-8')
try: # no cookie Authorization=Token in the request
token_str = [x for x in cookie_str.split(';') if re.search(' Authorization=Token', x)][0].strip()
except IndexError:
scope['user'] = AnonymousUser()
return self.inner(scope)
token_name, token_key = token_str.replace('Authorization=', '').split()
if token_name == 'Token':
token = TokenModel.objects.get(key=token_key)
scope['user'] = token.user
except TokenModel.DoesNotExist:
scope['user'] = AnonymousUser()
return self.inner(scope)
TokenAuthMiddlewareStack = lambda inner: TokenAuthMiddleware(AuthMiddlewareStack(inner))
And this gives me
django.core.exceptions.SynchronousOnlyOperation: You cannot call this from an async context - use a thread or sync_to_async.
I also tried the following approaches
async def __call__(self, scope):
...
if token_name == 'Token':
token = await self.get_token(token_key)
scope['user'] = token.user
...
# approach 1
#sync_to_async
def get_token(self, token_key):
return TokenModel.objects.get(key=token_key)
# approach 2
#database_sync_to_async
def get_token(self, token_key):
return TokenModel.objects.get(key=token_key)
Those approaches give the following error
[Failure instance: Traceback: <class 'TypeError'>: 'coroutine' object is not callable
/Users/nikitatonkoshkur/Documents/work/svoya_igra/venv/lib/python3.8/site-packages/autobahn/websocket/protocol.py:2847:processHandshake
/Users/nikitatonkoshkur/Documents/work/svoya_igra/venv/lib/python3.8/site-packages/txaio/tx.py:366:as_future
/Users/nikitatonkoshkur/Documents/work/svoya_igra/venv/lib/python3.8/site-packages/twisted/internet/defer.py:151:maybeDeferred
/Users/nikitatonkoshkur/Documents/work/svoya_igra/venv/lib/python3.8/site-packages/daphne/ws_protocol.py:72:onConnect
--- <exception caught here> ---
/Users/nikitatonkoshkur/Documents/work/svoya_igra/venv/lib/python3.8/site-packages/twisted/internet/defer.py:151:maybeDeferred
/Users/nikitatonkoshkur/Documents/work/svoya_igra/venv/lib/python3.8/site-packages/daphne/server.py:206:create_application
]```
I am not sure if it will work or not,but you can try
First write the get_token function outside the class.
# approach 1
#sync_to_async
def get_token(self, token_key):
return TokenModel.objects.get(key=token_key)
then in your async function write get_token() instead of self.get_token()
async def __call__(self, scope):
...
if token_name == 'Token':
token = await get_token(token_key)
scope['user'] = token.user
...
Having looked at the source code of django-channels, especially at how SessionMiddleware works and uses django ORM, I ended up writing my TokenMiddleware in a similar fashion.
class TokenAuthMiddlewareInstance:
"""
Token authorization middleware for Django Channels 2
"""
def __init__(self, scope, middleware):
self.middleware = middleware
self.scope = dict(scope)
self.inner = self.middleware.inner
async def __call__(self, receive, send):
close_old_connections()
headers = dict(self.scope.get('headers', {}))
if b'cookie' in headers:
cookie_dict = parse_cookie(headers[b'cookie'].decode("ascii"))
token = cookie_dict.get('Authorization', '')
token_name, token_key = token.replace('Authorization=', '').split()
if token_name == 'Token':
self.scope['user'] = await self.get_user_from_token(token_key)
inner = self.inner(self.scope)
return await inner(receive, send)
inner = self.inner(self.scope)
self.scope['user'] = AnonymousUser()
return await inner(receive, send)
#staticmethod
#database_sync_to_async
def get_user_from_token(token_key):
try:
return TokenModel.objects.select_related('user').get(key=token_key).user
except TokenModel.DoesNotExist:
return AnonymousUser()
class TokenAuthMiddleware:
def __init__(self, inner):
self.inner = inner
def __call__(self, scope):
return TokenAuthMiddlewareInstance(scope, self)
TokenAuthMiddlewareStack = lambda inner: TokenAuthMiddleware(AuthMiddlewareStack(inner))
Related
I would like to return a 401 message if the user is not enabled. When I try returning a response instead of a token it doesn't work which I understand to be because the serializer is expecting the token. How do I customise it to send a 401 response if the user is not enabled please?
My custom token class is as below:
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from rest_framework_simplejwt.views import TokenObtainPairView
from rest_framework import status
from rest_framework.response import Response
class CustomTokenObtainPairSerializer(TokenObtainPairSerializer):
#classmethod
def get_token(cls, user):
if user.is_enabled:
token = super().get_token(user)
# Add custom claims
token['name'] = user.name
token['gender'] = user.gender
return token
else:
return Response({'detail':'Account not enabled'}, status=status.HTTP_401_UNAUTHORIZED)
class CustomTokenObtainPairView(TokenObtainPairView):
serializer_class = CustomTokenObtainPairSerializer
The URL root looks like:
re_path(r'^authenticate/',CustomTokenObtainPairView.as_view(), name='authenticate'),
This is what works for me, it's a simple solution and gives me the result I need which is why I'm posting this as an answer.
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from rest_framework_simplejwt.views import TokenObtainPairView
from rest_framework import status
from rest_framework.response import Response
from rest_framework_simplejwt.exceptions import InvalidToken
class CustomTokenObtainPairSerializer(TokenObtainPairSerializer):
#classmethod
def get_token(cls, user):
if user.is_enabled:
token = super().get_token(user)
# Add custom claims
token['name'] = user.name
token['gender'] = user.gender
return token
else:
raise InvalidToken("User is not enabled.")
You can return some symbol like None in Python from get_token if the user is not enabled and then override the get method of CustomTokenObtainPairView to return 401 if the value of get_token is None.
EDIT: Found a better way, move the check of is_enabled to the post request from the serializer. Below is the code.
class CustomTokenObtainPairSerializer(TokenObtainPairSerializer):
#classmethod
def get_token(cls, user):
token = super().get_token(user)
# Add custom claims
token['name'] = user.name
token['gender'] = user.gender
return token
from rest_framework.permissions import IsAuthenticated
class CustomTokenObtainPairView(TokenObtainPairView):
permission_classes = [IsAuthenticated]
serializer_class = CustomTokenObtainPairSerializer
def post(self, request, *args, **kwargs):
is self.request.user.is_enabled:
return super().post(request, *args, **kwargs)
else:
return Response({'detail':'Account not enabled'}, status=status.HTTP_401_UNAUTHORIZED)
I'm trying to test a view of my project with the following TestCase:
def test_jump_story(self):
c = APIClient()
user = User.objects.get(username='test1')
c.login(username=user.username, password='123')
room_id = PokerRoom.objects.get(name='planning').id
room_index = PokerRoom.objects.get(name='planning').index
request = c.post(reverse('jumpstory', kwargs={'pk': room_id, 'index': room_index}))
c.force_authenticate(user=user)
self.assertEqual(200,request.status_code)
but it returns this <Response status_code=401, "application/json"> even using force_authenticate.
The view that i'm testing:
class jumpStory(APIView):
permission_classes = [IsAuthenticated]
def post(self, request, pk, index):
data= self.request.data
index = self.kwargs['index']
pk = self.kwargs['pk']
if PokerRoom.objects.filter(id=pk).exists():
body = {'index':index}
message_socket("JUMP_STORY", pk, body)
return Response({'success':"JUMP_STORY"}, status=200)
else:
return Response({'error':'message not sended'}, status=400)
What is wrong with my test?
Use APITestCase instead of django TestCase. from rest_framework.test import APITestCase.
from rest_framework.test import APITestCase
class MyTests(APITestCase)
def setUp(self):
user = User.objects.create(username=john)
user.set_password("1234")
user.save()
self.client.force_authenticate(user=user)
def test_jump_story(self):
# do your test
I have created a custom token authentication middleware.
from rest_framework.authtoken.models import Token
from django.contrib.auth.models import AnonymousUser
from django.db import close_old_connections
from asgiref.sync import sync_to_async
class TokenAuthMiddleware:
"""
Token authorization middleware for Django Channels 2
"""
def __init__(self, inner):
self.inner = inner
def __call__(self, scope):
# Close old database connections to prevent usage of timed out connections
sync_to_async(close_old_connections)()
headers = dict(scope['headers'])
try:
token_name, token_key = headers[b'sec-websocket-protocol'].decode().split(', ')
if token_name == 'Token':
token = sync_to_async(Token.objects.get, thread_sensitive=True)(key=token_name)
scope['user'] = token.user
else:
scope['user'] = AnonymousUser()
except Token.DoesNotExist:
scope['user'] = AnonymousUser()
return self.inner(scope)
When I run it, an exception happens when I run scope['user'] = token.user
[Failure instance: Traceback: <class 'AttributeError'>: 'coroutine' object has no attribute 'user'
I tried awaiting the Token query like this:
token = await sync_to_async(Token.objects.get, thread_sensitive=True)(key=token_name)
and I added async in front of the __call__ function, but then the following error is raised before any of the code inside the __call__ function runs:
[Failure instance: Traceback: <class 'TypeError'>: 'coroutine' object is not callable
I am using Django v3.0.6 and Django Channels v2.4.0
Here is the solution that worked for me:
from rest_framework.authtoken.models import Token
from django.contrib.auth.models import AnonymousUser
from channels.db import database_sync_to_async
#database_sync_to_async
def get_user(token):
try:
return Token.objects.get(key=token).user
except Token.DoesNotExist:
return AnonymousUser()
class TokenAuthMiddleware:
"""
Token authorization middleware for Django Channels 2
"""
def __init__(self, inner):
# Store the ASGI application we were passed
self.inner = inner
def __call__(self, scope):
return TokenAuthMiddlewareInstance(scope, self)
class TokenAuthMiddlewareInstance:
"""
Inner class that is instantiated once per scope.
"""
def __init__(self, scope, middleware):
self.middleware = middleware
self.scope = dict(scope)
self.inner = self.middleware.inner
async def __call__(self, receive, send):
headers = dict(self.scope['headers'])
token_name, token_key = headers[b'sec-websocket-protocol'].decode().split(', ')
if token_name == 'Token':
self.scope['user'] = await get_user(token_key)
else:
self.scope['user'] = AnonymousUser()
# Instantiate our inner application
inner = self.inner(self.scope)
return await inner(receive, send)
Just wrap your function in database_sync_to_async, it will handle the connections for you
class TokenAuthMiddleware:
"""
Token authorization middleware for Django Channels 2
"""
def __init__(self, inner):
self.inner = inner
async def __call__(self, scope):
# Close old database connections to prevent usage of timed out connections
sync_to_async(close_old_connections)()
headers = dict(scope['headers'])
try:
token_name, token_key = headers[b'sec-websocket-protocol'].decode().split(', ')
if token_name == 'Token':
user = await self.get_user(token)
scope['user'] = user
else:
scope['user'] = AnonymousUser()
except Token.DoesNotExist:
scope['user'] = AnonymousUser()
#database_sync_to_async
def get_user(self, token):
token = Token.ojbects.get(key=token)
return token.user
In case it helps anyone, I was having the same issue along with others. I updated django channels and then changed my routing from
websocket_urlpatterns = [
re_path(r'ws/chat/(?P<room_name>\w+)/$', consumers.ChatConsumer),
]
to
websocket_urlpatterns = [
re_path(r'ws/chat/(?P<room_name>\w+)/$', consumers.ChatConsumer.as_asgi()),
]
You need django channels 3.0+ to do this (https://channels.readthedocs.io/en/stable/releases/3.0.0.html). Then you can follow https://channels.readthedocs.io/en/stable/topics/authentication.html#django-authentication to setup your custom middleware.
Reference:
Django Channel Custom Authentication Middleware __call__() missing 2 required positional arguments: 'receive' and 'send'
I upgraded to Django 3.0 and now I get this error when using websockets + TokenAuthMiddleware:
SynchronousOnlyOperation
You cannot call this from an async context - use a thread or sync_to_async.
The problem is that you can't access synchronous code from an asynchronous context. Here is a TokenAuthMiddleware for Django 3.0:
# myproject.myapi.utils.py
from channels.auth import AuthMiddlewareStack
from channels.db import database_sync_to_async
from django.contrib.auth.models import AnonymousUser
from rest_framework.authtoken.models import Token
#database_sync_to_async
def get_user(headers):
try:
token_name, token_key = headers[b'authorization'].decode().split()
if token_name == 'Token':
token = Token.objects.get(key=token_key)
return token.user
except Token.DoesNotExist:
return AnonymousUser()
class TokenAuthMiddleware:
def __init__(self, inner):
self.inner = inner
def __call__(self, scope):
return TokenAuthMiddlewareInstance(scope, self)
class TokenAuthMiddlewareInstance:
"""
Yeah, this is black magic:
https://github.com/django/channels/issues/1399
"""
def __init__(self, scope, middleware):
self.middleware = middleware
self.scope = dict(scope)
self.inner = self.middleware.inner
async def __call__(self, receive, send):
headers = dict(self.scope['headers'])
if b'authorization' in headers:
self.scope['user'] = await get_user(headers)
inner = self.inner(self.scope)
return await inner(receive, send)
TokenAuthMiddlewareStack = lambda inner: TokenAuthMiddleware(AuthMiddlewareStack(inner))
Use it like this:
# myproject/routing.py
from myapi.utils import TokenAuthMiddlewareStack
from myapi.websockets import WSAPIConsumer
application = ProtocolTypeRouter({
"websocket": TokenAuthMiddlewareStack(
URLRouter([
path("api/v1/ws", WSAPIConsumer),
]),
),
})
application = SentryAsgiMiddleware(application)
As #tapion stated this solution doesn't work anymore since channels 3.x
Newer solution can be a little bit tweaked:
class TokenAuthMiddleware:
def __init__(self, inner):
self.inner = inner
async def __call__(self, scope, receive, send):
headers = dict(scope['headers'])
if b'authorization' in headers:
scope['user'] = await get_user(headers)
return await self.inner(scope, receive, send)
CoinbaseWalletAuth.py
from django import template
register = template.Library()
API_KEY = '******************'
API_SECRET = '***************'
class CoinbaseWalletAuth(AuthBase):
def __init__(self, api_key, secret_key):
self.api_key = api_key
self.secret_key = secret_key
def __call__(self, request):
timestamp = str(int(time.time()))
message = timestamp + request.method + request.path_url + (request.body or '')
signature = hmac.new(self.secret_key, message, hashlib.sha256).hexdigest()
request.headers.update({
'CB-ACCESS-SIGN': signature,
'CB-ACCESS-TIMESTAMP': timestamp,
'CB-ACCESS-KEY': self.api_key,
})
print('hello')
return request
register.tag('CoinbaseWalletAuth', CoinbaseWalletAuth(API_KEY,API_SECRET))
Views.py
def test(request):
if request.method == 'POST':
api_url = 'https://api.coinbase.com/v2/'
auth = CoinbaseWalletAuth # call the class based function in views(This is not working)
r = requests.get(api_url + 'user', auth=auth)
data= r.json()
return HttpResponse(data)
Answering your question
from .templatetags.CoinbaseWalletAuth import CoinbaseWalletAuth
def test(request):
if request.method == 'POST':
api_url = 'https://api.coinbase.com/v2/'
auth = CoinbaseWalletAuth # call the class based function in views(This is not working)
r = requests.get(api_url + 'user', auth=auth)
data= r.json()
return HttpResponse(data)
The point to use templatetags is to run somelogic while the template is rendering... in your views you already have access to write logic in python, if there some logic at your templatetag that you have to use in your application so you should write it as method (class, static or anyother) and in your templatetag call this method... so that way you can share this logic between your applcation
in some file at your project (maybe models maybe utils)
CoinbaseWalletAuth.py
API_KEY = '******************'
API_SECRET = '***************'
class CoinbaseWalletAuth(AuthBase):
def __init__(self, api_key, secret_key):
...
def __call__(self, request):
...
return request
coinbase_templatetag.py
from django import template
from .utils import CoinbaseWalletAuth # utils will the folder that you store that file
register = template.Library()
register.tag('CoinbaseWalletAuth', CoinbaseWalletAuth(API_KEY,API_SECRET))