Websocket stress test with Autobahn Testsuite - python-2.7

I try to do some stress test against my websocket server. On client side I run the following script from this site :
import time, sys
from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred, returnValue, inlineCallbacks
from autobahn.twisted.websocket import connectWS, \
WebSocketClientFactory, \
WebSocketClientProtocol
class MassConnectProtocol(WebSocketClientProtocol):
didHandshake = False
def onOpen(self):
print("websocket connection opened")
self.factory.test.onConnected()
self.factory.test.protos.append(self)
self.didHandshake = True
class MassConnectFactory(WebSocketClientFactory):
protocol = MassConnectProtocol
def clientConnectionFailed(self, connector, reason):
if self.test.onFailed():
reactor.callLater(float(self.retrydelay)/1000., connector.connect)
def clientConnectionLost(self, connector, reason):
if self.test.onLost():
reactor.callLater(float(self.retrydelay)/1000., connector.connect)
class MassConnect:
def __init__(self, name, uri, connections, batchsize, batchdelay, retrydelay):
print('MassConnect init')
self.name = name
self.uri = uri
self.batchsize = batchsize
self.batchdelay = batchdelay
self.retrydelay = retrydelay
self.failed = 0
self.lost = 0
self.targetCnt = connections
self.currentCnt = 0
self.actual = 0
self.protos = []
def run(self):
print('MassConnect runned')
self.d = Deferred()
self.started = time.clock()
self.connectBunch()
return self.d
def onFailed(self):
self.failed += 1
sys.stdout.write("!")
return True
def onLost(self):
self.lost += 1
#sys.stdout.write("*")
return False
return True
def onConnected(self):
print("onconnected")
self.actual += 1
if self.actual % self.batchsize == 0:
sys.stdout.write(".")
if self.actual == self.targetCnt:
self.ended = time.clock()
duration = self.ended - self.started
print " connected %d clients to %s at %s in %s seconds (retries %d = failed %d + lost %d)" % (self.currentCnt, self.name, self.uri, duration, self.failed + self.lost, self.failed, self.lost)
result = {'name': self.name,
'uri': self.uri,
'connections': self.targetCnt,
'retries': self.failed + self.lost,
'lost': self.lost,
'failed': self.failed,
'duration': duration}
for p in self.protos:
p.sendClose()
#self.d.callback(result)
def connectBunch(self):
if self.currentCnt + self.batchsize < self.targetCnt:
c = self.batchsize
redo = True
else:
c = self.targetCnt - self.currentCnt
redo = False
for i in xrange(0, c):
factory = MassConnectFactory(self.uri)
factory.test = self
factory.retrydelay = self.retrydelay
connectWS(factory)
self.currentCnt += 1
if redo:
reactor.callLater(float(self.batchdelay)/1000., self.connectBunch)
class MassConnectTest:
def __init__(self, spec):
self.spec = spec
print('MassConnetest init')
#inlineCallbacks
def run(self):
print self.spec
res = []
for s in self.spec['servers']:
print s['uri']
t = MassConnect(s['name'],
s['uri'],
self.spec['options']['connections'],
self.spec['options']['batchsize'],
self.spec['options']['batchdelay'],
self.spec['options']['retrydelay'])
r = yield t.run()
res.append(r)
returnValue(res)
def startClient(spec, debug = False):
test = MassConnectTest(spec)
d = test.run()
return d
if __name__ == '__main__':
spec = {}
spec['servers'] = [{'name': 'test', 'uri':"ws://127.0.0.1:8080"} ]
spec['options'] ={'connections': 1000,'batchsize': 500, 'batchdelay': 1000, 'retrydelay': 200 }
startClient(spec,False)
But after running this script there are no connections established on the server side. Server seems to be configured properly, because when I connect to my server using different client side (for example web browser), it works fine and websocket connection is established. I also checked network sniffer and it seems that script doesn't produce any websocket connections.
What did I do wrong in this script?

The massconnect.py script you used was supposed to be invoked from another part of the autobahntestsuite, such as the wstest command:
$ echo '{"servers": [{"name": "test", "uri":"ws://127.0.0.1:8080"} ], "options": {"connections": 1000,"batchsize": 500, "batchdelay": 1000, "retrydelay": 200 }}' > spec.json
$ wstest -m massconnect --spec spec.json
If you want to copy massconnect directly, I think it's missing the command to start the Twisted deferred tasks:
if __name__ == '__main__':
spec = {}
spec['servers'] = [{'name': 'test', 'uri':"ws://127.0.0.1:8080"} ]
spec['options'] ={'connections': 1000,'batchsize': 500, 'batchdelay': 1000, 'retrydelay': 200 }
startClient(spec,False)
reactor.run() # <-- add this
And check your Python indentations, either some of them got corrupted when pasting here, or the original code had incorrect indentations in some class and function definitions.

Related

boto3 EC2 script email: One email for all status checks

I am getting one email per instance that fails status checks. I want to get one email for all status checks.
Here is my code:
import boto3
import smtplib
client = boto3.client("ec2")
clientsns = boto3.client("sns")
status = client.describe_instance_status(IncludeAllInstances = True)
#failed_instances = []
for i in status["InstanceStatuses"]:
# failed_instances.append(i[{'Instance'})]
in_status = i['InstanceStatus']['Details'][0]['Status']
sys_status = i['SystemStatus']['Details'][0]['Status']
# check statuses failed instances
if ((in_status != 'passed') or (sys_status != 'passed')):
msg = f'The following instances failed status checks, {i["InstanceId"]}'
clientsns.publish(TopicArn='arn:aws:sns:us-west-1:462518063038:test',Message=msg)
Try something like this:
import boto3
import botocore
from boto3 import Session
boto3.setup_default_session(profile_name='account2')
def get_tag(tags, key='Name'):
if not tags: return ''
for tag in tags:
if tag['Key'] == key:
return tag['Value']
return ''
client = boto3.client("ec2")
conn = boto3.resource('ec2')
#instances = conn.instances.filter()
instances = conn.instances.filter(
Filters=[{'Name': 'instance-state-name', 'Values': ['running']}])
filter_for = {
"running": [{"Name": "instance-state-name", "Values": ["running"]}],
}
ec2instance = client.describe_instance_status(IncludeAllInstances = True, Filters=filter_for["running"])
failed_instances = []
for i in ec2instance["InstanceStatuses"]:
in_status = i['InstanceStatus']['Details'][0]['Status']
sys_status = i['SystemStatus']['Details'][0]['Status']
# check statuses failed instances
if ((in_status != 'passed') or (sys_status != 'passed')):
failed_instances.append(i["InstanceId"])
if len(failed_instances)>0:
# new_line = '\n'
# msg = f'The following instances failed status checks:{new_line} {new_line.join(failed_instances)}'
# #msg = f'The following instances failed status checks, {failed_instances}'
# clientsns.publish(TopicArn='arn:aws:sns:us-west-1:462518063038:test',Message=msg)
for j in failed_instances:
instance = [x for x in list(instances) if x.id == j][0]
instance_name = get_tag(instance.tags)
print (instance_name, instance.id, instance.instance_type)

Cross Account Cloudtrail log transfer through Cloudwatch and Kinesis data stream

I am using Cloudwatch subscriptions to send over cloudtrail log of one account into another. The Account receiving the logs has a Kinesis data stream which receives the logs from the cloudwatch subscription and invokes the standard lambda function provided by AWS to parse and store the logs to an S3 bucket of the log receiver account.
The log files getting written to s3 bucket are in the form of :
{"eventVersion":"1.08","userIdentity":{"type":"AssumedRole","principalId":"AA:i-096379450e69ed082","arn":"arn:aws:sts::34502sdsdsd:assumed-role/RDSAccessRole/i-096379450e69ed082","accountId":"34502sdsdsd","accessKeyId":"ASIAVAVKXAXXXXXXXC","sessionContext":{"sessionIssuer":{"type":"Role","principalId":"AROAVAVKXAKDDDDD","arn":"arn:aws:iam::3450291sdsdsd:role/RDSAccessRole","accountId":"345029asasas","userName":"RDSAccessRole"},"webIdFederationData":{},"attributes":{"mfaAuthenticated":"false","creationDate":"2021-04-27T04:38:52Z"},"ec2RoleDelivery":"2.0"}},"eventTime":"2021-04-27T07:24:20Z","eventSource":"ssm.amazonaws.com","eventName":"ListInstanceAssociations","awsRegion":"us-east-1","sourceIPAddress":"188.208.227.188","userAgent":"aws-sdk-go/1.25.41 (go1.13.15; linux; amd64) amazon-ssm-agent/","requestParameters":{"instanceId":"i-096379450e69ed082","maxResults":20},"responseElements":null,"requestID":"a5c63b9d-aaed-4a3c-9b7d-a4f7c6b774ab","eventID":"70de51df-c6df-4a57-8c1e-0ffdeb5ac29d","readOnly":true,"resources":[{"accountId":"34502914asasas","ARN":"arn:aws:ec2:us-east-1:3450291asasas:instance/i-096379450e69ed082"}],"eventType":"AwsApiCall","managementEvent":true,"eventCategory":"Management","recipientAccountId":"345029149342"}
{"eventVersion":"1.08","userIdentity":{"type":"AssumedRole","principalId":"AROAVAVKXAKPKZ25XXXX:AmazonMWAA-airflow","arn":"arn:aws:sts::3450291asasas:assumed-role/dev-1xdcfd/AmazonMWAA-airflow","accountId":"34502asasas","accessKeyId":"ASIAVAVKXAXXXXXXX","sessionContext":{"sessionIssuer":{"type":"Role","principalId":"AROAVAVKXAKPKZXXXXX","arn":"arn:aws:iam::345029asasas:role/service-role/AmazonMWAA-dlp-dev-1xdcfd","accountId":"3450291asasas","userName":"dlp-dev-1xdcfd"},"webIdFederationData":{},"attributes":{"mfaAuthenticated":"false","creationDate":"2021-04-27T07:04:08Z"}},"invokedBy":"airflow.amazonaws.com"},"eventTime":"2021-04-27T07:23:46Z","eventSource":"logs.amazonaws.com","eventName":"CreateLogStream","awsRegion":"us-east-1","sourceIPAddress":"airflow.amazonaws.com","userAgent":"airflow.amazonaws.com","errorCode":"ResourceAlreadyExistsException","errorMessage":"The specified log stream already exists","requestParameters":{"logStreamName":"scheduler.py.log","logGroupName":"dlp-dev-DAGProcessing"},"responseElements":null,"requestID":"40b48ef9-fc4b-4d1a-8fd1-4f2584aff1e9","eventID":"ef608d43-4765-4a3a-9c92-14ef35104697","readOnly":false,"eventType":"AwsApiCall","apiVersion":"20140328","managementEvent":true,"eventCategory":"Management","recipientAccountId":"3450291asasas"}
The problem with this type of log lines is that Athena is not able to Parse these log lines and I am not able to query the logs using Athena.
I tried modifying the blueprint lambda function to save the log file as a standard JSON result which would make it easy for Athena to parse the files.
Eg:
{'Records': ['{"eventVersion":"1.08","userIdentity":{"type":"AssumedRole","principalId":"AROAVAVKXAKPBRW2S3TAF:i-096379450e69ed082","arn":"arn:aws:sts::345029149342:assumed-role/RightslineRDSAccessRole/i-096379450e69ed082","accountId":"345029149342","accessKeyId":"ASIAVAVKXAKPBL653UOC","sessionContext":{"sessionIssuer":{"type":"Role","principalId":"AROAVAVKXAKPXXXXXXX","arn":"arn:aws:iam::34502asasas:role/RDSAccessRole","accountId":"345029asasas","userName":"RDSAccessRole"},"webIdFederationData":{},"attributes":{"mfaAuthenticated":"false","creationDate":"2021-04-27T04:38:52Z"},"ec2RoleDelivery":"2.0"}},"eventTime":"2021-04-27T07:24:20Z","eventSource":"ssm.amazonaws.com","eventName":"ListInstanceAssociations","awsRegion":"us-east-1","sourceIPAddress":"188.208.227.188","userAgent":"aws-sdk-go/1.25.41 (go1.13.15; linux; amd64) amazon-ssm-agent/","requestParameters":{"instanceId":"i-096379450e69ed082","maxResults":20},"responseElements":null,"requestID":"a5c63b9d-aaed-4a3c-9b7d-a4f7c6b774ab","eventID":"70de51df-c6df-4a57-8c1e-0ffdeb5ac29d","readOnly":true,"resources":[{"accountId":"3450291asasas","ARN":"arn:aws:ec2:us-east-1:34502asasas:instance/i-096379450e69ed082"}],"eventType":"AwsApiCall","managementEvent":true,"eventCategory":"Management","recipientAccountId":"345029asasas"}]}
The modified code for Blueprint Lambda function that I looks like:
import base64
import json
import gzip
from io import BytesIO
import boto3
def transformLogEvent(log_event):
return log_event['message'] + '\n'
def processRecords(records):
for r in records:
data = base64.b64decode(r['data'])
striodata = BytesIO(data)
with gzip.GzipFile(fileobj=striodata, mode='r') as f:
data = json.loads(f.read())
recId = r['recordId']
if data['messageType'] == 'CONTROL_MESSAGE':
yield {
'result': 'Dropped',
'recordId': recId
}
elif data['messageType'] == 'DATA_MESSAGE':
result = {}
result["Records"] = {}
events = []
for e in data['logEvents']:
events.append(e["message"])
result["Records"] = events
print(result)
if len(result) <= 6000000:
yield {
'data': result,
'result': 'Ok',
'recordId': recId
}
else:
yield {
'result': 'ProcessingFailed',
'recordId': recId
}
else:
yield {
'result': 'ProcessingFailed',
'recordId': recId
}
def putRecordsToFirehoseStream(streamName, records, client, attemptsMade, maxAttempts):
failedRecords = []
codes = []
errMsg = ''
# if put_record_batch throws for whatever reason, response['xx'] will error out, adding a check for a valid
# response will prevent this
response = None
try:
response = client.put_record_batch(DeliveryStreamName=streamName, Records=records)
except Exception as e:
failedRecords = records
errMsg = str(e)
# if there are no failedRecords (put_record_batch succeeded), iterate over the response to gather results
if not failedRecords and response and response['FailedPutCount'] > 0:
for idx, res in enumerate(response['RequestResponses']):
# (if the result does not have a key 'ErrorCode' OR if it does and is empty) => we do not need to re-ingest
if 'ErrorCode' not in res or not res['ErrorCode']:
continue
codes.append(res['ErrorCode'])
failedRecords.append(records[idx])
errMsg = 'Individual error codes: ' + ','.join(codes)
if len(failedRecords) > 0:
if attemptsMade + 1 < maxAttempts:
print('Some records failed while calling PutRecordBatch to Firehose stream, retrying. %s' % (errMsg))
putRecordsToFirehoseStream(streamName, failedRecords, client, attemptsMade + 1, maxAttempts)
else:
raise RuntimeError('Could not put records after %s attempts. %s' % (str(maxAttempts), errMsg))
def putRecordsToKinesisStream(streamName, records, client, attemptsMade, maxAttempts):
failedRecords = []
codes = []
errMsg = ''
# if put_records throws for whatever reason, response['xx'] will error out, adding a check for a valid
# response will prevent this
response = None
try:
response = client.put_records(StreamName=streamName, Records=records)
except Exception as e:
failedRecords = records
errMsg = str(e)
# if there are no failedRecords (put_record_batch succeeded), iterate over the response to gather results
if not failedRecords and response and response['FailedRecordCount'] > 0:
for idx, res in enumerate(response['Records']):
# (if the result does not have a key 'ErrorCode' OR if it does and is empty) => we do not need to re-ingest
if 'ErrorCode' not in res or not res['ErrorCode']:
continue
codes.append(res['ErrorCode'])
failedRecords.append(records[idx])
errMsg = 'Individual error codes: ' + ','.join(codes)
if len(failedRecords) > 0:
if attemptsMade + 1 < maxAttempts:
print('Some records failed while calling PutRecords to Kinesis stream, retrying. %s' % (errMsg))
putRecordsToKinesisStream(streamName, failedRecords, client, attemptsMade + 1, maxAttempts)
else:
raise RuntimeError('Could not put records after %s attempts. %s' % (str(maxAttempts), errMsg))
def createReingestionRecord(isSas, originalRecord):
if isSas:
return {'data': base64.b64decode(originalRecord['data']), 'partitionKey': originalRecord['kinesisRecordMetadata']['partitionKey']}
else:
return {'data': base64.b64decode(originalRecord['data'])}
def getReingestionRecord(isSas, reIngestionRecord):
if isSas:
return {'Data': reIngestionRecord['data'], 'PartitionKey': reIngestionRecord['partitionKey']}
else:
return {'Data': reIngestionRecord['data']}
def lambda_handler(event, context):
print(event)
isSas = 'sourceKinesisStreamArn' in event
streamARN = event['sourceKinesisStreamArn'] if isSas else event['deliveryStreamArn']
region = streamARN.split(':')[3]
streamName = streamARN.split('/')[1]
records = list(processRecords(event['records']))
projectedSize = 0
dataByRecordId = {rec['recordId']: createReingestionRecord(isSas, rec) for rec in event['records']}
putRecordBatches = []
recordsToReingest = []
totalRecordsToBeReingested = 0
for idx, rec in enumerate(records):
if rec['result'] != 'Ok':
continue
projectedSize += len(rec['data']) + len(rec['recordId'])
# 6000000 instead of 6291456 to leave ample headroom for the stuff we didn't account for
if projectedSize > 6000000:
totalRecordsToBeReingested += 1
recordsToReingest.append(
getReingestionRecord(isSas, dataByRecordId[rec['recordId']])
)
records[idx]['result'] = 'Dropped'
del(records[idx]['data'])
# split out the record batches into multiple groups, 500 records at max per group
if len(recordsToReingest) == 500:
putRecordBatches.append(recordsToReingest)
recordsToReingest = []
if len(recordsToReingest) > 0:
# add the last batch
putRecordBatches.append(recordsToReingest)
# iterate and call putRecordBatch for each group
recordsReingestedSoFar = 0
if len(putRecordBatches) > 0:
client = boto3.client('kinesis', region_name=region) if isSas else boto3.client('firehose', region_name=region)
for recordBatch in putRecordBatches:
if isSas:
putRecordsToKinesisStream(streamName, recordBatch, client, attemptsMade=0, maxAttempts=20)
else:
putRecordsToFirehoseStream(streamName, recordBatch, client, attemptsMade=0, maxAttempts=20)
recordsReingestedSoFar += len(recordBatch)
print('Reingested %d/%d records out of %d' % (recordsReingestedSoFar, totalRecordsToBeReingested, len(event['records'])))
else:
print('No records to be reingested')
return {"records": records}
My end goal is to store the result on S3 as JSON so that it can be queried easily with Athena.
the line where the transformation is happening is:
elif data['messageType'] == 'DATA_MESSAGE':
Any help in this would be greatly appreciated.

AWS Lambda is writing wrong output in the CloudWatch metrics

I'm new to Devops and coding. I'm working on building a monitoring tool (grafana) with CloudWatch and Lambda.
I have a code which is not working properly. It pings the server. If it is returning 200 it will push 0 in the metrics and when the site is down it should push 1 but when I'm mentioning in the write metrics to write 1, instead of writing 1 its writing 100 and if I try to do any other values its greater than 100 its posting but less than 100 its just post 100.
Here is the code:
import boto3
import urllib2
def write_metric(value, metric):
d = boto3.client('cloudwatch')
d.put_metric_data(Namespace='WebsiteStatus',
MetricData=[
{
'MetricName':metric,
'Dimensions':[
{
'Name': 'Status',
'Value': 'WebsiteStatusCode',
},
],
'Value': value,
},
]
)
def check_site(url, metric):
STAT = 1
print("Checking %s " % url)
request = urllib2.Request("https://" +url)
try:
response = urllib2.urlopen(request)
response.close()
except urllib2.URLError as e:
if hasattr(e, 'code'):
print ("[Error:] Connection to %s failed with code: " %url +str(e.code))
STAT = 100
write_metric(STAT, metric)
if hasattr(e, 'reason'):
print ("[Error:] Connection to %s failed with code: " % url +str(e.reason))
STAT = 100
write_metric(STAT, metric)
except urllib2.HTTPError as e:
if hasattr(e, 'code'):
print ("[Error:] Connection to %s failed with code: " % url + str(e.code))
STAT = 100
write_metric(STAT, metric)
if hasattr(e, 'reason'):
print ("[Error:] Connection to %s failed with code: " % url + str(e.reason))
STAT = 100
write_metric(STAT, metric)
print('HTTPError!!!')
if STAT != 100:
STAT = response.getcode()
return STAT
def lambda_handler(event, context):
websiteurls = [
"website.com"
]
metricname = 'SiteAvailability'
for site in websiteurls:
r = check_site(site,metricname)
if r == 200:
print("Site %s is up" %site)
write_metric(0, metricname)
else:
print("[Error:] Site %s down" %site)
write_metric(1, metricname)
These lines:
STAT = 100
write_metric(STAT, metric)
will always send 100 as your value.

why my XMPP client resend messages?

i have a XMPP cliente on heroku and it works with Google Cloud Messaging but i have a bad behavior on my app
i have cheked my code many times but i have not found any mistake, but some messages is been resent,the problem is not for Acknowledging of my messages, because i am Acknowledging each message and i am not receiving any nack message from GCM server, so i can not know what is the problem
i would appreciate for any help
this is my code
SERVER = 'gcm.googleapis.com'
PORT = 5235
USERNAME = "secret"
PASSWORD = "secret"
N_TIMER=40
EXP_TIMER=1
unacked_messages_quota = 100
send_queue = []
error_send_queue = []
lock=threading.Lock()
def unique_id():
return str(uuid.uuid4().hex)
#synchronized
def message_callback(session, message):
global unacked_messages_quota
gcm = message.getTags('gcm')
if gcm:
gcm_json = gcm[0].getData()
msg = json.loads(gcm_json)
if not msg.has_key('message_type'):
# Acknowledge the incoming message immediately.
send({'to': msg['from'],
'message_type': 'ack',
'message_id': msg['message_id']})
# Queue a response back to the server.
if msg.has_key('from'):
# Send a response back to the app that sent the upstream message.
try:
msg['data']['idCel'] = msg['from']
payloadObj= payload("command", msg['data']['command'] , msg['data'])
rpc = RpcClient()
response = rpc.call(payloadObj)
if 'response' in response and response['response'] == 'ok':
pass
elif response['type'] == 'response' :
send_queue.append({'to': msg['from'],
'priority':'high',
'delay_while_idle':True,
'message_id': unique_id(),
'data': {'response': response['response'],'type': 'response'}
})
else:
send_queue.append({'to': msg['from'],
'message_id': unique_id(),
'data': {'error': response['error'],'type': 'error'}})
except Exception as e:
traceback.print_exc()
print str(e)
elif msg['message_type'] == 'ack' or msg['message_type'] == 'nack':
if msg['message_type'] == 'nack':
error_send_queue.append(
{'to': msg['from'],
'message_type': 'ack',
'message_id': msg['message_id']})
unacked_messages_quota += 1
def send(json_dict):
template = ("<message><gcm xmlns='google:mobile:data'>{1}</gcm></message>")
try:
client.send(xmpp.protocol.Message(
node=template.format(client.Bind.bound[0], json.dumps(json_dict))))
except Exception as e:
traceback.print_exc()
print str(e)
def flush_queued_messages():
global unacked_messages_quota
while len(send_queue) and unacked_messages_quota > 0:
send(send_queue.pop(0))
unacked_messages_quota -= 1
def flush_queued_errors_messages():
lock.acquire()
global unacked_messages_quota
global error_send_queue
global EXP_TIMER
while len(error_send_queue) and unacked_messages_quota > 0:
send(error_send_queue.pop(0))
unacked_messages_quota -= 1
time.sleep( (2**EXP_TIMER) )
EXP_TIMER += 1
EXP_TIMER=1
lock.release()
client = xmpp.Client('gcm.googleapis.com',debug=['always', 'roster'],
port=int(os.environ.get("PORT")))
client.connect(server=(SERVER,PORT), secure=1, use_srv=False)
auth = client.auth(USERNAME, PASSWORD)
if not auth:
print 'Authentication failed!'
sys.exit(1)
client.RegisterHandler('message', message_callback)
t1 = threading.Thread(target=flush_queued_errors_messages)
while True:
client.Process(1)
flush_queued_messages()
if N_TIMER == 0:
client.send(" ")
N_TIMER = 40
if not t1.isAlive():
t1 = threading.Thread(target=flush_queued_errors_messages)
t1.start()
N_TIMER -= 1
I could know my real problem, the problem was not from my code but the problem was from GCM servers.
Here is the explication.

Mock Stripe Methods in Python for testing

So I am trying to mock all the stripe web hooks in the method so that I can write the Unit test for it. I am using the mock library for mocking the stripe methods. Here is the method I am trying to mock:
class AddCardView(APIView):
"""
* Add card for the customer
"""
permission_classes = (
CustomerPermission,
)
def post(self, request, format=None):
name = request.DATA.get('name', None)
cvc = request.DATA.get('cvc', None)
number = request.DATA.get('number', None)
expiry = request.DATA.get('expiry', None)
expiry_month, expiry_year = expiry.split("/")
customer_obj = request.user.contact.business.customer
customer = stripe.Customer.retrieve(customer_obj.stripe_id)
try:
card = customer.sources.create(
source={
"object": "card",
"number": number,
"exp_month": expiry_month,
"exp_year": expiry_year,
"cvc": cvc,
"name": name
}
)
# making it the default card
customer.default_source = card.id
customer.save()
except CardError as ce:
logger.error("Got CardError for customer_id={0}, CardError={1}".format(customer_obj.pk, ce.json_body))
return Response({"success": False, "error": "Failed to add card"})
else:
customer_obj.card_last_4 = card.get('last4')
customer_obj.card_kind = card.get('type', '')
customer_obj.card_fingerprint = card.get('fingerprint')
customer_obj.save()
return Response({"success": True})
This is the method for unit testing:
#mock.patch('stripe.Customer.retrieve')
#mock.patch('stripe.Customer.create')
def test_add_card(self,create_mock,retrieve_mock):
response = {
'default_card': None,
'cards': {
"count": 0,
"data": []
}
}
# save_mock.return_value = response
create_mock.return_value = response
retrieve_mock.return_value = response
self.api_client.client.login(username = self.username, password = self.password)
res = self.api_client.post('/biz/api/auth/card/add')
print res
Now stripe.Customer.retrieve is being mocked properly. But I am not able to mock customer.sources.create. I am really stuck on this.
This is the right way of doing it:
#mock.patch('stripe.Customer.retrieve')
def test_add_card_failure(self, retrieve_mock):
data = {
'name': "shubham",
'cvc': 123,
'number': "4242424242424242",
'expiry': "12/23",
}
e = CardError("Card Error", "", "")
retrieve_mock.return_value.sources.create.return_value = e
self.api_client.client.login(username=self.username, password=self.password)
res = self.api_client.post('/biz/api/auth/card/add', data=data)
self.assertEqual(self.deserialize(res)['success'], False)
Even though the given answer is correct, there is a way more comfortable solution using vcrpy. That is creating a cassette (record) once a given record does not exist yet. When it does, the mocking is done transparently and the record will be replayed. Beautiful.
Having a vanilla pyramid application, using py.test, my test now looks like this:
import vcr
# here we have some FactoryBoy fixtures
from tests.fixtures import PaymentServiceProviderFactory, SSOUserFactory
def test_post_transaction(sqla_session, test_app):
# first we need a PSP and a User existent in the DB
psp = PaymentServiceProviderFactory() # type: PaymentServiceProvider
user = SSOUserFactory()
sqla_session.add(psp, user)
sqla_session.flush()
with vcr.use_cassette('tests/casettes/tests.checkout.services.transaction_test.test_post_transaction.yaml'):
# with that PSP we create a new PSPTransaction ...
res = test_app.post(url='/psps/%s/transaction' % psp.id,
params={
'token': '4711',
'amount': '12.44',
'currency': 'EUR',
})
assert 201 == res.status_code
assert 'id' in res.json_body
IMO, the following method is better than the rest of the answers
import unittest
import stripe
import json
from unittest.mock import patch
from stripe.http_client import RequestsClient # to mock the request session
stripe.api_key = "foo"
stripe.default_http_client = RequestsClient() # assigning the default HTTP client
null = None
false = False
true = True
charge_resp = {
"id": "ch_1FgmT3DotIke6IEFVkwh2N6Y",
"object": "charge",
"amount": 1000,
"amount_captured": 1000,
"amount_refunded": 0,
"billing_details": {
"address": {
"city": "Los Angeles",
"country": "USA",
},
"email": null,
"name": "Jerin",
"phone": null
},
"captured": true,
}
def get_customer_city_from_charge(stripe_charge_id):
# this is our function and we are writing unit-test for this function
charge_response = stripe.Charge.retrieve("foo-bar")
return charge_response.billing_details.address.city
class TestStringMethods(unittest.TestCase):
#patch("stripe.default_http_client._session")
def test_get_customer_city_from_charge(self, mock_session):
mock_response = mock_session.request.return_value
mock_response.content.decode.return_value = json.dumps(charge_resp)
mock_response.status_code = 200
city_name = get_customer_city_from_charge("some_id")
self.assertEqual(city_name, "Los Angeles")
if __name__ == '__main__':
unittest.main()
Advantages of this method
You can generate the corresponding class objects (here, the charge_response variable is a type of Charge--(source code))
You can use the dot (.) operator over the response (as we can do with real stripe SDK)
dot operator support for deep attributes