better annotation asserts
This commit is contained in:
29
src/utils.py
29
src/utils.py
@ -1,9 +1,10 @@
|
||||
from typing import Callable, TypeAliasType
|
||||
from typing import Any, Callable, TypeAliasType
|
||||
|
||||
def has_annotation_callable(f: Callable, param_types: list[type], return_type: object = None):
|
||||
def has_annotation_callable(f: Callable, param_types: list[type], return_type: object = None) -> bool:
|
||||
a = f.__annotations__
|
||||
ret = a['return'] if 'return' in a else None
|
||||
assert ret == return_type
|
||||
if ret != return_type:
|
||||
return False
|
||||
|
||||
params = list(a.values())[:-1] if 'return' in a else list(a.values())
|
||||
for param, param_type in zip(params, param_types):
|
||||
@ -11,4 +12,24 @@ def has_annotation_callable(f: Callable, param_types: list[type], return_type: o
|
||||
param = param.__value__
|
||||
while isinstance(param_type, TypeAliasType):
|
||||
param_type = param_type.__value__
|
||||
assert param == param_type
|
||||
if param != param_type:
|
||||
return False
|
||||
return True
|
||||
|
||||
def format_type(t) -> str:
|
||||
while isinstance(t, TypeAliasType):
|
||||
t = t.__value__
|
||||
match t:
|
||||
case type():
|
||||
return t.__name__.split('.')[-1]
|
||||
case _:
|
||||
return str(t).split('.')[-1]
|
||||
|
||||
def format_callable(f: Callable[..., Any]) -> str:
|
||||
a = list(f.__annotations__.values() if 'return' in f.__annotations__ else f.__annotations__.values() + [None])
|
||||
a = [format_type(t) for t in a]
|
||||
return f"{f.__name__}({", ".join(a[:-1])}) -> {a[-1]}"
|
||||
|
||||
|
||||
def assert_annotation_callable(f: Callable[..., Any], param_types: list[type], return_type: object = None):
|
||||
assert has_annotation_callable(f, param_types, return_type), f"{f.__name__}({", ".join(format_type(p) for p in param_types)}) -> {format_type(return_type)} != {format_callable(f)}"
|
Reference in New Issue
Block a user