summary refs log tree commit diff
path: root/tools.py
blob: 11e962b73931621adf66b74f70943baa40c79f66 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from z3 import And, Or, sat


# implement python operators for some z3 objects
# z3.BoolRef.__radd__ = lambda self, other: self + other
# z3.BoolRef.__add__ = lambda self, other: And(self, other)
# z3.BoolRef.__rmul__ = lambda self, other: self * other
# z3.BoolRef.__mul__ = lambda self, other: Or(self, other)


def coordinate_l(width):
    return lambda cell_number: (cell_number % width, cell_number // width)


def cell_number_l(width):
    return lambda x, y: x + y * width


def z3_bool_mat_mul(mat1, mat2):
    return lambda i, j: simplify_or([simplify_and(mat1[i][k], mat2[k][j]) for k in range(len(mat1))])


def simplify_and(*terms):
    if any(t is False for t in terms):
        return False
    return And(terms)


def simplify_or(terms):
    a = [t for t in terms if t is not False]
    if len(a) == 0:
        return False
    return Or(a)


def z3_bool_mat_sum(mats):
    return lambda *index: simplify_or([mat[index] for mat in mats])


def mat_display(mat):
    height = len(mat[0])
    width = len(mat)
    print('x', ' '.join([str(x) for x in range(width)]))
    for y in range(height):
        print(y, end=" ")
        for x in range(width):
            print(mat[x][y], end=" ")
        print()


def bool_display(v):
    return '#' if v else ' '


def cell_display_l(shading, numbers):
    return lambda *index: "#" if shading[index] else str(numbers[index])
    # return lambda *index: "#" if shading[index] else ' '


def shaded_or_display_l(shading, other):
    return lambda *index: "#" if shading[index] else other[index]
    # return lambda *index: "#" if shading[index] else ' '


def all_smt(s, initial_terms):
    def block_term(s, m, t):
        s.add(t != m.eval(t, model_completion=True))

    def fix_term(s, m, t):
        s.add(t == m.eval(t, model_completion=True))

    def all_smt_rec(terms):
        if sat == s.check():
            m = s.model()
            yield m
            for i in range(len(terms)):
                s.push()
                block_term(s, m, terms[i])
                for j in range(i):
                    fix_term(s, m, terms[j])
                yield from all_smt_rec(terms[i:])
                s.pop()

    yield from all_smt_rec(list(initial_terms))