diff options
author | Martin Sivak <msivak@redhat.com> | 2012-12-06 14:31:00 +0100 |
---|---|---|
committer | Martin Sivak <msivak@redhat.com> | 2012-12-06 14:31:00 +0100 |
commit | 31fc9917f464d0a2251e108d6c8321c8ac7d7978 (patch) | |
tree | 21564af7247262d3d45f2301b3099ed2509aafbe /di | |
parent | 714f546372db8fba6b1078e04af2493da3860bd9 (diff) | |
download | python-di-31fc9917f464d0a2251e108d6c8321c8ac7d7978.tar.gz python-di-31fc9917f464d0a2251e108d6c8321c8ac7d7978.tar.xz python-di-31fc9917f464d0a2251e108d6c8321c8ac7d7978.zip |
Support accessing (class)attributes of the object decorated by @inject
Diffstat (limited to 'di')
-rw-r--r-- | di/core.py | 32 | ||||
-rw-r--r-- | di/core_test.py | 27 |
2 files changed, 49 insertions, 10 deletions
@@ -55,11 +55,14 @@ class DiRegistry(object): It records the injected objects, handles the execution and cleanup tasks associated with the DI mechanisms. """ + def __init__(self, obj): - self._obj = obj - self._used_objects = {} - self._obj._di_ = self._used_objects + object.__setattr__(self, "_DiRegistry__obj", obj) + object.__setattr__(self, "_DiRegistry__used_objects", {}) + for name in ["__name__", "__module__", "__doc__"]: + object.__setattr__(self, name, getattr(obj, name, getattr(self, name))) + self.__obj._di_ = self.__used_objects def __get__(self, obj, objtype): """Support instance methods.""" @@ -68,23 +71,32 @@ class DiRegistry(object): def register(self, *args, **kwargs): """Add registered injections to the instance of DiRegistry """ - self._used_objects.update(kwargs) + self.__used_objects.update(kwargs) for used_object in args: if hasattr(used_object, "__name__"): - self._used_objects[used_object.__name__] = used_object + self.__used_objects[used_object.__name__] = used_object elif isinstance(used_object, basestring): pass # it is already global, so this is just an annotation else: raise ValueError("%s is not a string or object with __name__" % used_object) def __call__(self, *args, **kwargs): - if not issubclass(type(self._obj), FunctionType): + if not issubclass(type(self.__obj), FunctionType): # call constructor or callable class # (which use @usesclassinject if needed) - return self._obj(*args, **kwargs) + return self.__obj(*args, **kwargs) + else: + return di_call(self.__used_objects, self.__obj, + *args, **kwargs) + + def __getattr__(self, name): + return getattr(self.__obj, name) + + def __setattr__(self, name, value): + if name in self.__dict__: + object.__setattr__(self, name, value) else: - return di_call(self._used_objects, self._obj, - *args, **kwargs) + setattr(self.__obj, name, value) def func_globals(func): """Helper method that allows access to globals @@ -133,7 +145,7 @@ def inject(*args, **kwargs): return obj if not isinstance(obj, DiRegistry): - obj = wraps(obj)(DiRegistry(obj)) + obj = DiRegistry(obj) obj.register(*args, **kwargs) return obj diff --git a/di/core_test.py b/di/core_test.py index 17d88fa..7cbc6bd 100644 --- a/di/core_test.py +++ b/di/core_test.py @@ -1,6 +1,29 @@ +# Tests for the dependency injection core +# +# Copyright (C) 2012 Red Hat, Inc. +# +# This copyrighted material is made available to anyone wishing to use, +# modify, copy, or redistribute it subject to the terms and conditions of +# the GNU General Public License v.2, or (at your option) any later version. +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY expressed or implied, including the implied warranties of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General +# Public License for more details. You should have received a copy of the +# GNU General Public License along with this program; if not, write to the +# Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. Any Red Hat trademarks that are incorporated in the +# source code or documentation are not subject to the GNU General Public +# License and may only be used or replicated with the express permission of +# Red Hat, Inc. +# +# Red Hat Author(s): Martin Sivak <msivak@redhat.com> +# + from .core import * import unittest +def method_to_inject(): + pass class BareFuncEnableTestCase(unittest.TestCase): @inject(injected_func = str.lower) @@ -70,6 +93,10 @@ class ClassDITestCase(unittest.TestCase): obj = Test() self.assertEqual("a", obj.method("A")) + def test_named_register(self): + """Test injection to instance method.""" + Test.register(method_to_inject) + def test_class_init_inject(self): """Test injection to class constructor.""" obj = TestInit("A") |