diff options
-rw-r--r-- | pyarg-parsetuple.cocci | 13 | ||||
-rw-r--r-- | validate.py | 77 |
2 files changed, 67 insertions, 23 deletions
diff --git a/pyarg-parsetuple.cocci b/pyarg-parsetuple.cocci index 4beb481..706a489 100644 --- a/pyarg-parsetuple.cocci +++ b/pyarg-parsetuple.cocci @@ -11,7 +11,7 @@ PyArg_ParseTuple(args, fmt@p1, e1) @initialize:python@ import sys sys.path.append('.') -from validate import * +from validate import validate_types @script:python@ args << ParseTuple_1.args; @@ -22,12 +22,17 @@ p1 << ParseTuple_1.p1; """ Analyze format strings, compare to vararg types actually passed + +FIXME: generalize this to varargs """ -# FIXME: generalize this to varargs #print "args: %s" % args #print "fmt: %s" % fmt #print "var1: %s" % t1 #print get_types(fmt.expr) -print p1[0].__dict__ #.line, p1.column -validate_types(fmt.expr, [t1]) + +# For some reason, locations are coming as a 1-tuple containing a Location (from +# coccilibs.elems), rather than the location itself +# Hence we use p1[0], not p1 + +validate_types(p1[0], fmt.expr, [t1]) diff --git a/validate.py b/validate.py index 5827ddf..950e0e7 100644 --- a/validate.py +++ b/validate.py @@ -4,13 +4,16 @@ Hooks for validating CPython extension source code def get_types(strfmt): """ Generate a list of C type names from a PyArg_ParseTuple format string + Compare to Python/getargs.c:vgetargs1 """ result = [] i = 0 while i < len(strfmt): c = strfmt[i] - if c == 'i': - result.append('int *') + simple = {'i':'int', + 's':'char *'} + if c in simple: + result.append(simple[c] + ' *') if c in [':', ';']: break i += 1 @@ -18,27 +21,63 @@ def get_types(strfmt): class CExtensionError(Exception): # Base class for errors discovered by static analysis in C extension code - pass + def __init__(self, location): + self.location = location + + def __str__(self): + return '%s:%s: %s' % (self.location.file, + self.location.line, + self._get_desc()) + + def _get_desc(self): + raise NotImplementedError + class WrongNumberOfVars(CExtensionError): - pass + def __init__(self, location, exp_types, actual_types): + CExtensionError.__init__(self, location) + self.exp_types = exp_types + self.actual_types = actual_types class NotEnoughVars(WrongNumberOfVars): - pass + def _get_desc(self): + return 'Not enough arguments: expected %i (%s), but got %i (%s)' % ( + len(self.exp_types), + self.exp_types, + len(self.actual_types), + self.actual_types) class TooManyVars(WrongNumberOfVars): - pass - -class MismatchingType(WrongNumberOfVars): - pass - -def validate_types(format_string, actual_types): - exp_types = get_types(format_string) - if len(actual_types) < len(exp_types): - raise NotEnoughVars(actual_types, exp_types) - if len(actual_types) > len(exp_types): - raise TooManyVars(actual_types, exp_types) - for exp, actual in zip(exp_types, actual_types): - if exp != actual: - raise MismatchingType(exp, actual) + def _get_desc(self): + return 'Too many arguments: expected %i (%s), but got %i (%s)' % ( + len(self.exp_types), + self.exp_types, + len(self.actual_types), + self.actual_types) + +class MismatchingType(CExtensionError): + def __init__(self, location, arg_num, exp_type, actual_type): + super(self.__class__, self).__init__(location) + self.arg_num = arg_num + self.exp_type = exp_type + self.actual_type = actual_type + + def _get_desc(self): + return 'Mismatching type of argument %i: expected "%s" but got "%s"' % ( + self.arg_num, + self.exp_type, + self.actual_type) + +def validate_types(location, format_string, actual_types): + try: + exp_types = get_types(format_string) + if len(actual_types) < len(exp_types): + raise NotEnoughVars(location, actual_types, exp_types) + if len(actual_types) > len(exp_types): + raise TooManyVars(location, actual_types, exp_types) + for i, (exp, actual) in enumerate(zip(exp_types, actual_types)): + if exp != actual: + raise MismatchingType(location, i+1, exp, actual) + except CExtensionError, err: + print err |