summaryrefslogtreecommitdiffstats
path: root/lib2to3c/initmodule.py
blob: c890a9ebe2836aca89e86d06e5898a7ac5fa320a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from lib2to3c import CocciFix, FixTest
class FixInitModule(CocciFix):
    def __init__(self):
        CocciFix.__init__(self, 'init-module.cocci')

    def preprocess(self, string):
        # FIXME
        return string

    def postprocess(self, string):
        for (before, after) in [('struct __HASH_IF_PY_MAJOR_VERSION_ge_3;',
                                 '#if PY_MAJOR_VERSION > 3'),
                                ('struct __HASH_ELSE;',
                                 '#else'),
                                ('struct __HASH_DEFINE__MOD_ERROR_VAL__NULL;',
                                 '#define MOD_ERROR_VAL NULL'),
                                ('struct __HASH_DEFINE__MOD_ERROR_VAL__;',
                                 '#define MOD_ERROR_VAL'),
                                ('struct __HASH_ENDIF;',
                                 '#endif'),
                                ('__HASH_IF_PY_MAJOR_VERSION_ge_3;',
                                 '#if PY_MAJOR_VERSION > 3'),
                                ('__HASH_ELSE;',
                                 '#else'),
                                ('__HASH_ENDIF;',
                                 '#endif'),
                                ]:
            string = string.replace(before, after)

        # etc
        return string

    def transform(self, string):
        # FIXME: preprocess
        string = self.preprocess(string)
        #print 'input:', repr(string)
        string = CocciFix.transform(self, string)
        #print 'output:', repr(string)
        string = self.postprocess(string)
        return string

#import unittest
class TestFixups(FixTest):
    def setUp(self):
        self.fixer = FixInitModule()

    def test_fixups(self):
        self.assertTransformsTo(self.fixer,
                                '''
PyMODINIT_FUNC
initxx(void)
{
    PyObject *m;

    if (something_that_can_fail() < 0)
        return;

    m = Py_InitModule3("xx", xx_methods, module_doc);
    if (m == NULL)
        return;

    PyModule_AddObject(m, "Null", (PyObject *)&Null_Type);
}
''',
                                '''
#if PY_MAJOR_VERSION > 3
static struct PyModuleDef moduledef = {
    PyModuleDef_HEAD_INIT,
    "xx",/* m_name */
    module_doc,/* m_doc */
    0,/* m_size */
    xx_methods,/* m_methods */
    NULL,/* m_reload */
    NULL,/* m_traverse */
    NULL,/* m_clear */
    NULL,/* m_free */
};
#define MOD_ERROR_VAL NULL
#else
#define MOD_ERROR_VAL
#endif
PyMODINIT_FUNC
initxx(void)
{
    PyObject *m;

    if (something_that_can_fail() < 0)
        return MOD_ERROR_VAL;

    #if PY_MAJOR_VERSION > 3
    m = PyModule_Create(&moduledef);
    #else
    m = Py_InitModule3("xx", xx_methods, module_doc);
    #endif
    if (m == NULL)
        return MOD_ERROR_VAL;

    PyModule_AddObject(m, "Null", (PyObject *)&Null_Type);
}
''')
# FIXME: this should have a trailing:
'''
    #if PY_MAJOR_VERSION > 3
    return m;
    #endif
'''
# but I haven't figured out how to get spatch to add that whilst correctly
# handling error paths

# Some code that isn't handled yet:
# Multiple error-handling paths:
'''
PyMODINIT_FUNC
initxx(void)
{
    PyObject *m;

    if (something_that_can_fail() < 0)
        return;

    m = Py_InitModule3("xx", xx_methods, module_doc);
    if (m == NULL)
        return;

    if (PyType_Ready(&Null_Type) < 0)
        return;
    PyModule_AddObject(m, "Null", (PyObject *)&Null_Type);
}
'''

if __name__ == '__main__':
    unittest.main()