summaryrefslogtreecommitdiffstats
path: root/di/core.py
blob: 09627505a2f35cf6165f8039a6a61f19d7689740 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""This module implements dependency injection mechanisms."""

__author__ = "Martin Sivak <msivak@redhat.com>"
__all__ = ["DI_ENABLE", "di_enable", "inject", "usesclassinject"]

from functools import wraps, partial
from types import FunctionType

DI_ENABLE = True

def di_enable(method):
    """This decorator enables DI mechanisms in an environment
       where DI is disabled by default. Must be the outermost
       decorator.

       Can be used only on methods or simple functions.
    """
    @wraps(method)
    def caller(*args, **kwargs):
        """The replacement method doing the DI enablement.
        """
        global DI_ENABLE
        old = DI_ENABLE
        DI_ENABLE = True
        ret = method(*args, **kwargs)
        DI_ENABLE = old
        return ret

    return caller

class DiRegistry(object):
    """This class is the internal core of the DI engine.
       It records the injected objects, handles the execution
       and cleanup tasks associated with the DI mechanisms.
    """
    
    def __init__(self, obj):
        self._obj = obj
        self._used_objects = {}
        self._obj._di_ = self._used_objects

    def __get__(self, obj, objtype):
        """Support instance methods."""
        return partial(self.__call__, obj)
        
    def register(self, *args, **kwargs):
        """Add registered injections to the instance of DiRegistry
        """
        self._used_objects.update(kwargs)
        for used_object in args:
            if hasattr(used_object, "__name__"):
                self._used_objects[used_object.__name__] = used_object
            elif isinstance(used_object, basestring):
                pass # it is already global, so this is just an annotation
            else:
                raise ValueError("%s is not a string or object with __name__" % used_object)
                
    def __call__(self, *args, **kwargs):
        if not issubclass(type(self._obj), FunctionType):
            # call constructor or callable class
            # (which use @usesclassinject if needed)
            return self._obj(*args, **kwargs)
        else:
            return di_call(self._used_objects, self._obj,
                                   *args, **kwargs)

def func_globals(func):
    """Helper method that allows access to globals
       depending on the Python version.
    """
    if hasattr(func, "func_globals"):
        return func.func_globals # Python 2
    else:
        return func.__globals__ # Python 3

def di_call(di_dict, method, *args, **kwargs):
    """This method is the core of dependency injection framework.
       It modifies methods global namespace to define all the injected
       variables, executed the method under test and then restores
       the global namespace back.
       
       This variant is used on plain functions.
       
       The modified global namespace is discarded after the method finishes
       so all new global variables and changes to scalars will be lost.
    """
    # modify the globals
    new_globals = func_globals(method).copy()
    new_globals.update(di_dict)

    # create new func with modified globals
    new_method = FunctionType(method.func_code,
                              new_globals, method.func_name,
                              method.func_defaults, method.func_closure)
        
    # execute the method and return it's ret value
    return new_method(*args, **kwargs)

        
def inject(*args, **kwargs):
    """Decorator that registers all the injections we want to pass into
       a unit possibly under test.

       It can be used to decorate class, method or simple function, but
       if it is a decorated class, it's methods has to be decorated with
       @usesinject to use the DI mechanism.
    """
    def inject_decorate(obj):
        """The actual decorator generated by @inject."""
        if not DI_ENABLE:
            return obj
        
        if not isinstance(obj, DiRegistry):
            obj = DiRegistry(obj)

        obj.register(*args, **kwargs)
        return obj
    
    return inject_decorate
    
def usesclassinject(method):
    """This decorator marks a method inside of @inject decorated
       class as a method that should use the dependency injection
       mechanisms.
    """
    if not DI_ENABLE:
        return method    

    @wraps(method)
    def call(*args, **kwargs):
        """The replacement method acting as a proxy to @inject
           decorated class and it's DI mechanisms."""
        self = args[0]
        return di_call(self._di_, method, *args, **kwargs)

    return call

### Unittests are defined below this point
import unittest


class BareFuncEnableTestCase(unittest.TestCase):
    @inject(injected_func = str.lower)
    def method(self, arg):
        return injected_func(arg)
    
class BareFuncTestCase(unittest.TestCase):
    @inject(injected_func = str.lower)
    def method(self, arg):
        return injected_func(arg)

    @inject(injected_func = method)
    def method2(self, arg):
        return injected_func(self, arg)
    
    def test_bare_inject(self):
        """Tests the injection to plain methods."""
        self.assertEqual("a", self.method("A"))

    def test_double_inject(self):
        """Tests the injection to two plain methods."""
        self.assertEqual("a", self.method2("A"))

    def test_inject_global_tainting(self):
        """Tests whether the global namespace is clean
           after the injection is done."""
        global injected_func
        injected_func = None
        self.method("A")
        self.assertEqual(None, injected_func)
        
    
@inject(injected_func = str.lower)
class Test(object):
    """Test fixture for class injection."""
    @usesclassinject
    def method(self, arg):
        return injected_func(arg)

    
@inject(injected_func = str.lower)
class TestInit(object):
    """Test fixture for injection to __init__."""
    @usesclassinject
    def __init__(self, arg):
        self.value = injected_func(arg)


@inject(injected_func = str.lower)
class TestCallable(object):
    """Test fixture for callable classes."""
    @usesclassinject
    def __call__(self, arg):
        return injected_func(arg)

class TestCallableSingle(object):
    """Test fixture for callable classes with
       simple method injection."""
    @inject(injected_func = str.lower)
    def __call__(self, arg):
        return injected_func(arg)
    
class ClassDITestCase(unittest.TestCase):
    
    def test_class_inject(self):
        """Test injection to instance method."""
        obj = Test()
        self.assertEqual("a", obj.method("A"))

    def test_class_init_inject(self):
        """Test injection to class constructor."""
        obj = TestInit("A")
        self.assertEqual("a", obj.value)

    def test_callable_class(self):
        """Test class injection to callable class."""
        obj = TestCallable()
        self.assertEqual("a", obj("A"))
        
    def test_callable_class_single(self):
        """Test method injection to callable class."""
        obj = TestCallableSingle()
        self.assertEqual("a", obj("A"))
        
if __name__ == "__main__":
    unittest.main()