From e17d23f92a569df9322b84cfd9aea98678caf6f5 Mon Sep 17 00:00:00 2001 From: Nils Pukropp Date: Sun, 4 Feb 2024 21:40:00 +0100 Subject: [PATCH] better annotation asserts --- src/utils.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/utils.py b/src/utils.py index 705dc56..b247768 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,9 +1,10 @@ -from typing import Callable, TypeAliasType +from typing import Any, Callable, TypeAliasType -def has_annotation_callable(f: Callable, param_types: list[type], return_type: object = None): +def has_annotation_callable(f: Callable, param_types: list[type], return_type: object = None) -> bool: a = f.__annotations__ ret = a['return'] if 'return' in a else None - assert ret == return_type + if ret != return_type: + return False params = list(a.values())[:-1] if 'return' in a else list(a.values()) for param, param_type in zip(params, param_types): @@ -11,4 +12,24 @@ def has_annotation_callable(f: Callable, param_types: list[type], return_type: o param = param.__value__ while isinstance(param_type, TypeAliasType): param_type = param_type.__value__ - assert param == param_type \ No newline at end of file + if param != param_type: + return False + return True + +def format_type(t) -> str: + while isinstance(t, TypeAliasType): + t = t.__value__ + match t: + case type(): + return t.__name__.split('.')[-1] + case _: + return str(t).split('.')[-1] + +def format_callable(f: Callable[..., Any]) -> str: + a = list(f.__annotations__.values() if 'return' in f.__annotations__ else f.__annotations__.values() + [None]) + a = [format_type(t) for t in a] + return f"{f.__name__}({", ".join(a[:-1])}) -> {a[-1]}" + + +def assert_annotation_callable(f: Callable[..., Any], param_types: list[type], return_type: object = None): + assert has_annotation_callable(f, param_types, return_type), f"{f.__name__}({", ".join(format_type(p) for p in param_types)}) -> {format_type(return_type)} != {format_callable(f)}" \ No newline at end of file