Files
test_exam_2025/test_ex6.py
2025-03-21 21:28:32 +01:00

426 lines
14 KiB
Python

from dataclasses import dataclass, field, fields, is_dataclass, InitVar
from types import NoneType, FunctionType, GenericAlias
from typing import get_args, Callable, Iterable, Iterator, Type, Optional, TypeAliasType, ForwardRef, Any, Literal
from enum import Enum
import unittest
from unittest import mock
from unittest.mock import patch
from functools import wraps
import signal
import io
import importlib
import importlib.util
import sys
import subprocess
import os
import shutil
import re
import math
import inspect
import json
import ast
import copy
import functools
import itertools
output = ""
class TimeoutException(Exception):
pass
class Timeout:
def __init__(self, seconds, msg=None) -> None:
if msg is None:
msg = f"Test timed out after {seconds} seconds."
self.seconds = seconds
self.msg = msg
def handle_timeout(self, signal, frame):
raise TimeoutException(self.msg)
def __enter__(self):
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)
def __exit__(self, exc_type, exc_val, exc_tb):
signal.alarm(0)
class IterablePure:
def __init__(self, i: Optional[Iterable] = None, repeat: bool = False):
if i is not None:
self.i = i
else:
def endless():
c = 0
while True:
yield c
c += 1
self.i = endless()
self.repeat = repeat
def __iter__(self):
if self.repeat:
while True:
yield from self.i
else:
yield from self.i
def resolve_type_alias(t: Type) -> Type:
while isinstance(t, TypeAliasType | GenericAlias) and hasattr(t, "__value__"):
if isinstance(t, TypeAliasType):
t = t.__value__
elif isinstance(t, GenericAlias):
t = t.__value__[*t.__args__]
return t
def check_annot(function: Callable, expected: list[Type], ret=None):
actual = function.__annotations__
if "self" in actual:
del actual["self"]
return_type = actual["return"] if "return" in actual else None
actual = list(actual.values())[
:-1] if "return" in actual else list(actual.values())
for a, e in zip(actual, expected):
a = resolve_type_alias(a)
e = resolve_type_alias(e)
if a != e:
return False
return resolve_type_alias(ret) == resolve_type_alias(return_type)
def check_lambda_annotations(module, f_name: str, param_types: list[object], return_type: object = None) -> bool:
annotation = module.__annotations__[f_name]
if annotation.__name__ != 'Callable':
return False
all_types = annotation.__args__
types, ret_type = all_types[:-1], all_types[-1]
if ret_type != return_type:
return False
for param, param_type in zip(types, param_types):
while isinstance(param, TypeAliasType):
param = param.__value__
while isinstance(param_type, TypeAliasType):
param_type = param_type.__value__
if param != param_type:
return False
return True
def check_attributes(o: object, inherited_attr: list[str], new_attr: list[str]) -> bool:
attr = [f.name for f in fields(o)]
if attr != inherited_attr + new_attr:
return False
for attr in inherited_attr:
if attr in o.__annotations__:
return False
return True
def check_no_if_used(f: Callable) -> bool:
tree = ast.parse(inspect.getsource(f))
for node in ast.walk(tree):
if isinstance(node, ast.If):
return False
return True
def check_comprehension_used(f: Callable) -> bool:
source = inspect.getsource(f)
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, (ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp)):
return True
return False
def check_no_sideeffects(f: Callable, *args) -> bool:
args_copy = tuple(copy.deepcopy(a) for a in args)
f(*args)
return args == args_copy
def check_function_used(f: Callable, name: str) -> bool:
source = inspect.getsource(f)
if source.startswith(4 * " "):
source = "\n".join(l[4:] for l in source.split("\n"))
return name in [c.func.id for c in ast.walk(ast.parse(source)) if isinstance(c, ast.Call) and isinstance(c.func, ast.Name)]
def check_function_not_used(f: Callable, name: str) -> bool:
source = inspect.getsource(f)
if source.startswith(4 * " "):
source = "\n".join(l[4:] for l in source.split("\n"))
return name not in [c.func.id for c in ast.walk(ast.parse(source)) if isinstance(c, ast.Call) and isinstance(c.func, ast.Name)]
def check_pattern_matching_used(f: Callable) -> bool:
source = inspect.getsource(f)
tree = ast.parse(source)
return any(isinstance(node, ast.Match) for node in ast.walk(tree))
def check_gen(g: Iterator, res: Iterable, check_end: bool = True, max: int = 100) -> bool:
if not isinstance(g, Iterator):
return False
for i, x in enumerate(res):
el = next(g)
if el != x:
return False
if i >= max:
break
if check_end:
try:
next(g)
except StopIteration:
return True
return False
return True
def check_only_one_line(f: Callable) -> bool:
if f.__name__ == '<lambda>':
return True
inspect.getsource(f)
lines = [line for line in inspect.getsource(f).split(
'\n') if not line.strip().startswith('#') and line.strip() != '']
return len(lines) == 2
def mock_print(*args, sep=None, end=None, **_kwargs):
global output
if sep is None:
sep = " "
if end is None:
end = "\n"
args = [str(a) for a in args]
output += str(sep.join(args)) + end
def mock_input(*arg: str):
inputs = [s for s in arg]
def input_function(arg=None):
if arg is None:
arg = ""
nonlocal inputs
if len(inputs) == 0:
raise Exception("More input asked than given")
p = inputs[0]
inputs = inputs[1:]
mock_print(arg, p, sep="")
return p
return input_function
def input_error(*arg):
raise Exception("Input is not allowed here")
@patch("builtins.print", mock_print)
@patch("builtins.input", input_error)
def check_global_statements(unit: unittest.TestCase, module_name: str):
global output
output = ""
with Timeout(2):
if module_name in sys.modules.keys():
del sys.modules[module_name]
module = importlib.import_module(module_name)
unit.assertEqual(output, "")
has_global_asserts = any(isinstance(node, ast.Assert)
for node in ast.iter_child_nodes(ast.parse(inspect.getsource(module))))
unit.assertFalse(has_global_asserts)
def no_input_test(test_function: Callable) -> Callable:
@wraps(test_function)
@patch("builtins.print", mock_print)
@patch("builtins.input", input_error)
def inner(self):
global output
output = ""
with Timeout(2):
test_function(self)
return inner
class Ex6a_sum_of_subtree(unittest.TestCase):
def test6_global_statements(self):
check_global_statements(self, "ex6_recursion")
def test6_test_pattern_matching_sum_of_subtree(self):
with Timeout(2):
import ex6_recursion
check_pattern_matching_used(ex6_recursion.cut_at)
@no_input_test
def test6_annot_sum_of_subtree(self):
global output
output = ""
with Timeout(2):
import ex6_recursion
self.assertTrue(check_annot(ex6_recursion.sum_of_subtree, [
ex6_recursion.BTree[int]], int))
@no_input_test
def test6_sum_of_subtree_basisfall(self):
global output
output = ""
with Timeout(2):
import ex6_recursion
sum_of_subtree = ex6_recursion.sum_of_subtree
tree = None
self.assertEqual(sum_of_subtree(tree), 0)
self.assertIsNone(tree)
@no_input_test
def test6_sum_of_subtree(self):
global output
output = ""
with Timeout(2):
import ex6_recursion
sum_of_subtree = ex6_recursion.sum_of_subtree
Node = ex6_recursion.Node
tree = Node(1, Node(2, Node(3), Node(4)), Node(5, Node(6)))
self.assertEqual(sum_of_subtree(tree), 21)
self.assertEqual(tree, Node(mark=21, left=Node(mark=9, left=Node(mark=3, left=None, right=None), right=Node(
mark=4, left=None, right=None)), right=Node(mark=11, left=Node(mark=6, left=None, right=None), right=None)))
single_node_tree = Node(5)
self.assertEqual(sum_of_subtree(single_node_tree), 5)
simple_tree = Node(1, Node(2), Node(3))
total_sum = sum_of_subtree(simple_tree)
self.assertEqual(total_sum, 6) # 1 + 2 + 3 = 6
self.assertEqual(simple_tree.mark, 6)
self.assertEqual(simple_tree.left.mark, 2)
self.assertEqual(simple_tree.right.mark, 3)
complex_tree = Node(1, Node(2, Node(3), Node(4)), Node(5, Node(6)))
total_sum = sum_of_subtree(complex_tree)
self.assertEqual(total_sum, 21) # 1 + 2 + 3 + 4 + 5 + 6 = 21
self.assertEqual(complex_tree.mark, 21)
self.assertEqual(complex_tree.left.mark, 9) # 2 + 3 + 4 = 9
self.assertEqual(complex_tree.right.mark, 11) # 5 + 6 = 11
self.assertEqual(complex_tree.left.left.mark, 3)
self.assertEqual(complex_tree.left.right.mark, 4)
self.assertEqual(complex_tree.right.left.mark, 6)
class Ex6b_cut_at(unittest.TestCase):
def test6_global_statements(self):
check_global_statements(self, "ex6_recursion")
def test6_test_pattern_matching_cut_at(self):
with Timeout(2):
import ex6_recursion
check_pattern_matching_used(ex6_recursion.cut_at)
def test6_test_no_side_effects_cut_at(self):
with Timeout(2):
import ex6_recursion
Node = ex6_recursion.Node
tree = Node(1, Node(2, Node(3), Node(4)),
Node(3, Node(3), Node(6)))
check_no_sideeffects(ex6_recursion.cut_at, tree, 3)
@no_input_test
def test6_annot_cut_at(self):
global output
output = ""
with Timeout(2):
import ex6_recursion
self.assertTrue(check_annot(ex6_recursion.sum_of_subtree, [
ex6_recursion.BTree[int]], int))
type_params = ex6_recursion.cut_at.__type_params__
self.assertEqual(len(type_params), 1)
T = type_params[0]
self.assertTrue(check_annot(ex6_recursion.cut_at, [
ex6_recursion.BTree[T], T], ex6_recursion.BTree[T]))
@no_input_test
def test6_cut_at_basisfall(self):
global output
output = ""
with Timeout(2):
import ex6_recursion
cut_at = ex6_recursion.cut_at
tree = None
self.assertIsNone(cut_at(tree, 42))
@no_input_test
def test6_cut_at(self):
global output
output = ""
with Timeout(2):
import ex6_recursion
Node = ex6_recursion.Node
cut_at = ex6_recursion.cut_at
tree = Node(1, Node(2, Node(3), Node(4)),
Node(3, Node(3), Node(6)))
self.assertIsNone(cut_at(tree, 1))
self.assertEqual(cut_at(tree, 3), Node(mark=1, left=Node(
mark=2, left=None, right=Node(mark=4, left=None, right=None)), right=None))
single_node_tree = Node(5)
new_tree = cut_at(single_node_tree, 10)
self.assertIsNotNone(new_tree)
self.assertEqual(new_tree.mark, 5)
self.assertIsNone(new_tree.left)
self.assertIsNone(new_tree.right)
tree_without_cut = Node(
1, Node(2, Node(3), Node(4)), Node(5, Node(6)))
new_tree = cut_at(tree_without_cut, 42) # No node has value 42
self.assertIsNotNone(new_tree)
self.assertEqual(new_tree.mark, 1)
self.assertIsNotNone(new_tree.left)
self.assertIsNotNone(new_tree.right)
tree_with_cut = Node(
1, Node(2, Node(3), Node(4)), Node(3, Node(3), Node(6)))
new_tree = cut_at(tree_with_cut, 3) # Cut all nodes with value 3
self.assertIsNotNone(new_tree)
self.assertEqual(new_tree.mark, 1)
self.assertEqual(new_tree.left.mark, 2)
self.assertIsNone(new_tree.right) # The right subtree is removed
self.assertIsNone(new_tree.left.left) # Node(3) is removed
self.assertIsNotNone(new_tree.left.right)
self.assertEqual(new_tree.left.right.mark, 4) # Node(4) remains
def get_results():
result_data = {}
output = io.StringIO("")
test_classes = [(name, cls) for name, cls in inspect.getmembers(sys.modules[__name__], lambda x: inspect.isclass(
x) and (x.__module__ == __name__)) if unittest.TestCase in cls.__mro__]
suite = unittest.TestLoader().loadTestsFromTestCase(test_classes[0][1])
for other in test_classes[1:]:
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(other[1]))
testResult = unittest.TextTestRunner(stream=output, verbosity=2).run(suite)
result_data["tests"] = [str(test_method) for test in test_classes
for test_method in unittest.TestLoader()
.loadTestsFromTestCase(test[1])]
result_data["failure"] = not testResult.wasSuccessful()
result_data["tests_run"] = testResult.testsRun
result_data["failures"] = [(str(func), str(err))
for func, err in testResult.failures]
result_data["errors"] = [(str(func), str(err))
for func, err in testResult.errors]
return result_data
if __name__ == "__main__":
docker_run = True
try:
with open("/output/output.json", "w", encoding="utf-8") as f:
result = get_results()
json.dump(result, f)
except FileNotFoundError:
docker_run = False
if not docker_run:
unittest.main()