summaryrefslogtreecommitdiffstats
path: root/statemachine.py
blob: 3a0e1166c9cf6b1bce613ceee8490beb02334428 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
__docformat__ = 'restructuredtext'

from category import Category
from pattern import Pattern
import setcross

class StateMachine:
    """
    The state machine contains a list of dependencies between states, and a list
    of states which are "up."
    """

    def __init__(self):
        """
        Create a new state machine
        """
        self.up = set()
        self.deps = []

    def bring_up(self, cat):
        """
        Move states in the given Category `cat` from down to up.
        """
        found = None
        for (match, dependency) in self.get_applicable_deps(cat):
            res = self.get_satisfied_states(match, dependency)
            if len(res) == 0:
                return False
            if found == None:
                found = [res]
            else:
                found.append(res)
        if found == None:
            self.up.add(cat)
            return True
        to_add = self.cat_cross(found)
        for x in to_add:
            self.up.add(x)
        return True

    def cat_cross(self, found):
        """
        Given a list of sets, where each set contains Category objects, return a
        set of all categories that can be made by intersecting one element from
        each set.
        """
        to_add = set()
        for tup in setcross.cross(*found):
            orig = tup
            while len(tup) > 1:
                newtup = (tup[0].intersect(tup[1]),)
                if newtup[0] == None:
                    tup = ()
                    break
                tup = newtup + tup[2:len(tup)]
            if len(tup) == 0 or tup[0] == None:
                continue
            to_add.add(tup[0])
        return to_add

    def intersect_list(self, cats1, cats2):
        """
        Given two lists of categories, return a list of categories such that a
        state appearing in at least one category in each list will appear in at
        least one category in the returned list.
        """
        retval = set()
        for x in cats1:
            for y in cats2:
                inter = x.intersect(y)
                if inter != None:
                    retval.add(inter)
        return retval

    def add_hold(self, cat):
        """
        Add a hold to a state. Does not check dependencies.
        """
        for x in self.up:
            if cat.subset_of(x):
                return
        self.up.add(cat)

    def get_satisfied_states(self, dependents, dependencies):
        """
        Given that states in `dependents` depend on states in `dependencies`,
        return a new Category that contains only the states in `dependents` that
        could match states in `dependencies`.
        """
        retval = []
        for cat in self.up:
            if dependencies.equiv(cat):
                retval.append(dependents.fill(cat.args))
        return set(retval) | dependents.inverse_set()

    def get_applicable_deps(self, cat):
        """
        Find dependencies that might apply to members of `cat`
        """
        retval = []
        for (x, y) in self.deps:
            if x.equiv(cat):
                un = cat.intersect(x)
                retval.append((un, y.fill(un.args)))
        return retval

    def __str__(self):
        return "\n".join(["%s" % k for k in self.up])

    def __repr__(self):
        return str(self)

if __name__ == "__main__":
    sm = StateMachine()
    sm.deps.append((Category("mounted", type=Pattern(True, "nfs")), Category("network_up")))
    sm.deps.append((Category("mounted", uuid=Pattern(False), devname=Pattern(False), label=Pattern(False), type=Pattern(False, "nfs")), Category("found_disk", uuid=Pattern(False), devname=Pattern(False), label=Pattern(False))))
    sm.deps.append((Category("mounted", uuid=Pattern(False), devname=Pattern(False), label=Pattern(False)), Category("vol_conf", uuid=Pattern(False), devname=Pattern(False), label=Pattern(False))))
    sm.bring_up(Category("network_up"))
    sm.bring_up(Category("vol_conf", uuid=Pattern(False), devname=Pattern(False), label=Pattern(True, "myroot"), type=Pattern(True, "ext3"), mountpoint=Pattern(True, "/")))
    sm.bring_up(Category("vol_conf", uuid=Pattern(False), devname=Pattern(True, "foosrv.com:/vol/home"), label=Pattern(False), type=Pattern(True, "nfs"), mountpoint=Pattern(True, "/home")))
    sm.bring_up(Category("found_disk", uuid=Pattern(True, "d3adb3ef"), devname=Pattern(True, "/dev/sda"), label=Pattern(True, "myroot")))
    sm.bring_up(Category("mounted", uuid=Pattern(False), type=Pattern(False), devname=Pattern(False), label=Pattern(False), mountpoint=Pattern(False)))
    print sm