[Common] Add overrides function decorator

This commit is contained in:
bendikro 2016-03-07 20:48:13 +01:00 committed by Calum Lind
parent 4d3cf756e4
commit 891209d925
1 changed files with 90 additions and 0 deletions

View File

@ -7,6 +7,8 @@
# See LICENSE for more details.
#
import inspect
import re
from functools import wraps
@ -24,3 +26,91 @@ def proxy(proxy_func):
return proxy_func(func, *args, **kwargs)
return wrapper
return decorator
def overrides(*args):
"""
Decorater function to specify when class methods override
super class methods.
When used as
@overrides
def funcname
the argument will be the funcname function.
When used as
@overrides(BaseClass)
def funcname
the argument will be the BaseClass
"""
stack = inspect.stack()
if inspect.isfunction(args[0]):
return _overrides(stack, args[0])
else:
# One or more classes are specifed, so return a function that will be
# called with the real function as argument
def ret_func(func, **kwargs):
return _overrides(stack, func, explicit_base_classes=args)
return ret_func
def _overrides(stack, method, explicit_base_classes=None):
# stack[0]=overrides, stack[1]=inside class def'n, stack[2]=outside class def'n
classes = {}
derived_class_locals = stack[2][0].f_locals
# Find all super classes
m = re.search(r'class\s(.+)\((.+)\)\s*\:', stack[2][4][0])
class_name = m.group(1)
base_classes = m.group(2)
# Handle multiple inheritance
base_classes = [s.strip() for s in base_classes.split(',')]
check_classes = base_classes
if not base_classes:
raise ValueError('overrides decorator: unable to determine base class of class "%s"' % class_name)
def get_class(cls_name):
if '.' not in cls_name:
return derived_class_locals[cls_name]
else:
components = cls_name.split('.')
# obj is either a module or a class
obj = derived_class_locals[components[0]]
for c in components[1:]:
assert inspect.ismodule(obj) or inspect.isclass(obj)
obj = getattr(obj, c)
return obj
if explicit_base_classes:
# One or more base classes are explicitly given, check only those classes
override_classes = re.search(r'\s*@overrides\((.+)\)\s*', stack[1][4][0]).group(1)
override_classes = [c.strip() for c in override_classes.split(",")]
check_classes = override_classes
for c in base_classes + check_classes:
classes[c] = get_class(c)
# Verify that the excplicit override class is one of base classes
if explicit_base_classes:
from itertools import product
for bc, cc in product(base_classes, check_classes):
if issubclass(classes[bc], classes[cc]):
break
else:
raise Exception("Excplicit override class '%s' is not a super class of '%s'"
% (explicit_base_classes, class_name))
if not all(hasattr(classes[cls], method.__name__) for cls in check_classes):
for cls in check_classes:
if not hasattr(classes[cls], method.__name__):
raise Exception("Function override '%s' not found in superclass: '%s'\n%s"
% (method.__name__, cls, "File: %s:%s" % (stack[1][1], stack[1][2])))
if not any(hasattr(classes[cls], method.__name__) for cls in check_classes):
raise Exception("Function override '%s' not found in any superclass: '%s'\n%s"
% (method.__name__, check_classes, "File: %s:%s" % (stack[1][1], stack[1][2])))
return method