Python 3.4 has a new enum module and Enum data type. If you are unable to switch to 3.4 yet, Enum has been backported.
Since Enum members support docstrings, as pretty much all python objects do, I would like to set them. Is there an easy way to do that?
Yes there is, and it's my favorite recipe so far. As a bonus, one does not have to specify the integer value either. Here's an example:
class AddressSegment(AutoEnum):
misc = "not currently tracked"
ordinal = "N S E W NE NW SE SW"
secondary = "apt bldg floor etc"
street = "st ave blvd etc"
You might ask why I don't just have "N S E W NE NW SE SW" be the value of ordinal? Because when I get its repr seeing <AddressSegment.ordinal: 'N S E W NE NW SE SW'> gets a bit clunky, but having that information readily available in the docstring is a good compromise.
Here's the recipe for the Enum:
class AutoEnum(enum.Enum):
"""
Automatically numbers enum members starting from 1.
Includes support for a custom docstring per member.
"""
#
def __new__(cls, *args):
"""Ignores arguments (will be handled in __init__."""
value = len(cls) + 1
obj = object.__new__(cls)
obj._value_ = value
return obj
#
def __init__(self, *args):
"""Can handle 0 or 1 argument; more requires a custom __init__.
0 = auto-number w/o docstring
1 = auto-number w/ docstring
2+ = needs custom __init__
"""
if len(args) == 1 and isinstance(args[0], (str, unicode)):
self.__doc__ = args[0]
elif args:
raise TypeError('%s not dealt with -- need custom __init__' % (args,))
And in use:
>>> list(AddressSegment)
[<AddressSegment.ordinal: 1>, <AddressSegment.secondary: 2>, <AddressSegment.misc: 3>, <AddressSegment.street: 4>]
>>> AddressSegment.secondary
<AddressSegment.secondary: 2>
>>> AddressSegment.secondary.__doc__
'apt bldg floor etc'
The reason I handle the arguments in __init__ instead of in __new__ is to make subclassing AutoEnum easier should I want to extend it further.
Anyone arriving here as a google search:
For many IDE's now in 2022, the following will populate intellisense:
class MyEnum(Enum):
"""
MyEnum purpose and general doc string
"""
VALUE = "Value"
"""
This is the Value selection. Use this for Values
"""
BUILD = "Build"
"""
This is the Build selection. Use this for Buildings
"""
Example in VSCode:
This does not directly answer the question, but I wanted to add a more robust version of #Ethan Furman's AutoEnum class which uses the auto enum function.
The implementation below works with Pydantic and does fuzzy-matching of values to the corresponding enum type.
Usage:
In [2]: class Weekday(AutoEnum): ## Assume AutoEnum class has been defined.
...: Monday = auto()
...: Tuesday = auto()
...: Wednesday = auto()
...: Thursday = auto()
...: Friday = auto()
...: Saturday = auto()
...: Sunday = auto()
...:
In [3]: Weekday('MONDAY') ## Fuzzy matching: case-insensitive
Out[3]: Monday
In [4]: Weekday(' MO NDAY') ## Fuzzy matching: ignores extra spaces
Out[4]: Monday
In [5]: Weekday('_M_onDa y') ## Fuzzy matching: ignores underscores
Out[5]: Monday
In [6]: %timeit Weekday('_M_onDay') ## Fuzzy matching takes ~1 microsecond.
1.15 µs ± 10.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [7]: %timeit Weekday.from_str('_M_onDay') ## You can further speedup matching using from_str (this is because _missing_ is not called)
736 ns ± 8.89 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [8]: list(Weekday) ## Get all the enums
Out[8]: [Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday]
In [9]: Weekday.Monday.matches('Tuesday') ## Check if a string matches a particular enum value
Out[9]: False
In [10]: Weekday.matches_any('__TUESDAY__') ## Check if a string matches any enum
Out[10]: True
In [11]: Weekday.Tuesday is Weekday(' Tuesday') and Weekday.Tuesday == Weekday('_Tuesday_') ## `is` and `==` work as expected
Out[11]: True
In [12]: Weekday.Tuesday == 'Tuesday' ## Strings don't match enum values, because strings aren't enums!
Out[12]: False
In [13]: Weekday.convert_keys({ ## Convert matching dict keys to an enum. Similar: .convert_list, .convert_set
'monday': 'alice',
'tuesday': 'bob',
'not_wednesday': 'charles',
'THURSDAY ': 'denise',
})
Out[13]:
{Monday: 'alice',
Tuesday: 'bob',
'not_wednesday': 'charles',
Thursday: 'denise'}
The code for AutoEnum can be found below.
If you want to change the fuzzy-matching logic, then override the classmethod _normalize (e.g. returning the input unchanged in _normalize, will perform exact matching).
from typing import *
from enum import Enum, auto
class AutoEnum(str, Enum):
"""
Utility class which can be subclassed to create enums using auto().
Also provides utility methods for common enum operations.
"""
#classmethod
def _missing_(cls, enum_value: Any):
## Ref: https://stackoverflow.com/a/60174274/4900327
## This is needed to allow Pydantic to perform case-insensitive conversion to AutoEnum.
return cls.from_str(enum_value=enum_value, raise_error=True)
def _generate_next_value_(name, start, count, last_values):
return name
#property
def str(self) -> str:
return self.__str__()
def __repr__(self):
return self.__str__()
def __str__(self):
return self.name
def __hash__(self):
return hash(self.__class__.__name__ + '.' + self.name)
def __eq__(self, other):
return self is other
def __ne__(self, other):
return self is not other
def matches(self, enum_value: str) -> bool:
return self is self.from_str(enum_value, raise_error=False)
#classmethod
def matches_any(cls, enum_value: str) -> bool:
return cls.from_str(enum_value, raise_error=False) is not None
#classmethod
def does_not_match_any(cls, enum_value: str) -> bool:
return not cls.matches_any(enum_value)
#classmethod
def _initialize_lookup(cls):
if '_value2member_map_normalized_' not in cls.__dict__: ## Caching values for fast retrieval.
cls._value2member_map_normalized_ = {}
for e in list(cls):
normalized_e_name: str = cls._normalize(e.value)
if normalized_e_name in cls._value2member_map_normalized_:
raise ValueError(
f'Cannot register enum "{e.value}"; '
f'another enum with the same normalized name "{normalized_e_name}" already exists.'
)
cls._value2member_map_normalized_[normalized_e_name] = e
#classmethod
def from_str(cls, enum_value: str, raise_error: bool = True) -> Optional:
"""
Performs a case-insensitive lookup of the enum value string among the members of the current AutoEnum subclass.
:param enum_value: enum value string
:param raise_error: whether to raise an error if the string is not found in the enum
:return: an enum value which matches the string
:raises: ValueError if raise_error is True and no enum value matches the string
"""
if isinstance(enum_value, cls):
return enum_value
if enum_value is None and raise_error is False:
return None
if not isinstance(enum_value, str) and raise_error is True:
raise ValueError(f'Input should be a string; found type {type(enum_value)}')
cls._initialize_lookup()
enum_obj: Optional[AutoEnum] = cls._value2member_map_normalized_.get(cls._normalize(enum_value))
if enum_obj is None and raise_error is True:
raise ValueError(f'Could not find enum with value {enum_value}; available values are: {list(cls)}.')
return enum_obj
#classmethod
def _normalize(cls, x: str) -> str:
## Found to be faster than .translate() and re.sub() on Python 3.10.6
return str(x).replace(' ', '').replace('-', '').replace('_', '').lower()
#classmethod
def convert_keys(cls, d: Dict) -> Dict:
"""
Converts string dict keys to the matching members of the current AutoEnum subclass.
Leaves non-string keys untouched.
:param d: dict to transform
:return: dict with matching string keys transformed to enum values
"""
out_dict = {}
for k, v in d.items():
if isinstance(k, str) and cls.from_str(k, raise_error=False) is not None:
out_dict[cls.from_str(k, raise_error=False)] = v
else:
out_dict[k] = v
return out_dict
#classmethod
def convert_keys_to_str(cls, d: Dict) -> Dict:
"""
Converts dict keys of the current AutoEnum subclass to the matching string key.
Leaves other keys untouched.
:param d: dict to transform
:return: dict with matching keys of the current AutoEnum transformed to strings.
"""
out_dict = {}
for k, v in d.items():
if isinstance(k, cls):
out_dict[str(k)] = v
else:
out_dict[k] = v
return out_dict
#classmethod
def convert_values(
cls,
d: Union[Dict, Set, List, Tuple],
raise_error: bool = False
) -> Union[Dict, Set, List, Tuple]:
"""
Converts string values to the matching members of the current AutoEnum subclass.
Leaves non-string values untouched.
:param d: dict, set, list or tuple to transform.
:param raise_error: raise an error if unsupported type.
:return: data structure with matching string values transformed to enum values.
"""
if isinstance(d, dict):
return cls.convert_dict_values(d)
if isinstance(d, list):
return cls.convert_list(d)
if isinstance(d, tuple):
return tuple(cls.convert_list(d))
if isinstance(d, set):
return cls.convert_set(d)
if raise_error:
raise ValueError(f'Unrecognized data structure of type {type(d)}')
return d
#classmethod
def convert_dict_values(cls, d: Dict) -> Dict:
"""
Converts string dict values to the matching members of the current AutoEnum subclass.
Leaves non-string values untouched.
:param d: dict to transform
:return: dict with matching string values transformed to enum values
"""
out_dict = {}
for k, v in d.items():
if isinstance(v, str) and cls.from_str(v, raise_error=False) is not None:
out_dict[k] = cls.from_str(v, raise_error=False)
else:
out_dict[k] = v
return out_dict
#classmethod
def convert_list(cls, l: List) -> List:
"""
Converts string list itmes to the matching members of the current AutoEnum subclass.
Leaves non-string items untouched.
:param l: list to transform
:return: list with matching string items transformed to enum values
"""
out_list = []
for item in l:
if isinstance(item, str) and cls.matches_any(item):
out_list.append(cls.from_str(item))
else:
out_list.append(item)
return out_list
#classmethod
def convert_set(cls, s: Set) -> Set:
"""
Converts string list itmes to the matching members of the current AutoEnum subclass.
Leaves non-string items untouched.
:param s: set to transform
:return: set with matching string items transformed to enum values
"""
out_set = set()
for item in s:
if isinstance(item, str) and cls.matches_any(item):
out_set.add(cls.from_str(item))
else:
out_set.add(item)
return out_set
#classmethod
def convert_values_to_str(cls, d: Dict) -> Dict:
"""
Converts dict values of the current AutoEnum subclass to the matching string value.
Leaves other values untouched.
:param d: dict to transform
:return: dict with matching values of the current AutoEnum transformed to strings.
"""
out_dict = {}
for k, v in d.items():
if isinstance(v, cls):
out_dict[k] = str(v)
else:
out_dict[k] = v
return out_dict
Functions and classes have docstrings, but most objects don't and do not even need them at all. There is no native docstring syntax for instance attributes, as they can be described exhaustively in the classes' docstring, which is also what I recommend you to do. Instances of classes normally also don't have their own docstrings, and enum members are nothing more than that.
Sure enough you could add a docstring to almost anything. Actually you can, indeed, add anything to almost anything, as this is the way python was designed. But it is neither useful nor clean, and even what #Ethan Furman posted seems like way to much overhead just for adding a docstring to a static property.
Long story short, even though you might not like it at first:
Just don't do it and go with your enum's docstring. It is more than enough to explain the meaning of its members.
Related
What does
if self.transforms:
data = self.transforms(data)
do? I don't understand the logic behind this line - what is the condition the line is using?
I'm reading an article on creating a custom dataset with pytorch based on the below implementation:
#custom dataset
class MNISTDataset(Dataset):
def __init__(self, images, labels=None, transforms=None):
self.X = images
self.y = labels
self.transforms = transforms
def __len__(self):
return (len(self.X))
def __getitem__(self, i):
data = self.X.iloc[i, :]
data = np.asarray(data).astype(np.uint8).reshape(28, 28, 1)
if self.transforms:
data = self.transforms(data)
if self.y is not None:
return (data, self.y[i])
else:
return data
train_data = MNISTDataset(train_images, train_labels, transform)
test_data = MNISTDataset(test_images, test_labels, transform)
# dataloaders
trainloader = DataLoader(train_data, batch_size=128, shuffle=True)
testloader = DataLoader(test_data, batch_size=128, shuffle=True)
thank you! i'm basically trying to understand why it works & how it applies transforms to the data.
The dataset MNISTDataset can optionnaly be initialized with a transform function. If such transform function is given it be saved in self.transforms else it will keep its default values None. When calling a new item with __getitem__, it first checks if the transform is a truthy value, in this case it checks if self.transforms can be coerced to True which is the case for a callable object. Otherwise it means self.transforms hasn't been provided in the first place and no transform function is applied on data.
Here's a general example, out of a torch/torchvision context:
def do(x, callback=None):
if callback: # will be True if callback is a function/lambda
return callback(x)
return x
do(2) # returns 2
do(2, callback=lambda x: 2*x) # returns 4
I'm trying to programmatically set a value in a dictionary, potentially nested, given a list of indices and a value.
So for example, let's say my list of indices is:
['person', 'address', 'city']
and the value is
'New York'
I want as a result a dictionary object like:
{ 'Person': { 'address': { 'city': 'New York' } }
Basically, the list represents a 'path' into a nested dictionary.
I think I can construct the dictionary itself, but where I'm stumbling is how to set the value. Obviously if I was just writing code for this manually it would be:
dict['Person']['address']['city'] = 'New York'
But how do I index into the dictionary and set the value like that programmatically if I just have a list of the indices and the value?
Python
Something like this could help:
def nested_set(dic, keys, value):
for key in keys[:-1]:
dic = dic.setdefault(key, {})
dic[keys[-1]] = value
And you can use it like this:
>>> d = {}
>>> nested_set(d, ['person', 'address', 'city'], 'New York')
>>> d
{'person': {'address': {'city': 'New York'}}}
I took the freedom to extend the code from the answer of Bakuriu. Therefore upvotes on this are optional, as his code is in and of itself a witty solution, which I wouldn't have thought of.
def nested_set(dic, keys, value, create_missing=True):
d = dic
for key in keys[:-1]:
if key in d:
d = d[key]
elif create_missing:
d = d.setdefault(key, {})
else:
return dic
if keys[-1] in d or create_missing:
d[keys[-1]] = value
return dic
When setting create_missing to True, you're making sure to only set already existing values:
# Trying to set a value of a nonexistent key DOES NOT create a new value
print(nested_set({"A": {"B": 1}}, ["A", "8"], 2, False))
>>> {'A': {'B': 1}}
# Trying to set a value of an existent key DOES create a new value
print(nested_set({"A": {"B": 1}}, ["A", "8"], 2, True))
>>> {'A': {'B': 1, '8': 2}}
# Set the value of an existing key
print(nested_set({"A": {"B": 1}}, ["A", "B"], 2))
>>> {'A': {'B': 2}}
Here's another option:
from collections import defaultdict
recursivedict = lambda: defaultdict(recursivedict)
mydict = recursivedict()
I originally got this from here: Set nested dict value and create intermediate keys.
It is quite clever and elegant if you ask me.
First off, you probably want to look at setdefault.
As a function I'd write it as
def get_leaf_dict(dct, key_list):
res=dct
for key in key_list:
res=res.setdefault(key, {})
return res
This would be used as:
get_leaf_dict( dict, ['Person', 'address', 'city']) = 'New York'
This could be cleaned up with error handling and such. Also using *args rather than a single key-list argument might be nice; but the idea is that
you can iterate over the keys, pulling up the appropriate dictionary at each level.
Here is my simple solution: just write
terms = ['person', 'address', 'city']
result = nested_dict(3, str)
result[terms] = 'New York' # as easy as it can be
You can even do:
terms = ['John', 'Tinkoff', '1094535332'] # account in Tinkoff Bank
result = nested_dict(3, float)
result[terms] += 2375.30
Now the backstage:
from collections import defaultdict
class nesteddict(defaultdict):
def __getitem__(self, key):
if isinstance(key, list):
d = self
for i in key:
d = defaultdict.__getitem__(d, i)
return d
else:
return defaultdict.__getitem__(self, key)
def __setitem__(self, key, value):
if isinstance(key, list):
d = self[key[:-1]]
defaultdict.__setitem__(d, key[-1], value)
else:
defaultdict.__setitem__(self, key, value)
def nested_dict(n, type):
if n == 1:
return nesteddict(type)
else:
return nesteddict(lambda: nested_dict(n-1, type))
The dotty_dict library for Python 3 can do this. See documentation, Dotty Dict for more clarity.
from dotty_dict import dotty
dot = dotty()
string = '.'.join(['person', 'address', 'city'])
dot[string] = 'New York'
print(dot)
Output:
{'person': {'address': {'city': 'New York'}}}
Use these pair of methods
def gattr(d, *attrs):
"""
This method receives a dict and list of attributes to return the innermost value of the give dict
"""
try:
for at in attrs:
d = d[at]
return d
except:
return None
def sattr(d, *attrs):
"""
Adds "val" to dict in the hierarchy mentioned via *attrs
For ex:
sattr(animals, "cat", "leg","fingers", 4) is equivalent to animals["cat"]["leg"]["fingers"]=4
This method creates necessary objects until it reaches the final depth
This behaviour is also known as autovivification and plenty of implementation are around
This implementation addresses the corner case of replacing existing primitives
https://gist.github.com/hrldcpr/2012250#gistcomment-1779319
"""
for attr in attrs[:-2]:
# If such key is not found or the value is primitive supply an empty dict
if d.get(attr) is None or isinstance(d.get(attr), dict):
d[attr] = {}
d = d[attr]
d[attrs[-2]] = attrs[-1]
Here's a variant of Bakuriu's answer that doesn't rely on a separate function:
keys = ['Person', 'address', 'city']
value = 'New York'
nested_dict = {}
# Build nested dictionary up until 2nd to last key
# (Effectively nested_dict['Person']['address'] = {})
sub_dict = nested_dict
for key_ind, key in enumerate(keys[:-1]):
if not key_ind:
# Point to newly added piece of dictionary
sub_dict = nested_dict.setdefault(key, {})
else:
# Point to newly added piece of sub-dictionary
# that is also added to original dictionary
sub_dict = sub_dict.setdefault(key, {})
# Add value to last key of nested structure of keys
# (Effectively nested_dict['Person']['address']['city'] = value)
sub_dict[keys[-1]] = value
print(nested_dict)
>>> {'Person': {'address': {'city': 'New York'}}}
This is a pretty good use case for a recursive function. So you can do something like this:
def parse(l: list, v: str) -> dict:
copy = dict()
k, *s = l
if len(s) > 0:
copy[k] = parse(s, v)
else:
copy[k] = v
return copy
This effectively pops off the first value of the passed list l as a key for the dict copy that we initialize, then runs the remaining list through the same function, creating a new key under that key until there's nothing left in the list, whereupon it assigns the last value to the v param.
This is much easier in Perl:
my %hash;
$hash{"aaa"}{"bbb"}{"ccc"}=1; # auto creates each of the intermediate levels
# of the hash (aka: dict or associated array)
I have currently these to utils functions.
The only difference between unique_account_link_generator and unique_order_id is what they filter within qs_exists. It's either .filter(slug=new_id) or .filter(order_id=new_id)
I now wonder is there a way to combine them and then being able to define the filter method when I call the function: unique_id_generator(instance, _filter = "order_id")
import random
import string
def random_string_generator(size=10, chars=string.ascii_lowercase + string.digits):
return ''.join(random.choice(chars) for _ in range(size))
def unique_account_link_generator(instance):
"""
1. Generates random string
2. Check if string unique in database
3. If already exists, generate new string
"""
new_id = random_string_generator()
myClass = instance.__class__
qs_exists = myClass.objects.filter(slug=new_id).exists()
if qs_exists:
return unique_account_link_generator(instance)
return new_id
# How to send field_name via function?
def unique_id_generator(instance):
"""
1. Generates random string
2. Check if string unique in database
3. If already exists, generate new string
"""
new_id = random_string_generator()
myClass = instance.__class__
qs_exists = myClass.objects.filter(order_id=new_id).exists()
if qs_exists:
return unique_id_generator(instance)
return new_id
Not sure I understood the question, as the answer is very simple:
def unique_id_generator(instance, _filter="order_id"):
new_id = random_string_generator()
myClass = instance.__class__
qs_exists = myClass.objects.filter(**{_filter:new_id}).exists()
if qs_exists:
return unique_id_generator(instance, _filter)
return new_id
I want to give you an answer to your question in the comments. Since the comment section doesn't allow much text I would like to attach this as an addition to the accepted answer.
It's actually correct that **{_filter:new_id} will unpack what's inside the _filter parameter
If you call the function with (instance, _filter="order_id")
this part **{_filter:new_id} will look like this **{"order_id":"randomGeneratedCode123"}
Now you have a dictionary with the key "order_id" and the value "randomGeneratedCode123"
You goal is to transform the key "order_id" to a parameter name and the value of the "order_id" key to the value of the parameter order_id
order_id = "randomGeneratedCode123"
As you already said you can unpack a dictionary with the double stars **
After unpacking it, the keys in the dictionary will be your parameter names and the values of the keys the parameter values
Here is a small example for better understanding:
Let's say you have a dictionary and a function
dict = {'a': 1, 'b': 2}
def example(a, b):
print("Variable a is {}, b is {}".format(a, b))
example(**dict)
**dict is converted to:
a = 1, b = 2 so the function will be called with
example(a = 1, b = 2)
It's important that the keys in your dictionary have the same name as your function parameter names
So this wouldn't work:
dict = {'a': 1, 'c': 2}
example(**dict)
because it's "translated " as
example(a = 1, c = 2)
and the function doesn't have a parameter with the name c
In Django admin, if I want to display a list of Iron and their respective formatted weights, I would have to do this.
class IronAdmin(admin.ModelAdmin):
model = Iron
fields = ('weight_formatted',)
def weight_formatted(self, object):
return '{0:.2f} Kg'.format(object.weight)
weight_formatted.short_description = 'Weight'
I.e: 500.00 Kg
The problem with this however is that I would have to write a method for every field that I want to format, making it redundant when I have 10 or more objects to format.
Is there a method that I could override to "catch" these values and specify formatting before they get rendered onto the html? I.e. instead of having to write a method for each Admin class, I could just write the following and have it be formatted.
class IronAdmin(admin.ModelAdmin):
model = Iron
fields = ('weight__kg',)
def overriden_method(field):
if field.name.contains('__kg'):
field.value = '{0:.2f} Kg'.format(field.value)
I.e: 500.00 Kg
After hours scouring the source , I finally figured it out! I realize this isn't the most efficient code and it's probably more trouble than it's worth in most use cases but it's enough for me. In case anyone else needs a quick and dirty way to do it:
In order to automate it, I had to override django.contrib.admin.templatetags.admin_list.result_list with the following:
def result_list_larz(cl):
"""
Displays the headers and data list together
"""
resultz = list(results(cl)) # Where we override
""" Overriding starts here """
""" Have to scrub the __kg's as result_header(cl) will error out """
for k in cl.list_display:
cl.list_display[cl.list_display.index(k)] = k.replace('__kg','').replace('__c','')
headers = list(result_headers(cl))
num_sorted_fields = 0
for h in headers:
if h['sortable'] and h['sorted']:
num_sorted_fields += 1
return {'cl': cl,
'result_hidden_fields': list(result_hidden_fields(cl)),
'result_headers': headers,
'num_sorted_fields': num_sorted_fields,
'results': resultz}
Then overriding results(cl)'s call to items_for_result() wherein we then override its call to lookup_field() as follows:
def lookup_field(name, obj, model_admin=None):
opts = obj._meta
try:
f = _get_non_gfk_field(opts, name)
except (FieldDoesNotExist, FieldIsAForeignKeyColumnName):
# For non-field values, the value is either a method, property or
# returned via a callable.
if callable(name):
attr = name
value = attr(obj)
elif (model_admin is not None and
hasattr(model_admin, name) and
not name == '__str__' and
not name == '__unicode__'):
attr = getattr(model_admin, name)
value = attr(obj)
""" Formatting code here """
elif '__kg' in name or '__c' in name: # THE INSERT FOR FORMATTING!
actual_name = name.replace('__kg','').replace('__c', '')
value = getattr(obj, actual_name)
value = '{0:,.2f}'.format(value)
prefix = ''
postfix = ''
if '__kg' in name:
postfix = ' Kg'
elif '__c' in name:
prefix = 'P'
value = '{}{}{}'.format(prefix, value, postfix)
attr = value
else:
attr = getattr(obj, name)
if callable(attr):
value = attr()
else:
value = attr
f = None
""" Overriding code END """
else:
attr = None
value = getattr(obj, name)
return f, attr, value
I'm attempting to compose a class hierarchy using unittest.TestCase. These aren't strictly unit-tests I'm running here, I'm trying to instead test a bunch of string parsing functionality that is reliant on a few parameters (customer-name and missing-value for example). The general idea is to simply utilize some of unittest's conveniences and keep things DRY.
import unittest
import parsingfunc as i
class CustomerTestCase(unittest.TestCase):
"""Base class, most general functionality and test cases"""
def __init__(self, testname):
super(CustomerTestCase, self).__init__(testname)
# helpers
def mess_with_spaces():
...
def mess_with_case():
...
# tests
def test_case(self, value):
"""
Test that parsing function produces the same
value regardless of case of input
"""
self.assertEqual(self.func(value, missing=self.missing, customer=self.name),
self.func(self.mess_with_case(value), missing=self.missing, customer=self.name)))
...
def test_spaces(self, value):
"""
Test that parsing function produces the same
value regardless of spacing present in input
"""
...
class AisleValues(CustomerTestCase):
"""Base class for testing aisle values"""
def __init__(self, testname, customername=None, missing=None):
super(CustomerTestCase, self).__init__(testname)
self.name = customername
self.missing = missing
self.func = i.aisle_to_num
...
class PeacockAisles(AisleValues):
"""Peacock aisle parsing test cases"""
def __init__(self, testname):
super(AisleValues, self).__init__(testname, customername='peacock', missing='nan')
...
And now try to create instances of these classes
In [6]: a = i.CustomerTestCase('space_test')
In [7]: a.__dict__
Out[7]:
{'_cleanups': [],
'_resultForDoCleanups': None,
'_testMethodDoc': '\n Test that parsing function produces the same\n value regardless of spacing present in input\n ',
'_testMethodName': 'test_spaces',
'_type_equality_funcs': {list: 'assertListEqual',
dict: 'assertDictEqual',
set: 'assertSetEqual',
frozenset: 'assertSetEqual',
tuple: 'assertTupleEqual',
unicode: 'assertMultiLineEqual'}}
In [8]: b = i.AisleValues('test_spaces')
In [9]: b.__dict__
Out[9]:
{'_cleanups': [],
'_resultForDoCleanups': None,
'_testMethodDoc': '\n Test that parsing function produces the same\n value regardless of spacing present in input\n ',
'_testMethodName': 'test_spaces',
'_type_equality_funcs': {list: 'assertListEqual',
dict: 'assertDictEqual',
set: 'assertSetEqual',
frozenset: 'assertSetEqual',
tuple: 'assertTupleEqual',
unicode: 'assertMultiLineEqual'},
'func': <function integration.aisle_to_num>,
'missing': None,
'name': None}
In [10]: b = i.AisleValues('test_spaces', customername='peacock', missing='nan')
In [11]: b.__dict__
Out[12]:
{'_cleanups': [],
'_resultForDoCleanups': None,
'_testMethodDoc': '\n Test that parsing function produces the same\n value regardless of spacing present in input\n ',
'_testMethodName': 'test_spaces',
'_type_equality_funcs': {list: 'assertListEqual',
dict: 'assertDictEqual',
set: 'assertSetEqual',
frozenset: 'assertSetEqual',
tuple: 'assertTupleEqual',
unicode: 'assertMultiLineEqual'},
'func': <function integration.aisle_to_num>,
'missing': 'nan',
'name': 'peacock'}
In [13]: c = i.PeacockAisles('test_spaces')
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-12-d4bba181b94e> in <module>()
----> 1 c = i.PeacockAisles('test_spaces')
/path/to/python.py in __init__(self, testname)
89
90 def __init__(self, testname):
---> 91 super(AisleValues, self).__init__(testname, customername='peacock', missing='nan')
92 pprint(self.__dict__)
93
TypeError: __init__() got an unexpected keyword argument 'customername'
So what's the deal? Thanks!
You're not calling super correctly. When you name a class in the super call, it should be the current class, not a base class (unless you really know what you're doing and you want to skip the base class's implementation too, in favor of a "grandparent" class).
Your current code has AisleValues.__init__ calling unittest.TestCase.__init__, bypassing CustomerTestCase.__init__. It happens to work because CustomerTestCase.__init__ doesn't do anything useful (you could delete it with no effect), but that's just luck. When PeacockAisles.__init__ calls CustomerTestCase.__init__ (bypassing AisleValues.__init__), it fails because the grandparent class doesn't allow all the same arguments as its child.
You want:
class AisleValues(CustomerTestCase):
"""Base class for testing aisle values"""
def __init__(self, testname, customername=None, missing=None):
super(AisleValues, self).__init__(testname) # change the class named here
...
And:
class PeacockAisles(AisleValues):
"""Peacock aisle parsing test cases"""
def __init__(self, testname):
super(PeacockAisles, self).__init__(testname, customername='peacock', missing='nan')
... # and on the previous line