better annotation asserts

This commit is contained in:
2024-02-04 21:40:00 +01:00
parent e677575530
commit e17d23f92a

View File

@ -1,9 +1,10 @@
from typing import Callable, TypeAliasType
from typing import Any, Callable, TypeAliasType
def has_annotation_callable(f: Callable, param_types: list[type], return_type: object = None):
def has_annotation_callable(f: Callable, param_types: list[type], return_type: object = None) -> bool:
a = f.__annotations__
ret = a['return'] if 'return' in a else None
assert ret == return_type
if ret != return_type:
return False
params = list(a.values())[:-1] if 'return' in a else list(a.values())
for param, param_type in zip(params, param_types):
@ -11,4 +12,24 @@ def has_annotation_callable(f: Callable, param_types: list[type], return_type: o
param = param.__value__
while isinstance(param_type, TypeAliasType):
param_type = param_type.__value__
assert param == param_type
if param != param_type:
return False
return True
def format_type(t) -> str:
while isinstance(t, TypeAliasType):
t = t.__value__
match t:
case type():
return t.__name__.split('.')[-1]
case _:
return str(t).split('.')[-1]
def format_callable(f: Callable[..., Any]) -> str:
a = list(f.__annotations__.values() if 'return' in f.__annotations__ else f.__annotations__.values() + [None])
a = [format_type(t) for t in a]
return f"{f.__name__}({", ".join(a[:-1])}) -> {a[-1]}"
def assert_annotation_callable(f: Callable[..., Any], param_types: list[type], return_type: object = None):
assert has_annotation_callable(f, param_types, return_type), f"{f.__name__}({", ".join(format_type(p) for p in param_types)}) -> {format_type(return_type)} != {format_callable(f)}"