tests for a and b
This commit is contained in:
425
test_ex6.py
Normal file
425
test_ex6.py
Normal file
@ -0,0 +1,425 @@
|
||||
from dataclasses import dataclass, field, fields, is_dataclass, InitVar
|
||||
from types import NoneType, FunctionType, GenericAlias
|
||||
from typing import get_args, Callable, Iterable, Iterator, Type, Optional, TypeAliasType, ForwardRef, Any, Literal
|
||||
from enum import Enum
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
from functools import wraps
|
||||
import signal
|
||||
import io
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
import subprocess
|
||||
import os
|
||||
import shutil
|
||||
import re
|
||||
import math
|
||||
import inspect
|
||||
import json
|
||||
import ast
|
||||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
|
||||
output = ""
|
||||
|
||||
|
||||
class TimeoutException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Timeout:
|
||||
def __init__(self, seconds, msg=None) -> None:
|
||||
if msg is None:
|
||||
msg = f"Test timed out after {seconds} seconds."
|
||||
self.seconds = seconds
|
||||
self.msg = msg
|
||||
|
||||
def handle_timeout(self, signal, frame):
|
||||
raise TimeoutException(self.msg)
|
||||
|
||||
def __enter__(self):
|
||||
signal.signal(signal.SIGALRM, self.handle_timeout)
|
||||
signal.alarm(self.seconds)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
class IterablePure:
|
||||
def __init__(self, i: Optional[Iterable] = None, repeat: bool = False):
|
||||
if i is not None:
|
||||
self.i = i
|
||||
else:
|
||||
def endless():
|
||||
c = 0
|
||||
while True:
|
||||
yield c
|
||||
c += 1
|
||||
self.i = endless()
|
||||
|
||||
self.repeat = repeat
|
||||
|
||||
def __iter__(self):
|
||||
if self.repeat:
|
||||
while True:
|
||||
yield from self.i
|
||||
else:
|
||||
yield from self.i
|
||||
|
||||
|
||||
def resolve_type_alias(t: Type) -> Type:
|
||||
while isinstance(t, TypeAliasType | GenericAlias) and hasattr(t, "__value__"):
|
||||
if isinstance(t, TypeAliasType):
|
||||
t = t.__value__
|
||||
elif isinstance(t, GenericAlias):
|
||||
t = t.__value__[*t.__args__]
|
||||
return t
|
||||
|
||||
|
||||
def check_annot(function: Callable, expected: list[Type], ret=None):
|
||||
actual = function.__annotations__
|
||||
if "self" in actual:
|
||||
del actual["self"]
|
||||
return_type = actual["return"] if "return" in actual else None
|
||||
actual = list(actual.values())[
|
||||
:-1] if "return" in actual else list(actual.values())
|
||||
for a, e in zip(actual, expected):
|
||||
a = resolve_type_alias(a)
|
||||
e = resolve_type_alias(e)
|
||||
if a != e:
|
||||
return False
|
||||
return resolve_type_alias(ret) == resolve_type_alias(return_type)
|
||||
|
||||
|
||||
def check_lambda_annotations(module, f_name: str, param_types: list[object], return_type: object = None) -> bool:
|
||||
annotation = module.__annotations__[f_name]
|
||||
if annotation.__name__ != 'Callable':
|
||||
return False
|
||||
|
||||
all_types = annotation.__args__
|
||||
types, ret_type = all_types[:-1], all_types[-1]
|
||||
if ret_type != return_type:
|
||||
return False
|
||||
|
||||
for param, param_type in zip(types, param_types):
|
||||
while isinstance(param, TypeAliasType):
|
||||
param = param.__value__
|
||||
while isinstance(param_type, TypeAliasType):
|
||||
param_type = param_type.__value__
|
||||
if param != param_type:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_attributes(o: object, inherited_attr: list[str], new_attr: list[str]) -> bool:
|
||||
attr = [f.name for f in fields(o)]
|
||||
if attr != inherited_attr + new_attr:
|
||||
return False
|
||||
|
||||
for attr in inherited_attr:
|
||||
if attr in o.__annotations__:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_no_if_used(f: Callable) -> bool:
|
||||
tree = ast.parse(inspect.getsource(f))
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.If):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_comprehension_used(f: Callable) -> bool:
|
||||
source = inspect.getsource(f)
|
||||
tree = ast.parse(source)
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_no_sideeffects(f: Callable, *args) -> bool:
|
||||
args_copy = tuple(copy.deepcopy(a) for a in args)
|
||||
f(*args)
|
||||
return args == args_copy
|
||||
|
||||
|
||||
def check_function_used(f: Callable, name: str) -> bool:
|
||||
source = inspect.getsource(f)
|
||||
if source.startswith(4 * " "):
|
||||
source = "\n".join(l[4:] for l in source.split("\n"))
|
||||
return name in [c.func.id for c in ast.walk(ast.parse(source)) if isinstance(c, ast.Call) and isinstance(c.func, ast.Name)]
|
||||
|
||||
|
||||
def check_function_not_used(f: Callable, name: str) -> bool:
|
||||
source = inspect.getsource(f)
|
||||
if source.startswith(4 * " "):
|
||||
source = "\n".join(l[4:] for l in source.split("\n"))
|
||||
return name not in [c.func.id for c in ast.walk(ast.parse(source)) if isinstance(c, ast.Call) and isinstance(c.func, ast.Name)]
|
||||
|
||||
|
||||
def check_pattern_matching_used(f: Callable) -> bool:
|
||||
source = inspect.getsource(f)
|
||||
tree = ast.parse(source)
|
||||
return any(isinstance(node, ast.Match) for node in ast.walk(tree))
|
||||
|
||||
|
||||
def check_gen(g: Iterator, res: Iterable, check_end: bool = True, max: int = 100) -> bool:
|
||||
if not isinstance(g, Iterator):
|
||||
return False
|
||||
for i, x in enumerate(res):
|
||||
el = next(g)
|
||||
if el != x:
|
||||
return False
|
||||
if i >= max:
|
||||
break
|
||||
if check_end:
|
||||
try:
|
||||
next(g)
|
||||
except StopIteration:
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_only_one_line(f: Callable) -> bool:
|
||||
if f.__name__ == '<lambda>':
|
||||
return True
|
||||
inspect.getsource(f)
|
||||
lines = [line for line in inspect.getsource(f).split(
|
||||
'\n') if not line.strip().startswith('#') and line.strip() != '']
|
||||
return len(lines) == 2
|
||||
|
||||
|
||||
def mock_print(*args, sep=None, end=None, **_kwargs):
|
||||
global output
|
||||
if sep is None:
|
||||
sep = " "
|
||||
if end is None:
|
||||
end = "\n"
|
||||
args = [str(a) for a in args]
|
||||
output += str(sep.join(args)) + end
|
||||
|
||||
|
||||
def mock_input(*arg: str):
|
||||
inputs = [s for s in arg]
|
||||
|
||||
def input_function(arg=None):
|
||||
if arg is None:
|
||||
arg = ""
|
||||
nonlocal inputs
|
||||
if len(inputs) == 0:
|
||||
raise Exception("More input asked than given")
|
||||
p = inputs[0]
|
||||
inputs = inputs[1:]
|
||||
mock_print(arg, p, sep="")
|
||||
return p
|
||||
return input_function
|
||||
|
||||
|
||||
def input_error(*arg):
|
||||
raise Exception("Input is not allowed here")
|
||||
|
||||
|
||||
@patch("builtins.print", mock_print)
|
||||
@patch("builtins.input", input_error)
|
||||
def check_global_statements(unit: unittest.TestCase, module_name: str):
|
||||
global output
|
||||
output = ""
|
||||
with Timeout(2):
|
||||
if module_name in sys.modules.keys():
|
||||
del sys.modules[module_name]
|
||||
module = importlib.import_module(module_name)
|
||||
unit.assertEqual(output, "")
|
||||
has_global_asserts = any(isinstance(node, ast.Assert)
|
||||
for node in ast.iter_child_nodes(ast.parse(inspect.getsource(module))))
|
||||
unit.assertFalse(has_global_asserts)
|
||||
|
||||
|
||||
def no_input_test(test_function: Callable) -> Callable:
|
||||
@wraps(test_function)
|
||||
@patch("builtins.print", mock_print)
|
||||
@patch("builtins.input", input_error)
|
||||
def inner(self):
|
||||
global output
|
||||
output = ""
|
||||
with Timeout(2):
|
||||
test_function(self)
|
||||
return inner
|
||||
|
||||
|
||||
class Ex6a_sum_of_subtree(unittest.TestCase):
|
||||
def test6_global_statements(self):
|
||||
check_global_statements(self, "ex6_recursion")
|
||||
|
||||
def test6_test_pattern_matching_sum_of_subtree(self):
|
||||
with Timeout(2):
|
||||
import ex6_recursion
|
||||
check_pattern_matching_used(ex6_recursion.cut_at)
|
||||
|
||||
@no_input_test
|
||||
def test6_annot_sum_of_subtree(self):
|
||||
global output
|
||||
output = ""
|
||||
with Timeout(2):
|
||||
import ex6_recursion
|
||||
self.assertTrue(check_annot(ex6_recursion.sum_of_subtree, [
|
||||
ex6_recursion.BTree[int]], int))
|
||||
|
||||
@no_input_test
|
||||
def test6_sum_of_subtree_basisfall(self):
|
||||
global output
|
||||
output = ""
|
||||
with Timeout(2):
|
||||
import ex6_recursion
|
||||
sum_of_subtree = ex6_recursion.sum_of_subtree
|
||||
tree = None
|
||||
self.assertEqual(sum_of_subtree(tree), 0)
|
||||
self.assertIsNone(tree)
|
||||
|
||||
@no_input_test
|
||||
def test6_sum_of_subtree(self):
|
||||
global output
|
||||
output = ""
|
||||
with Timeout(2):
|
||||
import ex6_recursion
|
||||
sum_of_subtree = ex6_recursion.sum_of_subtree
|
||||
Node = ex6_recursion.Node
|
||||
tree = Node(1, Node(2, Node(3), Node(4)), Node(5, Node(6)))
|
||||
self.assertEqual(sum_of_subtree(tree), 21)
|
||||
self.assertEqual(tree, Node(mark=21, left=Node(mark=9, left=Node(mark=3, left=None, right=None), right=Node(
|
||||
mark=4, left=None, right=None)), right=Node(mark=11, left=Node(mark=6, left=None, right=None), right=None)))
|
||||
single_node_tree = Node(5)
|
||||
self.assertEqual(sum_of_subtree(single_node_tree), 5)
|
||||
simple_tree = Node(1, Node(2), Node(3))
|
||||
total_sum = sum_of_subtree(simple_tree)
|
||||
self.assertEqual(total_sum, 6) # 1 + 2 + 3 = 6
|
||||
self.assertEqual(simple_tree.mark, 6)
|
||||
self.assertEqual(simple_tree.left.mark, 2)
|
||||
self.assertEqual(simple_tree.right.mark, 3)
|
||||
complex_tree = Node(1, Node(2, Node(3), Node(4)), Node(5, Node(6)))
|
||||
total_sum = sum_of_subtree(complex_tree)
|
||||
self.assertEqual(total_sum, 21) # 1 + 2 + 3 + 4 + 5 + 6 = 21
|
||||
self.assertEqual(complex_tree.mark, 21)
|
||||
self.assertEqual(complex_tree.left.mark, 9) # 2 + 3 + 4 = 9
|
||||
self.assertEqual(complex_tree.right.mark, 11) # 5 + 6 = 11
|
||||
self.assertEqual(complex_tree.left.left.mark, 3)
|
||||
self.assertEqual(complex_tree.left.right.mark, 4)
|
||||
self.assertEqual(complex_tree.right.left.mark, 6)
|
||||
|
||||
|
||||
class Ex6b_cut_at(unittest.TestCase):
|
||||
def test6_global_statements(self):
|
||||
check_global_statements(self, "ex6_recursion")
|
||||
|
||||
def test6_test_pattern_matching_cut_at(self):
|
||||
with Timeout(2):
|
||||
import ex6_recursion
|
||||
check_pattern_matching_used(ex6_recursion.cut_at)
|
||||
|
||||
def test6_test_no_side_effects_cut_at(self):
|
||||
with Timeout(2):
|
||||
import ex6_recursion
|
||||
Node = ex6_recursion.Node
|
||||
tree = Node(1, Node(2, Node(3), Node(4)),
|
||||
Node(3, Node(3), Node(6)))
|
||||
check_no_sideeffects(ex6_recursion.cut_at, tree, 3)
|
||||
|
||||
@no_input_test
|
||||
def test6_annot_cut_at(self):
|
||||
global output
|
||||
output = ""
|
||||
with Timeout(2):
|
||||
import ex6_recursion
|
||||
self.assertTrue(check_annot(ex6_recursion.sum_of_subtree, [
|
||||
ex6_recursion.BTree[int]], int))
|
||||
type_params = ex6_recursion.cut_at.__type_params__
|
||||
self.assertEqual(len(type_params), 1)
|
||||
T = type_params[0]
|
||||
self.assertTrue(check_annot(ex6_recursion.cut_at, [
|
||||
ex6_recursion.BTree[T]], ex6_recursion.BTree[T]))
|
||||
|
||||
@no_input_test
|
||||
def test6_cut_at_basisfall(self):
|
||||
global output
|
||||
output = ""
|
||||
with Timeout(2):
|
||||
import ex6_recursion
|
||||
cut_at = ex6_recursion.cut_at
|
||||
tree = None
|
||||
self.assertIsNone(cut_at(tree, 42))
|
||||
|
||||
@no_input_test
|
||||
def test6_cut_at(self):
|
||||
global output
|
||||
output = ""
|
||||
with Timeout(2):
|
||||
import ex6_recursion
|
||||
Node = ex6_recursion.Node
|
||||
cut_at = ex6_recursion.cut_at
|
||||
tree = Node(1, Node(2, Node(3), Node(4)),
|
||||
Node(3, Node(3), Node(6)))
|
||||
self.assertIsNone(cut_at(tree, 1))
|
||||
self.assertEqual(cut_at(tree, 3), Node(mark=1, left=Node(
|
||||
mark=2, left=None, right=Node(mark=4, left=None, right=None)), right=None))
|
||||
single_node_tree = Node(5)
|
||||
new_tree = cut_at(single_node_tree, 10)
|
||||
self.assertIsNotNone(new_tree)
|
||||
self.assertEqual(new_tree.mark, 5)
|
||||
self.assertIsNone(new_tree.left)
|
||||
self.assertIsNone(new_tree.right)
|
||||
tree_without_cut = Node(
|
||||
1, Node(2, Node(3), Node(4)), Node(5, Node(6)))
|
||||
new_tree = cut_at(tree_without_cut, 42) # No node has value 42
|
||||
self.assertIsNotNone(new_tree)
|
||||
self.assertEqual(new_tree.mark, 1)
|
||||
self.assertIsNotNone(new_tree.left)
|
||||
self.assertIsNotNone(new_tree.right)
|
||||
tree_with_cut = Node(
|
||||
1, Node(2, Node(3), Node(4)), Node(3, Node(3), Node(6)))
|
||||
new_tree = cut_at(tree_with_cut, 3) # Cut all nodes with value 3
|
||||
self.assertIsNotNone(new_tree)
|
||||
self.assertEqual(new_tree.mark, 1)
|
||||
self.assertEqual(new_tree.left.mark, 2)
|
||||
self.assertIsNone(new_tree.right) # The right subtree is removed
|
||||
self.assertIsNone(new_tree.left.left) # Node(3) is removed
|
||||
self.assertIsNotNone(new_tree.left.right)
|
||||
self.assertEqual(new_tree.left.right.mark, 4) # Node(4) remains
|
||||
|
||||
|
||||
def get_results():
|
||||
result_data = {}
|
||||
output = io.StringIO("")
|
||||
test_classes = [(name, cls) for name, cls in inspect.getmembers(sys.modules[__name__], lambda x: inspect.isclass(
|
||||
x) and (x.__module__ == __name__)) if unittest.TestCase in cls.__mro__]
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(test_classes[0][1])
|
||||
for other in test_classes[1:]:
|
||||
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(other[1]))
|
||||
testResult = unittest.TextTestRunner(stream=output, verbosity=2).run(suite)
|
||||
result_data["tests"] = [str(test_method) for test in test_classes
|
||||
for test_method in unittest.TestLoader()
|
||||
.loadTestsFromTestCase(test[1])]
|
||||
result_data["failure"] = not testResult.wasSuccessful()
|
||||
result_data["tests_run"] = testResult.testsRun
|
||||
result_data["failures"] = [(str(func), str(err))
|
||||
for func, err in testResult.failures]
|
||||
result_data["errors"] = [(str(func), str(err))
|
||||
for func, err in testResult.errors]
|
||||
return result_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
docker_run = True
|
||||
try:
|
||||
with open("/output/output.json", "w", encoding="utf-8") as f:
|
||||
result = get_results()
|
||||
json.dump(result, f)
|
||||
except FileNotFoundError:
|
||||
docker_run = False
|
||||
if not docker_run:
|
||||
unittest.main()
|
Reference in New Issue
Block a user