An idea
A member of my team suggested that sobrecargar could become, aside from the library, a tool - particularly, a transpiler that takes code written with @sobrecargar / @overload and produces either:
- .pyi stubs that contain the overloaded signatures, paired with a transpiled version of the code that has a single implementation that inlines runtime type checking specialized for the cases relevant to each signature; or
- a transpiled version of the code that "auto-refactors" the overloads into separate functions and updates calling code acordingly
She's working on a prototype but I'm not completely sold on the approach. I generally dislike introducing build steps for interpreted languages as not having a build step is a feature of those languages. But the feedback is limited to my team, so i'd like to ask:
- Does this idea seem useful? do the improvements to performance (again, in te context of an interpreted, garbage collected and dynamicaly typed language) justify the added build step?
"""
===============
sobrecargar.py
===============
Method and function overloading for Python 3.
* Project Repository: https://github.com/Hernanatn/sobrecargar.py
* Documentation: https://github.com/hernanatn/sobrecargar.py/blob/master/README.MD
Copyright (c) 2023 Hernán A. Teszkiewicz Novick. Distributed under the MIT license.
Hernan ATN | [email protected] ==============
"""
__author__ = "Hernan ATN"
__copyright__ = "(c) 2023, Hernán A. Teszkiewicz Novick."
__license__ = "MIT"
__version__ = "3.1.1"
__email__ = "[email protected]"
__all__ = ['sobrecargar', 'overload']
from inspect import signature as get_signature, Signature, Parameter, currentframe as current_frame, getframeinfo as get_frame_info
from types import MappingProxyType
from typing import Callable, TypeVar, Iterator, ItemsView, Any, List, Tuple, Iterable, Generic, Optional, Unpack, Union, get_origin, get_args
from collections.abc import Sequence, Mapping
from collections import namedtuple
from functools import partial
from sys import modules, version_info
from itertools import zip_longest
from os.path import abspath as absolute_path
if version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self
if version_info < (3, 9):
raise ImportError("Module 'sobrecargar' requires Python 3.9 or higher.")
class _DeferredOverload(type):
"""Metaclass that handles deferred initialization of overloads, existing only to handle the case of overloading class/instance methods.
When decorating a function/method with @overload, instead of creating an instance of `overload`, an instance of `_DeferredOverload` is created,
which behaves *as if* it were `overload` and retains all the state needed to build the real instance later, only when the overloaded
function or method is called for the first time.
"""
def __init__(cls, name, bases, namespace):
super().__init__(name, bases, namespace)
class _Deferred(object):
def __new__(cls_inner, positional, keywords):
obj = cls.__new__(cls, *positional, *keywords)
if not hasattr(obj, "_Deferred__initial_params") or getattr(obj, "_Deferred__initial_params") is None:
obj.__initial_params = []
obj.__initial_params.append((positional, keywords))
obj.__class__ = cls_inner
return obj
def __initialize__(self):
initial = self.__initial_params
del self.__dict__['_Deferred__initial_params']
super().__setattr__('__class__', cls)
for positional, keywords in initial:
self.__init__(*positional, **keywords)
def __get__(self, obj, obj_type):
self.__initialize__()
return self.__get__(obj, obj_type)
def __call__(self, *positional, **keywords):
self.__initialize__()
return self.__call__(*positional, **keywords)
_Deferred.__name__ = f"{cls.__name__}_Deferred"
_Deferred.__qualname__ = f"{cls.__qualname__}_Deferred"
cls._Deferred = _Deferred
def __call__(cls, *positional, **keywords):
return cls._Deferred(positional, keywords)
def __instancecheck__(cls, instance):
return super().__instancecheck__(instance) or isinstance(instance, cls._Deferred)
def __subclasscheck__(cls, subclass):
return super().__subclasscheck__(subclass) or (subclass == cls._Deferred)
import __main__
class _sobrecargar(metaclass=_DeferredOverload):
"""
Class that acts as a decorator for functions, allowing multiple
versions of a function or method to be defined with different sets of parameters and types.
This enables function overloading (i.e., dynamic dispatch based on the provided arguments).
Class Attributes:
_overloaded (dict): A dictionary that keeps a record of '_overload' instances created
for each decorated function or method. The keys are the names of the functions or methods,
and the values are the corresponding '_overload' instances.
Instance Attributes:
overloads (dict): A dictionary storing the defined overloads for the decorated function or method.
The keys are Signature objects representing the overload signatures, and the values are the
corresponding functions or methods.
__cache (dict): A dictionary that maps parameter type combinations in the call to the underlying
function object to be called. A simple optimization that reduces the cost for subsequent calls,
which is very useful in loops.
__debug (Callable): A lambda that prints diagnostic information if the overload is initialized in debug mode,
otherwise it does nothing.
"""
_overloaded : dict[str, '_overload'] = {}
def __new__(cls, function: Callable, *positional, **keywords) -> '_overload':
"""
Constructor. Creates a unique instance per function name.
Args:
function (Callable): The function or method to be decorated.
Returns:
_overload: The instance of the '_overload' class associated with the provided function name.
"""
name: str = cls.__full_name(function)
if name not in cls._overloaded.keys():
cls._overloaded[name] = super().__new__(_overload)
cls._overloaded[name].__name = function.__name__
cls._overloaded[name].__full_name = name
return cls._overloaded[name]
def __init__(self, function: Callable, *, cache: bool = True, debug: bool = False) -> None:
"""
Initializer. Responsible for initializing the overload dictionary (if not already present)
and registering the current version of the decorated function or method.
Args:
function (Callable): The decorated function or method.
cache (bool): Option indicating whether the overload should use caching.
debug (bool): Option indicating whether to initialize in debug mode.
"""
if not hasattr(self, 'overloads'):
self.overloads : dict[Signature, Callable] = {}
self.__cache : Optional[dict[tuple[tuple[type[Any], ...], tuple[tuple[str, type[Any]]]], Callable[..., Any]]] = (
self.__cache if hasattr(self, "_overload__cache") and self.__cache is not None else {} if cache else None
)
self.__debug = (
self.__debug if hasattr(self, "_overload__debug") and self.__debug is not None
else (lambda msg: print(f"[DEBUG] {msg}") if debug else lambda msg: None)
)
signature_obj: Signature
underlying_function: Callable
signature_obj, underlying_function = _overload.__unwrap(function)
signature_obj, underlying_function = _overload.__unwrap(function)
self.__debug(f"Overload registered for: {self.__name}. Signature: {signature_obj}")
if type(self).__is_method(function):
cls: type = type(self).__get_class(function)
self.__debug(f"{self.__name} is a method of {cls}.")
self.__debug(f"{self.__name} is a method of {cls}.")
for ancestor in cls.__mro__:
for base in ancestor.__bases__:
if base is object: break
full_method_name: str = f"{base.__module__}.{base.__name__}.{function.__name__}"
if full_method_name in type(self)._overloaded.keys():
base_overload: '_overload' = type(self)._overloaded[full_method_name]
self.overloads.update(base_overload.overloads)
self.overloads[signature_obj] = underlying_function
if not self.__doc__: self.__doc__ = ""
self.__doc__ += f"\n{function.__doc__ or ''}"
def __call__(self, *positional, **keywords) -> Any:
"""
Method that allows the decorator instance to be called as a function.
The core engine of the module. It validates the provided parameters and builds a tuple
of 'candidates' from functions that match the provided parameters. It prioritizes the overload
that best fits the types and number of arguments. If several candidates match, it propagates the result
of the most specific one.
If caching is enabled, the selected function is stored for later calls.
Args:
*positional: Positional arguments passed to the function or method.
**keywords: Keyword arguments passed to the function or method.
Returns:
Any: The result of the selected version of the decorated function or method.
Raises:
TypeError: If no compatible overload exists for the provided parameters.
"""
if self.__cache is not None:
parameters = (
tuple(type(p) for p in positional),
tuple((n, type(v)) for n, v in keywords.items()),
)
if parameters in self.__cache.keys():
func = self.__cache.get(parameters)
self.__debug(
f"Cached call for {self.__name}"
f"\n\tProvided positional parameters: {', '.join(f'{type(p).__name__} [{repr(p)}]' for p in positional)}"
f"\n\tProvided keyword parameters: {', '.join(f'{k}: {type(v).__name__} [{v}]' for k, v in keywords.items())}"
f"\n\tCached signature: {get_signature(func)}"
)
return func(*positional, **keywords)
self.__debug(
f"Starting candidate selection for {self.__name}"
f"\n\tProvided positional parameters: {', '.join(f'{type(p).__name__} [{repr(p)}]' for p in positional)}"
f"\n\tProvided keyword parameters: {', '.join(f'{k}: {type(v).__name__} [{v}]' for k, v in keywords.items())}"
f"\n\tSupported overloads:"
f"\n" + "\n".join(
f"\t- {', '.join(f'{v}' for v in dict(sig.parameters).values())}"
for sig in self.overloads.keys()
)
)
_C = TypeVar("_C", bound=Sequence)
_T = TypeVar("_T", bound=Any)
Candidate = namedtuple('Candidate', ['score', 'function_object', "function_signature"])
candidates: List[Candidate] = []
def validate_container(value: _C, container_param: Parameter) -> int | bool:
type_score: int = 0
container_annotation = container_param.annotation
if not hasattr(container_annotation, "__origin__") or not hasattr(container_annotation, "__args__"):
type_score += 1
return type_score
if get_origin(container_annotation) is Union:
if not issubclass(type(value), get_args(container_annotation)):
return False
elif not issubclass(type(value), container_annotation.__origin__):
return False
container_args: Tuple[type[_C]] = container_annotation.__args__
has_ellipsis: bool = Ellipsis in container_args
has_single_type: bool = len(container_args) == 1 or has_ellipsis
if has_ellipsis:
aux_list: list = list(container_args)
aux_list[1] = aux_list[0]
container_args = tuple(aux_list)
type_iterator: Iterator
if has_single_type:
type_iterator = zip_longest((type(t) for t in value), container_args, fillvalue=container_args[0])
else:
type_iterator = zip_longest((type(t) for t in value), container_args)
if not issubclass(type(value[0]), container_args[0]):
return False
for received_type, expected_type in type_iterator:
if expected_type == None:
return False
if received_type == expected_type:
type_score += 2
elif issubclass(received_type, expected_type):
type_score += 1
else:
return False
return type_score
def validate_param_type(value: _T, func_param: Parameter) -> int | bool:
type_score: int = 0
expected_type = func_param.annotation
received_type: type[_T] = type(value)
is_untyped: bool = (expected_type == Any)
default_value: _T = func_param.default
is_null: bool = value is None and default_value is None
is_default: bool = value is None and default_value is not func_param.empty
param_is_self: bool = func_param.name == 'self' or func_param.name == 'cls'
param_is_var_pos: bool = func_param.kind == func_param.VAR_POSITIONAL
param_is_var_kw: bool = func_param.kind == func_param.VAR_KEYWORD
param_is_variable: bool = param_is_var_pos or param_is_var_kw
param_is_union: bool = hasattr(expected_type, "__origin__") and get_origin(expected_type) is Union
param_is_container: bool = (hasattr(expected_type, "__origin__") or (issubclass(expected_type, Sequence) and not issubclass(expected_type, str)) or issubclass(expected_type, Mapping)) and not param_is_union
numeric_compatible: bool = (issubclass(expected_type, complex) and issubclass(received_type, (float, int))
or issubclass(expected_type, float) and issubclass(received_type, int))
"""Check the special case where typed Python diverges from untyped Python.
See: https://typing.python.org/en/latest/spec/special-types.html#special-cases-for-float-and-complex
"""
is_different_type: bool
if param_is_variable and param_is_container and param_is_var_pos:
expected_type = expected_type.__args__[0] if get_origin(type(expected_type)) is Unpack else expected_type
is_different_type = not issubclass(received_type, expected_type.__args__[0])
elif param_is_variable and param_is_container and param_is_var_kw:
expected_type = expected_type.__args__[0] if get_origin(type(expected_type)) is Unpack else expected_type
is_different_type = not issubclass(received_type, expected_type.__args__[1])
elif param_is_union:
is_different_type = not issubclass(received_type, get_args(expected_type))
elif param_is_container:
is_different_type = not validate_container(value, func_param)
else:
is_different_type = not (
issubclass(received_type, expected_type)
or numeric_compatible
)
if not is_untyped and not is_null and not param_is_self and not is_default and is_different_type:
return False
elif param_is_variable and not param_is_container:
type_score += 1
else:
if param_is_variable and param_is_container and param_is_var_pos:
if received_type == expected_type.__args__[0]:
type_score += 3
elif issubclass(received_type, expected_type.__args__[0]):
type_score += 1
elif param_is_variable and param_is_container and param_is_var_kw:
if received_type == expected_type.__args__[1]:
type_score += 3
elif issubclass(received_type, expected_type.__args__[1]):
type_score += 1
elif param_is_container:
type_score += validate_container(value, func_param)
elif received_type == expected_type:
type_score += 5
elif issubclass(received_type, expected_type):
type_score += 4
elif numeric_compatible:
type_score += 3
elif is_default:
type_score += 2
elif is_null or param_is_self or is_untyped:
type_score += 1
return type_score
def validate_signature(func_params: MappingProxyType[str, Parameter], positional_count: int, positional_iterator: Iterator[tuple], keyword_view: ItemsView) -> int | bool:
signature_score: int = 0
this_score: int | bool
for pos_value, pos_name in positional_iterator:
this_score = validate_param_type(pos_value, func_params[pos_name])
if this_score:
signature_score += this_score
else:
return False
for key_name, key_value in keyword_view:
if key_name not in func_params and type(self).__has_var_kw(func_params):
var_kw: Optional[Parameter] = next((p for p in func_params.values() if p.kind == p.VAR_KEYWORD), None)
if var_kw is not None:
this_score = validate_param_type(key_value, var_kw)
else:
return False
elif key_name not in func_params:
return False
else:
this_score = validate_param_type(key_value, func_params[key_name])
if this_score:
signature_score += this_score
else:
return False
return signature_score
for sig, function in self.overloads.items():
length_score: int = 0
func_params: MappingProxyType[str, Parameter] = sig.parameters
positional_count: int = len(func_params) if type(self).__has_var_pos(func_params) else len(positional)
keyword_count: int = len({key: keywords[key] for key in func_params if key in keywords}) if (type(self).__has_var_kw(func_params) or type(self).__has_only_kw(func_params)) else len(keywords)
default_count: int = type(self).__has_default(func_params) if type(self).__has_default(func_params) else 0
positional_iterator: Iterator[tuple[Any, str]] = zip(positional, list(func_params)[:positional_count])
keyword_view: ItemsView[str, Any] = keywords.items()
if (len(func_params) == 0 or not (type(self).__has_variables(func_params) or type(self).__has_default(func_params))) and len(func_params) != (len(positional) + len(keywords)):
continue
if len(func_params) - (positional_count + keyword_count) == 0 and not (type(self).__has_variables(func_params) or type(self).__has_default(func_params)):
length_score += 3
elif len(func_params) - (positional_count + keyword_count) == 0:
length_score += 2
elif (0 <= len(func_params) - (positional_count + keyword_count) <= default_count) or (type(self).__has_variables(func_params)):
length_score += 1
else:
continue
signature_validation_score: int | bool = validate_signature(func_params, positional_count, positional_iterator, keyword_view)
if signature_validation_score:
candidate: Candidate = Candidate(score=(length_score + 2 * signature_validation_score), function_object=function, function_signature=sig)
candidates.append(candidate)
else:
continue
if candidates:
if len(candidates) > 1:
candidates.sort(key=lambda c: c.score, reverse=True)
self.__debug(f"Candidates: \n\t- " + "\n\t- ".join(' | '.join([str(i) for i in c if not callable(i)]) for c in candidates))
best_function = candidates[0].function_object
if self.__cache is not None:
parameters = (
tuple(type(p) for p in positional),
tuple(tuple(n, type(v)) for n, v in keywords.items()),
)
self.__cache.update({
parameters: best_function
})
return best_function(*positional, **keywords)
else:
call_frame = current_frame().f_back
frame_info = get_frame_info(call_frame)
if "return self.__call__(*positional,**keywords)" in frame_info.code_context and frame_info.function == "__call__":
frame_info = call_frame.f_back
raise TypeError(
f"[ERROR] Could not call {function.__name__} in {absolute_path(frame_info.filename)}:{frame_info.lineno} "
f"\n\tProvided parameters"
f"\n\t- Positional: {', '.join(p.__name__ for p in map(type, positional))}"
f"\n\t- Keywords: {', '.join(f'{k}: {type(v).__name__}' for k, v in keywords.items())}"
f"\n"
f"\n\tSupported overloads:\n" +
"\n".join(
f"\t- {', '.join(f'{v}' for v in dict(sig.parameters).values())}"
for sig in self.overloads.keys()
)
)
def __get__(self, obj, obj_type):
class OverloadedMethod:
__doc__ = self.__doc__
__call__ = partial(self.__call__, obj) if obj is not None else partial(self.__call__, obj_type)
return OverloadedMethod()
@staticmethod
def __unwrap(function: Callable) -> Tuple[Signature, Callable]:
while hasattr(function, '__func__'):
function = function.__func__
while hasattr(function, '__wrapped__'):
function = function.__wrapped__
sig: Signature = get_signature(function)
return (sig, function)
@staticmethod
def __full_name(function: Callable) -> str:
return f"{function.__module__}.{function.__qualname__}"
@staticmethod
def __is_method(function: Callable) -> bool:
return function.__name__ != function.__qualname__ and "<locals>" not in function.__qualname__.split(".")
@staticmethod
def __is_nested(function: Callable) -> bool:
return function.__name__ != function.__qualname__ and "<locals>" in function.__qualname__.split(".")
@staticmethod
def __get_class(method: Callable) -> type:
return getattr(modules[method.__module__], method.__qualname__.split(".")[0])
@staticmethod
def __has_variables(func_params: MappingProxyType[str, Parameter]) -> bool:
for param in func_params.values():
if _overload.__has_var_kw(func_params) or _overload.__has_var_pos(func_params):
return True
return False
@staticmethod
def __has_var_pos(func_params: MappingProxyType[str, Parameter]) -> bool:
for param in func_params.values():
if param.kind == Parameter.VAR_POSITIONAL:
return True
return False
@staticmethod
def __has_var_kw(func_params: MappingProxyType[str, Parameter]) -> bool:
for param in func_params.values():
if param.kind == Parameter.VAR_KEYWORD:
return True
return False
@staticmethod
def __has_default(func_params: MappingProxyType[str, Parameter]) -> int | bool:
default_count: int = 0
for param in func_params.values():
if param.default != param.empty:
default_count += 1
return default_count if default_count else False
@staticmethod
def __has_only_kw(func_params: MappingProxyType[str, Parameter]) -> bool:
for param in func_params.values():
if param.kind == Parameter.KEYWORD_ONLY:
return True
return False
def sobrecargar(*args, cache: bool = True, debug: bool = False) -> Callable:
"""Function decorator that transforms functions into overloads.
**Parameters:**
:param Callable f: the function to be overloaded.
:param bool cache: indicates whether to cache the dispatch result. Default: True.
:param bool debug: indicates whether to print diagnostic information. Default: False.
**Returns:**
:param Callable: the decorator.
---
"""
if args and callable(args[0]):
return _sobrecargar(args[0], cache=cache, debug=debug)
def decorator(f):
if debug:
frame_info = get_frame_info(current_frame().f_back)
print(
f"[DEBUG] Function overload."
f"\n\t{f.__name__} in {absolute_path(frame_info.filename)}:{frame_info.lineno}"
f"\n\t- cache = {cache}"
f"\n\t- debug = {debug}"
)
return _sobrecargar(f, cache=cache, debug=debug)
return decorator
# Alias
overload = sobrecargar
if __name__ == '__main__':
print(__doc__)
Installation
You can install sobrecargar using pip:
pip install sobrecargar
This example is far from perfect. Some stricter **kwargs handling (instead of implicitly postitioned *args) would improve it a lot- at the cost of added runtime overhead. But the core remains. The solution that typing.overload offers by itself amounts to a standard template for LSPs and type checkers. It does not give any type assurances (those are deferred for the actual unified implementation), it parasites duck-typing in such away that the actual implementation often gets muddied with ad-hoc type checking, it requieres extra boilerplate and as it's only a hint, when the exact same set of instructions can't handle the diverse inputs, for the function to "plow ahead regardless" (as per the example), branches need be introduced, imposing a runtime cost and potentially leading to brittle and coupled code.
As I see it, the greatest problem with typing.overload is it that it lies. It offers hints for multiple signatures but makes no guaratees about the implementation that actually has to handle them. Indeed calling typing.overloaded code often leads to a lot of effort trying to understand heavily branched ad hoc type checks with little context inside catch-all actual implementations that can, (andand often do), simply not handle the set of cases their signatures say they should. singledispatch trys to address this, but it ultimatelly works for a very narrow set of use cases, and introduces a typing sintax that differs from the already stablished type-hint syntax.
So...
Typed Python isn't going anywhere, and within typetyped codebases the pattern offers a lot. Evidence of that can be found in the fact that even with the current state of support, overloading is widely used both in the standard and in popular 3rd party libraryslibraries.
So a version of this patterns that: ensuresThe pattern would benefit from type correctness assurances, simplifies thesimplified overload definition of overloads, applies a consistant set of rules for overload selection,dispatch resolution (in oposition to ad hoc type checks) and offers a better debuging expirience, while trying to minimize overhead, should be an upgrade over the current state of thingsdebugging experience.