better annotation asserts
This commit is contained in:
29
src/utils.py
29
src/utils.py
@ -1,9 +1,10 @@
|
|||||||
from typing import Callable, TypeAliasType
|
from typing import Any, Callable, TypeAliasType
|
||||||
|
|
||||||
def has_annotation_callable(f: Callable, param_types: list[type], return_type: object = None):
|
def has_annotation_callable(f: Callable, param_types: list[type], return_type: object = None) -> bool:
|
||||||
a = f.__annotations__
|
a = f.__annotations__
|
||||||
ret = a['return'] if 'return' in a else None
|
ret = a['return'] if 'return' in a else None
|
||||||
assert ret == return_type
|
if ret != return_type:
|
||||||
|
return False
|
||||||
|
|
||||||
params = list(a.values())[:-1] if 'return' in a else list(a.values())
|
params = list(a.values())[:-1] if 'return' in a else list(a.values())
|
||||||
for param, param_type in zip(params, param_types):
|
for param, param_type in zip(params, param_types):
|
||||||
@ -11,4 +12,24 @@ def has_annotation_callable(f: Callable, param_types: list[type], return_type: o
|
|||||||
param = param.__value__
|
param = param.__value__
|
||||||
while isinstance(param_type, TypeAliasType):
|
while isinstance(param_type, TypeAliasType):
|
||||||
param_type = param_type.__value__
|
param_type = param_type.__value__
|
||||||
assert param == param_type
|
if param != param_type:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def format_type(t) -> str:
|
||||||
|
while isinstance(t, TypeAliasType):
|
||||||
|
t = t.__value__
|
||||||
|
match t:
|
||||||
|
case type():
|
||||||
|
return t.__name__.split('.')[-1]
|
||||||
|
case _:
|
||||||
|
return str(t).split('.')[-1]
|
||||||
|
|
||||||
|
def format_callable(f: Callable[..., Any]) -> str:
|
||||||
|
a = list(f.__annotations__.values() if 'return' in f.__annotations__ else f.__annotations__.values() + [None])
|
||||||
|
a = [format_type(t) for t in a]
|
||||||
|
return f"{f.__name__}({", ".join(a[:-1])}) -> {a[-1]}"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_annotation_callable(f: Callable[..., Any], param_types: list[type], return_type: object = None):
|
||||||
|
assert has_annotation_callable(f, param_types, return_type), f"{f.__name__}({", ".join(format_type(p) for p in param_types)}) -> {format_type(return_type)} != {format_callable(f)}"
|
Reference in New Issue
Block a user