from dataclasses import dataclass, field, fields, is_dataclass, InitVar from types import NoneType, FunctionType, GenericAlias from typing import get_args, Callable, Iterable, Iterator, Type, Optional, TypeAliasType, ForwardRef, Any, Literal from enum import Enum import unittest from unittest import mock from unittest.mock import patch from functools import wraps import signal import io import importlib import importlib.util import sys import subprocess import os import shutil import re import math import inspect import json import ast import copy import functools import itertools output = "" class TimeoutException(Exception): pass class Timeout: def __init__(self, seconds, msg=None) -> None: if msg is None: msg = f"Test timed out after {seconds} seconds." self.seconds = seconds self.msg = msg def handle_timeout(self, signal, frame): raise TimeoutException(self.msg) def __enter__(self): signal.signal(signal.SIGALRM, self.handle_timeout) signal.alarm(self.seconds) def __exit__(self, exc_type, exc_val, exc_tb): signal.alarm(0) class IterablePure: def __init__(self, i: Optional[Iterable] = None, repeat: bool = False): if i is not None: self.i = i else: def endless(): c = 0 while True: yield c c += 1 self.i = endless() self.repeat = repeat def __iter__(self): if self.repeat: while True: yield from self.i else: yield from self.i def resolve_type_alias(t: Type) -> Type: while isinstance(t, TypeAliasType | GenericAlias) and hasattr(t, "__value__"): if isinstance(t, TypeAliasType): t = t.__value__ elif isinstance(t, GenericAlias): t = t.__value__[*t.__args__] return t def check_annot(function: Callable, expected: list[Type], ret=None): actual = function.__annotations__ if "self" in actual: del actual["self"] return_type = actual["return"] if "return" in actual else None actual = list(actual.values())[ :-1] if "return" in actual else list(actual.values()) for a, e in zip(actual, expected): a = resolve_type_alias(a) e = resolve_type_alias(e) if a != e: return False return resolve_type_alias(ret) == resolve_type_alias(return_type) def check_lambda_annotations(module, f_name: str, param_types: list[object], return_type: object = None) -> bool: annotation = module.__annotations__[f_name] if annotation.__name__ != 'Callable': return False all_types = annotation.__args__ types, ret_type = all_types[:-1], all_types[-1] if ret_type != return_type: return False for param, param_type in zip(types, param_types): while isinstance(param, TypeAliasType): param = param.__value__ while isinstance(param_type, TypeAliasType): param_type = param_type.__value__ if param != param_type: return False return True def check_attributes(o: object, inherited_attr: list[str], new_attr: list[str]) -> bool: attr = [f.name for f in fields(o)] if attr != inherited_attr + new_attr: return False for attr in inherited_attr: if attr in o.__annotations__: return False return True def check_no_if_used(f: Callable) -> bool: tree = ast.parse(inspect.getsource(f)) for node in ast.walk(tree): if isinstance(node, ast.If): return False return True def check_comprehension_used(f: Callable) -> bool: source = inspect.getsource(f) tree = ast.parse(source) for node in ast.walk(tree): if isinstance(node, (ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp)): return True return False def check_no_sideeffects(f: Callable, *args) -> bool: args_copy = tuple(copy.deepcopy(a) for a in args) f(*args) return args == args_copy def check_function_used(f: Callable, name: str) -> bool: source = inspect.getsource(f) if source.startswith(4 * " "): source = "\n".join(l[4:] for l in source.split("\n")) return name in [c.func.id for c in ast.walk(ast.parse(source)) if isinstance(c, ast.Call) and isinstance(c.func, ast.Name)] def check_function_not_used(f: Callable, name: str) -> bool: source = inspect.getsource(f) if source.startswith(4 * " "): source = "\n".join(l[4:] for l in source.split("\n")) return name not in [c.func.id for c in ast.walk(ast.parse(source)) if isinstance(c, ast.Call) and isinstance(c.func, ast.Name)] def check_pattern_matching_used(f: Callable) -> bool: source = inspect.getsource(f) tree = ast.parse(source) return any(isinstance(node, ast.Match) for node in ast.walk(tree)) def check_gen(g: Iterator, res: Iterable, check_end: bool = True, max: int = 100) -> bool: if not isinstance(g, Iterator): return False for i, x in enumerate(res): el = next(g) if el != x: return False if i >= max: break if check_end: try: next(g) except StopIteration: return True return False return True def check_only_one_line(f: Callable) -> bool: if f.__name__ == '': return True inspect.getsource(f) lines = [line for line in inspect.getsource(f).split( '\n') if not line.strip().startswith('#') and line.strip() != ''] return len(lines) == 2 def mock_print(*args, sep=None, end=None, **_kwargs): global output if sep is None: sep = " " if end is None: end = "\n" args = [str(a) for a in args] output += str(sep.join(args)) + end def mock_input(*arg: str): inputs = [s for s in arg] def input_function(arg=None): if arg is None: arg = "" nonlocal inputs if len(inputs) == 0: raise Exception("More input asked than given") p = inputs[0] inputs = inputs[1:] mock_print(arg, p, sep="") return p return input_function def input_error(*arg): raise Exception("Input is not allowed here") @patch("builtins.print", mock_print) @patch("builtins.input", input_error) def check_global_statements(unit: unittest.TestCase, module_name: str): global output output = "" with Timeout(2): if module_name in sys.modules.keys(): del sys.modules[module_name] module = importlib.import_module(module_name) unit.assertEqual(output, "") has_global_asserts = any(isinstance(node, ast.Assert) for node in ast.iter_child_nodes(ast.parse(inspect.getsource(module)))) unit.assertFalse(has_global_asserts) def no_input_test(test_function: Callable) -> Callable: @wraps(test_function) @patch("builtins.print", mock_print) @patch("builtins.input", input_error) def inner(self): global output output = "" with Timeout(2): test_function(self) return inner class Ex6a_sum_of_subtree(unittest.TestCase): def test6_global_statements(self): check_global_statements(self, "ex6_recursion") def test6_test_pattern_matching_sum_of_subtree(self): with Timeout(2): import ex6_recursion check_pattern_matching_used(ex6_recursion.cut_at) @no_input_test def test6_annot_sum_of_subtree(self): global output output = "" with Timeout(2): import ex6_recursion self.assertTrue(check_annot(ex6_recursion.sum_of_subtree, [ ex6_recursion.BTree[int]], int)) @no_input_test def test6_sum_of_subtree_basisfall(self): global output output = "" with Timeout(2): import ex6_recursion sum_of_subtree = ex6_recursion.sum_of_subtree tree = None self.assertEqual(sum_of_subtree(tree), 0) self.assertIsNone(tree) @no_input_test def test6_sum_of_subtree(self): global output output = "" with Timeout(2): import ex6_recursion sum_of_subtree = ex6_recursion.sum_of_subtree Node = ex6_recursion.Node tree = Node(1, Node(2, Node(3), Node(4)), Node(5, Node(6))) self.assertEqual(sum_of_subtree(tree), 21) self.assertEqual(tree, Node(mark=21, left=Node(mark=9, left=Node(mark=3, left=None, right=None), right=Node( mark=4, left=None, right=None)), right=Node(mark=11, left=Node(mark=6, left=None, right=None), right=None))) single_node_tree = Node(5) self.assertEqual(sum_of_subtree(single_node_tree), 5) simple_tree = Node(1, Node(2), Node(3)) total_sum = sum_of_subtree(simple_tree) self.assertEqual(total_sum, 6) # 1 + 2 + 3 = 6 self.assertEqual(simple_tree.mark, 6) self.assertEqual(simple_tree.left.mark, 2) self.assertEqual(simple_tree.right.mark, 3) complex_tree = Node(1, Node(2, Node(3), Node(4)), Node(5, Node(6))) total_sum = sum_of_subtree(complex_tree) self.assertEqual(total_sum, 21) # 1 + 2 + 3 + 4 + 5 + 6 = 21 self.assertEqual(complex_tree.mark, 21) self.assertEqual(complex_tree.left.mark, 9) # 2 + 3 + 4 = 9 self.assertEqual(complex_tree.right.mark, 11) # 5 + 6 = 11 self.assertEqual(complex_tree.left.left.mark, 3) self.assertEqual(complex_tree.left.right.mark, 4) self.assertEqual(complex_tree.right.left.mark, 6) class Ex6b_cut_at(unittest.TestCase): def test6_global_statements(self): check_global_statements(self, "ex6_recursion") def test6_test_pattern_matching_cut_at(self): with Timeout(2): import ex6_recursion check_pattern_matching_used(ex6_recursion.cut_at) def test6_test_no_side_effects_cut_at(self): with Timeout(2): import ex6_recursion Node = ex6_recursion.Node tree = Node(1, Node(2, Node(3), Node(4)), Node(3, Node(3), Node(6))) check_no_sideeffects(ex6_recursion.cut_at, tree, 3) @no_input_test def test6_annot_cut_at(self): global output output = "" with Timeout(2): import ex6_recursion self.assertTrue(check_annot(ex6_recursion.sum_of_subtree, [ ex6_recursion.BTree[int]], int)) type_params = ex6_recursion.cut_at.__type_params__ self.assertEqual(len(type_params), 1) T = type_params[0] self.assertTrue(check_annot(ex6_recursion.cut_at, [ ex6_recursion.BTree[T]], ex6_recursion.BTree[T])) @no_input_test def test6_cut_at_basisfall(self): global output output = "" with Timeout(2): import ex6_recursion cut_at = ex6_recursion.cut_at tree = None self.assertIsNone(cut_at(tree, 42)) @no_input_test def test6_cut_at(self): global output output = "" with Timeout(2): import ex6_recursion Node = ex6_recursion.Node cut_at = ex6_recursion.cut_at tree = Node(1, Node(2, Node(3), Node(4)), Node(3, Node(3), Node(6))) self.assertIsNone(cut_at(tree, 1)) self.assertEqual(cut_at(tree, 3), Node(mark=1, left=Node( mark=2, left=None, right=Node(mark=4, left=None, right=None)), right=None)) single_node_tree = Node(5) new_tree = cut_at(single_node_tree, 10) self.assertIsNotNone(new_tree) self.assertEqual(new_tree.mark, 5) self.assertIsNone(new_tree.left) self.assertIsNone(new_tree.right) tree_without_cut = Node( 1, Node(2, Node(3), Node(4)), Node(5, Node(6))) new_tree = cut_at(tree_without_cut, 42) # No node has value 42 self.assertIsNotNone(new_tree) self.assertEqual(new_tree.mark, 1) self.assertIsNotNone(new_tree.left) self.assertIsNotNone(new_tree.right) tree_with_cut = Node( 1, Node(2, Node(3), Node(4)), Node(3, Node(3), Node(6))) new_tree = cut_at(tree_with_cut, 3) # Cut all nodes with value 3 self.assertIsNotNone(new_tree) self.assertEqual(new_tree.mark, 1) self.assertEqual(new_tree.left.mark, 2) self.assertIsNone(new_tree.right) # The right subtree is removed self.assertIsNone(new_tree.left.left) # Node(3) is removed self.assertIsNotNone(new_tree.left.right) self.assertEqual(new_tree.left.right.mark, 4) # Node(4) remains def get_results(): result_data = {} output = io.StringIO("") test_classes = [(name, cls) for name, cls in inspect.getmembers(sys.modules[__name__], lambda x: inspect.isclass( x) and (x.__module__ == __name__)) if unittest.TestCase in cls.__mro__] suite = unittest.TestLoader().loadTestsFromTestCase(test_classes[0][1]) for other in test_classes[1:]: suite.addTest(unittest.TestLoader().loadTestsFromTestCase(other[1])) testResult = unittest.TextTestRunner(stream=output, verbosity=2).run(suite) result_data["tests"] = [str(test_method) for test in test_classes for test_method in unittest.TestLoader() .loadTestsFromTestCase(test[1])] result_data["failure"] = not testResult.wasSuccessful() result_data["tests_run"] = testResult.testsRun result_data["failures"] = [(str(func), str(err)) for func, err in testResult.failures] result_data["errors"] = [(str(func), str(err)) for func, err in testResult.errors] return result_data if __name__ == "__main__": docker_run = True try: with open("/output/output.json", "w", encoding="utf-8") as f: result = get_results() json.dump(result, f) except FileNotFoundError: docker_run = False if not docker_run: unittest.main()