summaryrefslogtreecommitdiffstats
path: root/state.py
blob: 594b32a2ff3c6e9358990dadbb50373165f7a78e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# 
class Category:
    def __init__(self, name, **args):
        self.args = args
        self.name = name

    def equiv(self, other):
        if self.name != other.name: return False
        for key, value in self.args.iteritems():
            if not other.args.has_key(key): return False
            if value == None: continue
            if other.args[key] == None: continue
            if not other.args[key] == value: return False
        return True
    
    def intersect(self, other):
        if self.name != other.name: return None
        args = {}
        for key in list(set(self.args.keys() + other.args.keys())):
            if not self.args.has_key(key):
                args[key] = other.args[key]
            elif not other.args.has_key(key):
                args[key] = self.args[key]
            elif self.args[key] == None:
                args[key] = other.args[key]
            elif other.args[key] == None:
                args[key] = self.args[key]
            elif self.args[key] != other.args[key]:
                return None
            else: # self.args[key] == other.args[key]
                args[key] = self.args[key]
        return Category(self.name, **args)

    def __str__(self):
        string = self.name + "("
        had = False
        for key, val in self.args.iteritems():
            if had:
                string += ", "
            had = True
            if val == None:
                string += "%%%s" % key
            else:
                string += "%s: %s" % (key, val)
        return string + ")"

    def __repr__(self):
        return self.__str__()

    def __hash__(self):
        return hash((self.argstup(), self.name))

    def __eq__(self, other):
        return self.name == other.name and self.argstup() == other.argstup()

    def argstup(self):
        retval = []
        keys = self.args.keys()
        keys.sort()
        for key in keys:
            retval.append((key, self.args[key]))
        return tuple(retval)

    def filter(self, other):
        if not other.name == self.name:
            raise TypeError, "States must be the same class"
        args = {}
        for key, value in other.args.iteritems():
            if value != None or not self.args.has_key(key):
                args[key] = value
            else:
                args[key] = self.args[key]
        return Category(self.name, **args)

    def fill(self, info):
        args = {}
        for key, value in self.args.iteritems():
            if value != None or not info.has_key(key):
                args[key] = value
            else:
                args[key] = info[key]
        return Category(self.name, **args)

    def is_finite(self):
        return not None in self.args.values()

class StateMachine:
    def __init__(self):
        self.holds = {}
        self.deps = []

    def assert_state(self, cat):
        found = None
        for dependency in self.get_applicable_deps(cat):
            res = self.get_satisfied_states(cat, dependency)
            if len(res) == 0:
                return False
            if found == None:
                found = res
            else:
                found = self.intersect_list(found, res)
        if found == None:
            self.add_hold(cat)
            return True
        if len(found) == 0:
            return False
        for x in found:
            self.add_hold(x)
        return True

    def intersect_list(self, cats1, cats2):
        retval = set()
        found = set()
        for x in cats1:
            for y in cats2:
                if x == y: continue
                inter = x.intersect(y)
                if inter != None:
                    retval.add(inter)
                    found.add(x)
                    found.add(y)
        if len(found) == 0:
            return cats1 & cats2
        return (retval | ((cats1 & cats2) - found))

    def add_hold(self, cat):
        if self.holds.has_key(cat):
            self.holds[cat] = self.holds[cat] + 1
        else:
            self.holds[cat] = 1

    def get_satisfied_states(self, dependents, dependencies):
        retval = []
        for key, val in self.holds.iteritems():
            if dependencies.equiv(key) and val > 0:
                retval.append(dependents.fill(key.args))
        return set(retval)

    def get_applicable_deps(self, cat):
        retval = []
        for (x, y) in self.deps:
            if x.equiv(cat):
                retval.append(y.fill(cat.filter(x).args))
        return retval

if __name__ == "__main__":
    sm = StateMachine()
    sm.deps.append((Category("mounted", type="nfs"), Category("network_up")))
    sm.deps.append((Category("mounted", uuid=None, devname=None, label=None), Category("found_disk", uuid=None, devname=None, label=None)))
    sm.deps.append((Category("mounted", uuid=None, devname=None, label=None), Category("vol_conf", uuid=None, devname=None, label=None)))
    sm.assert_state(Category("vol_conf", uuid=None, devname=None, label="myroot", type="ext3", mountpoint="/"))
    sm.assert_state(Category("found_disk", uuid="d3adb3ef", devname="/dev/sda", label="myroot"))
    sm.assert_state(Category("mounted", uuid=None, type="ext3", devname=None, label=None, mountpoint=None))
    print sm.holds