From 0366e7395c97e51b8a6294c10176bef73e1bdcf7 Mon Sep 17 00:00:00 2001 From: Alexander Larsson Date: Wed, 26 May 2010 12:19:58 +0200 Subject: Initial import of spice protocol description and demarshall generator The "spice.proto" file describes in detail the networking prototcol that spice uses and spice_codegen.py can parse this and generate demarshallers for such network messages. --- python_modules/__init__.py | 0 python_modules/codegen.py | 354 +++++++++++++ python_modules/demarshal.py | 1033 ++++++++++++++++++++++++++++++++++++++ python_modules/ptypes.py | 965 +++++++++++++++++++++++++++++++++++ python_modules/spice_parser.py | 157 ++++++ spice.proto | 1086 ++++++++++++++++++++++++++++++++++++++++ spice_gen.py | 165 ++++++ 7 files changed, 3760 insertions(+) create mode 100644 python_modules/__init__.py create mode 100644 python_modules/codegen.py create mode 100644 python_modules/demarshal.py create mode 100644 python_modules/ptypes.py create mode 100644 python_modules/spice_parser.py create mode 100644 spice.proto create mode 100755 spice_gen.py diff --git a/python_modules/__init__.py b/python_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python_modules/codegen.py b/python_modules/codegen.py new file mode 100644 index 00000000..5bb659ac --- /dev/null +++ b/python_modules/codegen.py @@ -0,0 +1,354 @@ +from cStringIO import StringIO + +def camel_to_underscores(s, upper = False): + res = "" + for i in range(len(s)): + c = s[i] + if i > 0 and c.isupper(): + res = res + "_" + if upper: + res = res + c.upper() + else: + res = res + c.lower() + return res + +def underscores_to_camel(s): + res = "" + do_upper = True + for i in range(len(s)): + c = s[i] + if c == "_": + do_upper = True + else: + if do_upper: + res = res + c.upper() + else: + res = res + c + do_upper = False + return res + +proto_prefix = "Temp" + +def set_prefix(prefix): + global proto_prefix + global proto_prefix_upper + global proto_prefix_lower + proto_prefix = prefix + proto_prefix_upper = prefix.upper() + proto_prefix_lower = prefix.lower() + +def prefix_underscore_upper(*args): + s = proto_prefix_upper + for arg in args: + s = s + "_" + arg + return s + +def prefix_underscore_lower(*args): + s = proto_prefix_lower + for arg in args: + s = s + "_" + arg + return s + +def prefix_camel(*args): + s = proto_prefix + for arg in args: + s = s + underscores_to_camel(arg) + return s + +def increment_identifier(idf): + v = idf[-1:] + if v.isdigit(): + return idf[:-1] + str(int(v) + 1) + return idf + "2" + +def sum_array(array): + if len(array) == 0: + return 0 + return " + ".join(array) + +class CodeWriter: + def __init__(self): + self.out = StringIO() + self.contents = [self.out] + self.indentation = 0 + self.at_line_start = True + self.indexes = ["i", "j", "k", "ii", "jj", "kk"] + self.current_index = 0 + self.generated = {} + self.vars = [] + self.has_error_check = False + self.options = {} + self.function_helper_writer = None + + def set_option(self, opt, value = True): + self.options[opt] = value + + def has_option(self, opt): + return self.options.has_key(opt) + + def set_is_generated(self, kind, name): + if not self.generated.has_key(kind): + v = {} + self.generated[kind] = v + else: + v = self.generated[kind] + v[name] = 1 + + def is_generated(self, kind, name): + if not self.generated.has_key(kind): + return False + v = self.generated[kind] + return v.has_key(name) + + def getvalue(self): + strs = map(lambda writer: writer.getvalue(), self.contents) + return "".join(strs) + + def get_subwriter(self): + writer = CodeWriter() + self.contents.append(writer) + self.out = StringIO() + self.contents.append(self.out) + writer.indentation = self.indentation + writer.at_line_start = self.at_line_start + writer.generated = self.generated + writer.options = self.options + + return writer; + + def write(self, s): + # Ensure its a string + s = str(s) + + if len(s) == 0: + return + + if self.at_line_start: + for i in range(self.indentation): + self.out.write(" ") + self.at_line_start = False + self.out.write(s) + return self + + def newline(self): + self.out.write("\n") + self.at_line_start = True + return self + + def writeln(self, s): + self.write(s) + self.newline() + return self + + def label(self, s): + self.indentation = self.indentation - 1 + self.write(s + ":") + self.indentation = self.indentation + 1 + self.newline() + + def statement(self, s): + self.write(s) + self.write(";") + self.newline() + return self + + def assign(self, var, val): + self.write("%s = %s" % (var, val)) + self.write(";") + self.newline() + return self + + def increment(self, var, val): + self.write("%s += %s" % (var, val)) + self.write(";") + self.newline() + return self + + def comment(self, str): + self.write("/* " + str + " */") + return self + + def todo(self, str): + self.comment("TODO: *** %s ***" % str).newline() + return self + + def error_check(self, check, label = "error"): + self.has_error_check = True + with self.block("if (SPICE_UNLIKELY(%s))" % check): + if self.has_option("print_error"): + self.statement('printf("%%s: Caught error - %s", __PRETTY_FUNCTION__)' % check) + if self.has_option("assert_on_error"): + self.statement("assert(0)") + self.statement("goto %s" % label) + + def indent(self): + self.indentation += 4; + + def unindent(self): + self.indentation -= 4; + if self.indentation < 0: + self.indenttation = 0 + + def begin_block(self, prefix= "", comment = ""): + if len(prefix) > 0: + self.write(prefix) + if self.at_line_start: + self.write("{") + else: + self.write(" {") + if len(comment) > 0: + self.write(" ") + self.comment(comment) + self.newline() + self.indent() + + def end_block(self, semicolon=False, newline=True): + self.unindent() + if self.at_line_start: + self.write("}") + else: + self.write(" }") + if semicolon: + self.write(";") + if newline: + self.newline() + + class Block: + def __init__(self, writer, semicolon, newline): + self.writer = writer + self.semicolon = semicolon + self.newline = newline + + def __enter__(self): + return self.writer.get_subwriter() + + def __exit__(self, exc_type, exc_value, traceback): + self.writer.end_block(self.semicolon, self.newline) + + class PartialBlock: + def __init__(self, writer, scope, semicolon, newline): + self.writer = writer + self.scope = scope + self.semicolon = semicolon + self.newline = newline + + def __enter__(self): + return self.scope + + def __exit__(self, exc_type, exc_value, traceback): + self.writer.end_block(self.semicolon, self.newline) + + class NoBlock: + def __init__(self, scope): + self.scope = scope + + def __enter__(self): + return self.scope + + def __exit__(self, exc_type, exc_value, traceback): + pass + + def block(self, prefix= "", comment = "", semicolon=False, newline=True): + self.begin_block(prefix, comment) + return self.Block(self, semicolon, newline) + + def partial_block(self, scope, semicolon=False, newline=True): + return self.PartialBlock(self, scope, semicolon, newline) + + def no_block(self, scope): + return self.NoBlock(scope) + + def optional_block(self, scope): + if scope != None: + return self.NoBlock(scope) + return self.block() + + def for_loop(self, index, limit): + return self.block("for (%s = 0; %s < %s; %s++)" % (index, index, limit, index)) + + def while_loop(self, expr): + return self.block("while (%s)" % (expr)) + + def if_block(self, check, elseif=False, newline=True): + s = "if (%s)" % (check) + if elseif: + s = " else " + s + self.begin_block(s, "") + return self.Block(self, False, newline) + + def variable_defined(self, name): + for n in self.vars: + if n == name: + return True + return False + + def variable_def(self, ctype, *names): + for n in names: + # Strip away initialization + i = n.find("=") + if i != -1: + n = n[0:i] + self.vars.append(n.strip()) + # only add space for non-pointer types + if ctype[-1] == "*": + ctype = ctype[:-1].rstrip() + self.writeln("%s *%s;"%(ctype, ", *".join(names))) + else: + self.writeln("%s %s;"%(ctype, ", ".join(names))) + return self + + def function_helper(self): + if self.function_helper_writer != None: + writer = self.function_helper_writer.get_subwriter() + self.function_helper_writer.newline() + else: + writer = self.get_subwriter() + return writer + + def function(self, name, return_type, args, static = False): + self.has_error_check = False + self.function_helper_writer = self.get_subwriter() + if static: + self.write("static ") + self.write(return_type) + self.write(" %s(%s)"% (name, args)).newline() + self.begin_block() + self.function_variables_writer = self.get_subwriter() + self.function_variables = {} + return self.function_variables_writer + + def macro(self, name, args, define): + self.write("#define %s(%s) %s" % (name, args, define)).newline() + + def add_function_variable(self, ctype, name): + if self.function_variables.has_key(name): + assert(self.function_variables[name] == ctype) + else: + self.function_variables[name] = ctype + self.function_variables_writer.variable_def(ctype, name) + + def pop_index(self): + index = self.indexes[self.current_index] + self.current_index = self.current_index + 1 + self.add_function_variable("uint32_t", index) + return index + + def push_index(self): + self.current_index = self.current_index - 1 + + class Index: + def __init__(self, writer, val): + self.writer = writer + self.val = val + + def __enter__(self): + return self.val + + def __exit__(self, exc_type, exc_value, traceback): + self.writer.push_index() + + def index(self, no_block = False): + if no_block: + return self.no_block(None) + val = self.pop_index() + return self.Index(self, val) diff --git a/python_modules/demarshal.py b/python_modules/demarshal.py new file mode 100644 index 00000000..fcd68508 --- /dev/null +++ b/python_modules/demarshal.py @@ -0,0 +1,1033 @@ +import ptypes +import codegen + + +def write_parser_helpers(writer): + if writer.is_generated("helper", "demarshaller"): + return + + writer.set_is_generated("helper", "demarshaller") + + writer = writer.function_helper() + + writer.writeln("#ifdef WORDS_BIGENDIAN") + for size in [8, 16, 32, 64]: + for sign in ["", "u"]: + utype = "uint%d" % (size) + type = "%sint%d" % (sign, size) + swap = "SPICE_BYTESWAP%d" % size + if size == 8: + writer.macro("read_%s" % type, "ptr", "(*((%s_t *)(ptr)))" % type) + else: + writer.macro("read_%s" % type, "ptr", "((%s_t)%s(*((%s_t *)(ptr)))" % (type, swap, utype)) + writer.writeln("#else") + for size in [8, 16, 32, 64]: + for sign in ["", "u"]: + type = "%sint%d" % (sign, size) + writer.macro("read_%s" % type, "ptr", "(*((%s_t *)(ptr)))" % type) + writer.writeln("#endif") + + for size in [8, 16, 32, 64]: + for sign in ["", "u"]: + writer.newline() + type = "%sint%d" % (sign, size) + ctype = "%s_t" % type + scope = writer.function("SPICE_GNUC_UNUSED consume_%s" % type, ctype, "uint8_t **ptr", True) + scope.variable_def(ctype, "val") + writer.assign("val", "read_%s(*ptr)" % type) + writer.increment("*ptr", size / 8) + writer.statement("return val") + writer.end_block() + + writer.newline() + writer.statement("typedef struct PointerInfo PointerInfo") + writer.statement("typedef uint8_t * (*parse_func_t)(uint8_t *message_start, uint8_t *message_end, uint8_t *struct_data, PointerInfo *ptr_info, int minor)") + writer.statement("typedef uint8_t * (*parse_msg_func_t)(uint8_t *message_start, uint8_t *message_end, int minor, size_t *size_out)") + writer.statement("typedef uint8_t * (*spice_parse_channel_func_t)(uint8_t *message_start, uint8_t *message_end, uint16_t message_type, int minor, size_t *size_out)") + + writer.newline() + writer.begin_block("struct PointerInfo") + writer.variable_def("uint64_t", "offset") + writer.variable_def("parse_func_t", "parse") + writer.variable_def("SPICE_ADDRESS *", "dest") + writer.variable_def("uint32_t", "nelements") + writer.end_block(semicolon=True) + +def write_read_primitive(writer, start, container, name, scope): + m = container.lookup_member(name) + assert(m.is_primitive()) + writer.assign("pos", start + " + " + container.get_nw_offset(m, "", "__nw_size")) + writer.error_check("pos + %s > message_end" % m.member_type.get_fixed_nw_size()) + + var = "%s__value" % (name) + scope.variable_def(m.member_type.c_type(), var) + writer.assign(var, "read_%s(pos)" % (m.member_type.primitive_type())) + return var + +def write_read_primitive_item(writer, item, scope): + assert(item.type.is_primitive()) + writer.assign("pos", item.get_position()) + writer.error_check("pos + %s > message_end" % item.type.get_fixed_nw_size()) + var = "%s__value" % (item.subprefix) + scope.variable_def(item.type.c_type(), var) + writer.assign(var, "read_%s(pos)" % (item.type.primitive_type())) + return var + +class ItemInfo: + def __init__(self, type, prefix, position): + self.type = type + self.prefix = prefix + self.subprefix = prefix + self.position = position + self.non_null = False + self.member = None + + def nw_size(self): + return self.prefix + "__nw_size" + + def mem_size(self): + return self.prefix + "__mem_size" + + def extra_size(self): + return self.prefix + "__extra_size" + + def get_position(self): + return self.position + +class MemberItemInfo(ItemInfo): + def __init__(self, member, container, start): + if not member.is_switch(): + self.type = member.member_type + self.prefix = member.name + self.subprefix = member.name + self.non_null = member.has_attr("nonnull") + self.position = "(%s + %s)" % (start, container.get_nw_offset(member, "", "__nw_size")) + self.member = member + +def write_validate_switch_member(writer, container, switch_member, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + var = container.lookup_member(switch_member.variable) + var_type = var.member_type + + v = write_read_primitive(writer, start, container, switch_member.variable, parent_scope) + + item = MemberItemInfo(switch_member, container, start) + + first = True + for c in switch_member.cases: + check = c.get_check(v, var_type) + m = c.member + with writer.if_block(check, not first, False) as if_scope: + item.type = c.member.member_type + item.subprefix = item.prefix + "_" + m.name + item.non_null = c.member.has_attr("nonnull") + sub_want_extra_size = want_extra_size + if sub_want_extra_size and not m.contains_extra_size(): + writer.assign(item.extra_size(), 0) + sub_want_extra_size = False + + write_validate_item(writer, container, item, if_scope, scope, start, + want_nw_size, want_mem_size, sub_want_extra_size) + + first = False + + with writer.block(" else"): + if want_nw_size: + writer.assign(item.nw_size(), 0) + if want_mem_size: + writer.assign(item.mem_size(), 0) + if want_extra_size: + writer.assign(item.extra_size(), 0) + + writer.newline() + +def write_validate_struct_function(writer, struct): + validate_function = "validate_%s" % struct.c_type() + if writer.is_generated("validator", validate_function): + return validate_function + + writer.set_is_generated("validator", validate_function) + writer = writer.function_helper() + scope = writer.function(validate_function, "intptr_t", "uint8_t *message_start, uint8_t *message_end, SPICE_ADDRESS offset, int minor") + scope.variable_def("uint8_t *", "start = message_start + offset") + scope.variable_def("SPICE_GNUC_UNUSED uint8_t *", "pos"); + scope.variable_def("size_t", "mem_size", "nw_size"); + num_pointers = struct.get_num_pointers() + if num_pointers != 0: + scope.variable_def("SPICE_GNUC_UNUSED intptr_t", "ptr_size"); + + writer.newline() + with writer.if_block("offset == 0"): + writer.statement("return 0") + + writer.newline() + writer.error_check("start >= message_end") + + writer.newline() + write_validate_container(writer, None, struct, "start", scope, True, True, False) + + writer.newline() + writer.comment("Check if struct fits in reported side").newline() + writer.error_check("start + nw_size > message_end") + + writer.statement("return mem_size") + + writer.newline() + writer.label("error") + writer.statement("return -1") + + writer.end_block() + + return validate_function + +def write_validate_pointer_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + if want_nw_size: + writer.assign(item.nw_size(), 8) + + if want_mem_size or want_extra_size: + target_type = item.type.target_type + + v = write_read_primitive_item(writer, item, scope) + if item.non_null: + writer.error_check("%s == 0" % v) + + # pointer target is struct, or array of primitives + # if array, need no function check + + if target_type.is_array(): + writer.error_check("message_start + %s >= message_end" % v) + + + assert target_type.element_type.is_primitive() + + array_item = ItemInfo(target_type, "%s__array" % item.prefix, start) + scope.variable_def("uint32_t", array_item.nw_size()) + scope.variable_def("uint32_t", array_item.mem_size()) + if target_type.is_cstring_length(): + writer.assign(array_item.nw_size(), "spice_strnlen((char *)message_start + %s, message_end - (message_start + %s))" % (v, v)) + writer.error_check("*(message_start + %s + %s) != 0" % (v, array_item.nw_size())) + writer.assign(array_item.mem_size(), array_item.nw_size()) + else: + write_validate_array_item(writer, container, array_item, scope, parent_scope, start, + True, True, False) + writer.error_check("message_start + %s + %s > message_end" % (v, array_item.nw_size())) + + if want_extra_size: + if item.member and item.member.has_attr("nocopy"): + writer.comment("@nocopy, so no extra size").newline() + writer.assign(item.extra_size(), 0) + elif target_type.element_type.get_fixed_nw_size == 1: + writer.assign(item.extra_size(), array_item.mem_size()) + # If not bytes or zero, add padding needed for alignment + else: + writer.assign(item.extra_size(), "%s + /* for alignment */ 3" % array_item.mem_size()) + if want_mem_size: + writer.assign(item.mem_size(), "sizeof(void *) + %s" % array_item.mem_size()) + + elif target_type.is_struct(): + validate_function = write_validate_struct_function(writer, target_type) + writer.assign("ptr_size", "%s(message_start, message_end, %s, minor)" % (validate_function, v)) + writer.error_check("ptr_size < 0") + + if want_extra_size: + writer.assign(item.extra_size(), "ptr_size + /* for alignment */ 3") + if want_mem_size: + writer.assign(item.mem_size(), "sizeof(void *) + ptr_size") + else: + raise NotImplementedError("pointer to unsupported type %s" % target_type) + + +def write_validate_array_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + array = item.type + is_byte_size = False + element_type = array.element_type + if array.is_bytes_length(): + nelements = "%s__nbytes" %(item.prefix) + else: + nelements = "%s__nelements" %(item.prefix) + if not parent_scope.variable_defined(nelements): + parent_scope.variable_def("uint32_t", nelements) + + if array.is_constant_length(): + writer.assign(nelements, array.size) + elif array.is_remaining_length(): + if element_type.is_fixed_nw_size(): + if element_type.get_fixed_nw_size() == 1: + writer.assign(nelements, "message_end - %s" % item.get_position()) + else: + writer.assign(nelements, "(message_end - %s) / (%s)" %(item.get_position(), element_type.get_fixed_nw_size())) + else: + raise NotImplementedError("TODO array[] of dynamic element size not done yet") + elif array.is_identifier_length(): + v = write_read_primitive(writer, start, container, array.size, scope) + writer.assign(nelements, v) + elif array.is_image_size_length(): + bpp = array.size[1] + width = array.size[2] + rows = array.size[3] + width_v = write_read_primitive(writer, start, container, width, scope) + rows_v = write_read_primitive(writer, start, container, rows, scope) + # TODO: Handle multiplication overflow + if bpp == 8: + writer.assign(nelements, "%s * %s" % (width_v, rows_v)) + elif bpp == 1: + writer.assign(nelements, "((%s + 7) / 8 ) * %s" % (width_v, rows_v)) + else: + writer.assign(nelements, "((%s * %s + 7) / 8 ) * %s" % (bpp, width_v, rows_v)) + elif array.is_bytes_length(): + is_byte_size = True + v = write_read_primitive(writer, start, container, array.size[1], scope) + writer.assign(nelements, v) + elif array.is_cstring_length(): + writer.todo("cstring array size type not handled yet") + else: + writer.todo("array size type not handled yet") + + writer.newline() + + nw_size = item.nw_size() + mem_size = item.mem_size() + extra_size = item.extra_size() + + if is_byte_size and want_nw_size: + writer.assign(nw_size, nelements) + want_nw_size = False + + if element_type.is_fixed_nw_size() and want_nw_size: + element_size = element_type.get_fixed_nw_size() + # TODO: Overflow check the multiplication + if element_size == 1: + writer.assign(nw_size, nelements) + else: + writer.assign(nw_size, "(%s) * %s" % (element_size, nelements)) + want_nw_size = False + + if element_type.is_fixed_sizeof() and want_mem_size and not is_byte_size: + # TODO: Overflow check the multiplication + writer.assign(mem_size, "%s * %s" % (element_type.sizeof(), nelements)) + want_mem_size = False + + if not element_type.contains_extra_size() and want_extra_size: + writer.assign(extra_size, 0) + want_extra_size = False + + if not (want_mem_size or want_nw_size or want_extra_size): + return + + start2 = codegen.increment_identifier(start) + scope.variable_def("uint8_t *", "%s = %s" % (start2, item.get_position())) + if is_byte_size: + start2_end = "%s_array_end" % start2 + scope.variable_def("uint8_t *", start2_end) + + element_item = ItemInfo(element_type, "%s__element" % item.prefix, start2) + + element_nw_size = element_item.nw_size() + element_mem_size = element_item.mem_size() + scope.variable_def("uint32_t", element_nw_size) + scope.variable_def("uint32_t", element_mem_size) + + if want_nw_size: + writer.assign(nw_size, 0) + if want_mem_size: + writer.assign(mem_size, 0) + if want_extra_size: + writer.assign(extra_size, 0) + + want_element_nw_size = want_nw_size + if element_type.is_fixed_nw_size(): + start_increment = element_type.get_fixed_nw_size() + else: + want_element_nw_size = True + start_increment = element_nw_size + + if is_byte_size: + writer.assign(start2_end, "%s + %s" % (start2, nelements)) + + with writer.index(no_block = is_byte_size) as index: + with writer.while_loop("%s < %s" % (start2, start2_end) ) if is_byte_size else writer.for_loop(index, nelements) as scope: + write_validate_item(writer, container, element_item, scope, parent_scope, start2, + want_element_nw_size, want_mem_size, want_extra_size) + + if want_nw_size: + writer.increment(nw_size, element_nw_size) + if want_mem_size: + writer.increment(mem_size, element_mem_size) + if want_extra_size: + writer.increment(extra_size, element_extra_size) + + writer.increment(start2, start_increment) + if is_byte_size: + writer.error_check("%s != %s" % (start2, start2_end)) + +def write_validate_struct_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + struct = item.type + start2 = codegen.increment_identifier(start) + scope.variable_def("SPICE_GNUC_UNUSED uint8_t *", start2 + " = %s" % (item.get_position())) + + write_validate_container(writer, item.prefix, struct, start2, scope, want_nw_size, want_mem_size, want_extra_size) + +def write_validate_primitive_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + if want_nw_size: + nw_size = item.nw_size() + writer.assign(nw_size, item.type.get_fixed_nw_size()) + if want_mem_size: + mem_size = item.mem_size() + writer.assign(mem_size, item.type.sizeof()) + assert not want_extra_size + +def write_validate_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + if item.type.is_pointer(): + write_validate_pointer_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + elif item.type.is_array(): + write_validate_array_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + elif item.type.is_struct(): + write_validate_struct_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + elif item.type.is_primitive(): + write_validate_primitive_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + else: + writer.todo("Implement validation of %s" % item.type) + +def write_validate_member(writer, container, member, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + if member.has_minor_attr(): + prefix = "if (minor >= %s)" % (member.get_minor_attr()) + newline = False + else: + prefix = "" + newline = True + item = MemberItemInfo(member, container, start) + with writer.block(prefix, newline=newline, comment=member.name) as scope: + if member.is_switch(): + write_validate_switch_member(writer, container, member, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + else: + write_validate_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + + if member.has_minor_attr(): + with writer.block(" else", comment = "minor < %s" % (member.get_minor_attr())): + if member.is_array(): + nelements = "%s__nelements" %(item.prefix) + writer.assign(nelements, 0) + if want_nw_size: + writer.assign(item.nw_size(), 0) + + if want_mem_size: + if member.is_fixed_sizeof(): + writer.assign(item.mem_size(), member.sizeof()) + elif member.is_array(): + writer.assign(item.mem_size(), 0) + else: + raise NotImplementedError("TODO minor check for non-constant items") + + assert not want_extra_size + +def write_validate_container(writer, prefix, container, start, parent_scope, want_nw_size, want_mem_size, want_extra_size): + for m in container.members: + sub_want_nw_size = want_nw_size and not m.is_fixed_nw_size() + sub_want_mem_size = m.is_extra_size() + sub_want_extra_size = not m.is_extra_size() and m.contains_extra_size() + + defs = ["size_t"] + if sub_want_nw_size: + defs.append (m.name + "__nw_size") + if sub_want_mem_size: + defs.append (m.name + "__mem_size") + if sub_want_extra_size: + defs.append (m.name + "__extra_size") + + if sub_want_nw_size or sub_want_mem_size or sub_want_extra_size: + parent_scope.variable_def(*defs) + write_validate_member(writer, container, m, parent_scope, start, + sub_want_nw_size, sub_want_mem_size, sub_want_extra_size) + writer.newline() + + if want_nw_size: + if prefix: + nw_size = prefix + "__nw_size" + else: + nw_size = "nw_size" + + size = 0 + for m in container.members: + if m.is_fixed_nw_size(): + size = size + m.get_fixed_nw_size() + + nm_sum = str(size) + for m in container.members: + if not m.is_fixed_nw_size(): + nm_sum = nm_sum + " + " + m.name + "__nw_size" + + writer.assign(nw_size, nm_sum) + + if want_mem_size: + if prefix: + mem_size = prefix + "__mem_size" + else: + mem_size = "mem_size" + + mem_sum = container.sizeof() + for m in container.members: + if m.is_extra_size(): + mem_sum = mem_sum + " + " + m.name + "__mem_size" + elif m.contains_extra_size(): + mem_sum = mem_sum + " + " + m.name + "__extra_size" + + writer.assign(mem_size, mem_sum) + + if want_extra_size: + if prefix: + extra_size = prefix + "__extra_size" + else: + extra_size = "extra_size" + + extra_sum = [] + for m in container.members: + if m.is_extra_size(): + extra_sum.append(m.name + "__mem_size") + elif m.contains_extra_size(): + extra_sum.append(m.name + "__extra_size") + writer.assign(extra_size, codegen.sum_array(extra_sum)) + +class DemarshallingDestination: + def __init__(self): + pass + + def child_at_end(self, writer, t): + return RootDemarshallingDestination(self, t.c_type(), t.sizeof()) + + def child_sub(self, member): + return SubDemarshallingDestination(self, member) + + def declare(self, writer): + return writer.optional_block(self.reuse_scope) + + def is_toplevel(self): + return self.parent_dest == None and not self.is_helper + +class RootDemarshallingDestination(DemarshallingDestination): + def __init__(self, parent_dest, c_type, sizeof, pointer = None): + self.is_helper = False + self.reuse_scope = None + self.parent_dest = parent_dest + if parent_dest: + self.base_var = codegen.increment_identifier(parent_dest.base_var) + else: + self.base_var = "out" + self.c_type = c_type + self.sizeof = sizeof + self.pointer = pointer # None == at "end" + + def get_ref(self, member): + return self.base_var + "->" + member + + def declare(self, writer): + if self.reuse_scope: + scope = self.reuse_scope + else: + writer.begin_block() + scope = writer.get_subwriter() + + scope.variable_def(self.c_type + " *", self.base_var) + if not self.reuse_scope: + scope.newline() + + if self.pointer: + writer.assign(self.base_var, "(%s *)%s" % (self.c_type, self.pointer)) + else: + writer.assign(self.base_var, "(%s *)end" % (self.c_type)) + writer.increment("end", self.sizeof) + writer.newline() + + if self.reuse_scope: + return writer.no_block(self.reuse_scope) + else: + return writer.partial_block(scope) + +class SubDemarshallingDestination(DemarshallingDestination): + def __init__(self, parent_dest, member): + self.reuse_scope = None + self.parent_dest = parent_dest + self.base_var = parent_dest.base_var + self.member = member + self.is_helper = False + + def get_ref(self, member): + return self.parent_dest.get_ref(self.member) + "." + member + +def read_array_len(writer, prefix, array, dest, scope, handles_bytes = False): + if array.is_bytes_length(): + nelements = "%s__nbytes" % prefix + else: + nelements = "%s__nelements" % prefix + if dest.is_toplevel(): + return nelements # Already there for toplevel, need not recalculate + element_type = array.element_type + scope.variable_def("uint32_t", nelements) + if array.is_constant_length(): + writer.assign(nelements, array.size) + elif array.is_identifier_length(): + writer.assign(nelements, dest.get_ref(array.size)) + elif array.is_remaining_length(): + if element_type.is_fixed_nw_size(): + writer.assign(nelements, "(message_end - in) / (%s)" %(element_type.get_fixed_nw_size())) + else: + raise NotImplementedError("TODO array[] of dynamic element size not done yet") + elif array.is_image_size_length(): + bpp = array.size[1] + width = array.size[2] + rows = array.size[3] + width_v = dest.get_ref(width) + rows_v = dest.get_ref(rows) + # TODO: Handle multiplication overflow + if bpp == 8: + writer.assign(nelements, "%s * %s" % (width_v, rows_v)) + elif bpp == 1: + writer.assign(nelements, "((%s + 7) / 8 ) * %s" % (width_v, rows_v)) + else: + writer.assign(nelements, "((%s * %s + 7) / 8 ) * %s" % (bpp, width_v, rows_v)) + elif array.is_bytes_length(): + if not handles_bytes: + raise NotImplementedError("handling of bytes() not supported here yet") + writer.assign(nelements, dest.get_ref(array.size[1])) + else: + raise NotImplementedError("TODO array size type not handled yet") + return nelements + +def write_switch_parser(writer, container, switch, dest, scope): + var = container.lookup_member(switch.variable) + var_type = var.member_type + + if switch.has_attr("fixedsize"): + scope.variable_def("uint8_t *", "in_save") + writer.assign("in_save", "in") + + first = True + for c in switch.cases: + check = c.get_check(dest.get_ref(switch.variable), var_type) + m = c.member + with writer.if_block(check, not first, False) as block: + t = m.member_type + if switch.has_end_attr(): + dest2 = dest.child_at_end(writer, m.member_type) + elif switch.has_attr("anon"): + dest2 = dest + else: + if t.is_struct(): + dest2 = dest.child_sub(switch.name + "." + m.name) + else: + dest2 = dest.child_sub(switch.name) + dest2.reuse_scope = block + + if t.is_struct(): + write_container_parser(writer, t, dest2) + elif t.is_pointer(): + write_parse_pointer(writer, t, False, dest2, m.name, not m.has_attr("ptr32"), block) + elif t.is_primitive(): + writer.assign(dest2.get_ref(m.name), "consume_%s(&in)" % (t.primitive_type())) + #TODO validate e.g. flags and enums + elif t.is_array(): + nelements = read_array_len(writer, m.name, t, dest, block) + write_array_parser(writer, nelements, t, dest, block) + else: + writer.todo("Can't handle type %s" % m.member_type) + + first = False + + writer.newline() + + if switch.has_attr("fixedsize"): + writer.assign("in", "in_save + %s" % switch.get_fixed_nw_size()) + +def write_parse_ptr_function(writer, target_type): + if target_type.is_array(): + parse_function = "parse_array_%s" % target_type.element_type.primitive_type() + else: + parse_function = "parse_struct_%s" % target_type.c_type() + if writer.is_generated("parser", parse_function): + return parse_function + + writer.set_is_generated("parser", parse_function) + + writer = writer.function_helper() + scope = writer.function(parse_function, "uint8_t *", "uint8_t *message_start, uint8_t *message_end, uint8_t *struct_data, PointerInfo *this_ptr_info, int minor") + scope.variable_def("uint8_t *", "in = message_start + this_ptr_info->offset") + scope.variable_def("uint8_t *", "end") + + num_pointers = target_type.get_num_pointers() + if num_pointers != 0: + scope.variable_def("SPICE_GNUC_UNUSED intptr_t", "ptr_size"); + scope.variable_def("uint32_t", "n_ptr=0"); + scope.variable_def("PointerInfo", "ptr_info[%s]" % num_pointers) + + writer.newline() + if target_type.is_array(): + writer.assign("end", "struct_data") + else: + writer.assign("end", "struct_data + %s" % (target_type.sizeof())) + + dest = RootDemarshallingDestination(None, target_type.c_type(), target_type.sizeof(), "struct_data") + dest.is_helper = True + dest.reuse_scope = scope + if target_type.is_array(): + write_array_parser(writer, "this_ptr_info->nelements", target_type, dest, scope) + else: + write_container_parser(writer, target_type, dest) + + if num_pointers != 0: + write_ptr_info_check(writer) + + writer.statement("return end") + + if writer.has_error_check: + writer.newline() + writer.label("error") + writer.statement("return NULL") + + writer.end_block() + + return parse_function + +def write_array_parser(writer, nelements, array, dest, scope): + is_byte_size = array.is_bytes_length() + + element_type = array.element_type + if element_type == ptypes.uint8 or element_type == ptypes.int8: + writer.statement("memcpy(end, in, %s)" % (nelements)) + writer.increment("in", nelements) + writer.increment("end", nelements) + else: + if is_byte_size: + scope.variable_def("uint8_t *", "array_end") + writer.assign("array_end", "end + %s" % nelements) + with writer.index(no_block = is_byte_size) as index: + with writer.while_loop("end < array_end") if is_byte_size else writer.for_loop(index, nelements) as array_scope: + if element_type.is_primitive(): + writer.statement("*(%s *)end = consume_%s(&in)" % (element_type.c_type(), element_type.primitive_type())) + writer.increment("end", element_type.sizeof()) + else: + dest2 = dest.child_at_end(writer, element_type) + dest2.reuse_scope = array_scope + write_container_parser(writer, element_type, dest2) + +def write_parse_pointer(writer, t, at_end, dest, member_name, is_64bit, scope): + target_type = t.target_type + if is_64bit: + writer.assign("ptr_info[n_ptr].offset", "consume_uint64(&in)") + else: + writer.assign("ptr_info[n_ptr].offset", "consume_uint32(&in)") + writer.assign("ptr_info[n_ptr].parse", write_parse_ptr_function(writer, target_type)) + if at_end: + writer.assign("ptr_info[n_ptr].dest", "end") + writer.increment("end", "sizeof(SPICE_ADDRESS)"); + else: + writer.assign("ptr_info[n_ptr].dest", "&%s" % dest.get_ref(member_name)) + if target_type.is_array(): + nelements = read_array_len(writer, member_name, target_type, dest, scope) + writer.assign("ptr_info[n_ptr].nelements", nelements) + + writer.statement("n_ptr++") + +def write_member_parser(writer, container, member, dest, scope): + if member.is_switch(): + write_switch_parser(writer, container, member, dest, scope) + return + + t = member.member_type + + if t.is_pointer(): + if member.has_attr("nocopy"): + writer.comment("Reuse data from network message").newline() + writer.assign(dest.get_ref(member.name), "(size_t)(message_start + consume_uint64(&in))") + else: + write_parse_pointer(writer, t, member.has_end_attr(), dest, member.name, not member.has_attr("ptr32"), scope) + elif t.is_primitive(): + if member.has_end_attr(): + writer.statement("*(%s *)end = consume_%s(&in)" % (t.c_type(), t.primitive_type())) + writer.increment("end", t.sizeof()) + else: + writer.assign(dest.get_ref(member.name), "consume_%s(&in)" % (t.primitive_type())) + #TODO validate e.g. flags and enums + elif t.is_array(): + nelements = read_array_len(writer, member.name, t, dest, scope, handles_bytes = True) + write_array_parser(writer, nelements, t, dest, scope) + elif t.is_struct(): + if member.has_end_attr(): + dest2 = dest.child_at_end(writer, t) + else: + dest2 = dest.child_sub(member.name) + writer.comment(member.name) + write_container_parser(writer, t, dest2) + else: + raise NotImplementedError("TODO can't handle parsing of %s" % t) + +def write_container_parser(writer, container, dest): + with dest.declare(writer) as scope: + for m in container.members: + if m.has_minor_attr(): + writer.begin_block("if (minor >= %s)" % m.get_minor_attr()) + write_member_parser(writer, container, m, dest, scope) + if m.has_minor_attr(): + # We need to zero out the fixed part of all optional fields + if not m.member_type.is_array(): + writer.end_block(newline=False) + writer.begin_block(" else") + # TODO: This is not right for fields that don't exist in the struct + if m.member_type.is_primitive(): + writer.assign(dest.get_ref(m.name), "0") + elif m.is_fixed_sizeof(): + writer.statement("memset ((char *)&%s, 0, %s)" % (dest.get_ref(m.name), m.sizeof())) + else: + raise NotImplementedError("TODO Clear optional dynamic fields") + writer.end_block() + +def write_ptr_info_check(writer): + writer.newline() + with writer.index() as index: + with writer.for_loop(index, "n_ptr") as scope: + offset = "ptr_info[%s].offset" % index + function = "ptr_info[%s].parse" % index + dest = "ptr_info[%s].dest" % index + with writer.if_block("%s == 0" % offset, newline=False): + writer.assign("*%s" % dest, "0") + with writer.block(" else"): + writer.comment("Align to 32 bit").newline() + writer.assign("end", "(uint8_t *)SPICE_ALIGN((size_t)end, 4)") + writer.assign("*%s" % dest, "(size_t)end") + writer.assign("end", "%s(message_start, message_end, end, &ptr_info[%s], minor)" % (function, index)) + writer.error_check("end == NULL") + writer.newline() + +def write_msg_parser(writer, message): + msg_name = message.c_name() + function_name = "parse_%s" % msg_name + if writer.is_generated("demarshaller", function_name): + return function_name + writer.set_is_generated("demarshaller", function_name) + + msg_type = message.c_type() + msg_sizeof = message.sizeof() + + writer.newline() + parent_scope = writer.function(function_name, + "uint8_t *", + "uint8_t *message_start, uint8_t *message_end, int minor, size_t *size", True) + parent_scope.variable_def("SPICE_GNUC_UNUSED uint8_t *", "pos"); + parent_scope.variable_def("uint8_t *", "start = message_start"); + parent_scope.variable_def("uint8_t *", "data = NULL"); + parent_scope.variable_def("size_t", "mem_size", "nw_size"); + if not message.has_attr("nocopy"): + parent_scope.variable_def("uint8_t *", "in", "end"); + num_pointers = message.get_num_pointers() + if num_pointers != 0: + parent_scope.variable_def("SPICE_GNUC_UNUSED intptr_t", "ptr_size"); + parent_scope.variable_def("uint32_t", "n_ptr=0"); + parent_scope.variable_def("PointerInfo", "ptr_info[%s]" % num_pointers) + writer.newline() + + write_parser_helpers(writer) + + write_validate_container(writer, None, message, "start", parent_scope, True, True, False) + + writer.newline() + + writer.comment("Check if message fits in reported side").newline() + with writer.block("if (start + nw_size > message_end)"): + writer.statement("return NULL") + + writer.newline().comment("Validated extents and calculated size").newline() + + if message.has_attr("nocopy"): + writer.assign("data", "message_start") + writer.assign("*size", "message_end - message_start") + else: + writer.assign("data", "(uint8_t *)malloc(mem_size)") + writer.error_check("data == NULL") + writer.assign("end", "data + %s" % (msg_sizeof)) + writer.assign("in", "start").newline() + + dest = RootDemarshallingDestination(None, msg_type, msg_sizeof, "data") + dest.reuse_scope = parent_scope + write_container_parser(writer, message, dest) + + writer.newline() + writer.statement("assert(in <= message_end)") + + if num_pointers != 0: + write_ptr_info_check(writer) + + writer.statement("assert(end <= data + mem_size)") + + writer.newline() + writer.assign("*size", "end - data") + + writer.statement("return data") + writer.newline() + if writer.has_error_check: + writer.label("error") + with writer.block("if (data != NULL)"): + writer.statement("free(data)") + writer.statement("return NULL") + writer.end_block() + + return function_name + +def write_channel_parser(writer, channel, server): + writer.newline() + ids = {} + min_id = 1000000 + if server: + messages = channel.server_messages + else: + messages = channel.client_messages + for m in messages: + ids[m.value] = m + + ranges = [] + ids2 = ids.copy() + while len(ids2) > 0: + end = start = min(ids2.keys()) + while ids2.has_key(end): + del ids2[end] + end = end + 1 + + ranges.append( (start, end) ) + + if server: + function_name = "parse_%s_msg" % channel.name + else: + function_name = "parse_%s_msgc" % channel.name + writer.newline() + scope = writer.function(function_name, + "uint8_t *", + "uint8_t *message_start, uint8_t *message_end, uint16_t message_type, int minor, size_t *size_out") + + helpers = writer.function_helper() + + d = 0 + for r in ranges: + d = d + 1 + writer.write("static parse_msg_func_t funcs%d[%d] = " % (d, r[1] - r[0])); + writer.begin_block() + for i in range(r[0], r[1]): + func = write_msg_parser(helpers, ids[i].message_type) + writer.write(func) + if i != r[1] -1: + writer.write(",") + writer.newline() + + writer.end_block(semicolon = True) + + d = 0 + for r in ranges: + d = d + 1 + with writer.if_block("message_type >= %d && message_type < %d" % (r[0], r[1]), d > 1, False): + writer.statement("return funcs%d[message_type-%d](message_start, message_end, minor, size_out)" % (d, r[0])) + writer.newline() + + writer.statement("return NULL") + writer.end_block() + + return function_name + +def write_get_channel_parser(writer, channel_parsers, max_channel, is_server): + writer.newline() + if is_server: + function_name = "spice_get_server_channel_parser" + else: + function_name = "spice_get_client_channel_parser" + + scope = writer.function(function_name, + "spice_parse_channel_func_t", + "uint32_t channel, unsigned int *max_message_type") + + writer.write("static struct {spice_parse_channel_func_t func; unsigned int max_messages; } channels[%d] = " % (max_channel+1)) + writer.begin_block() + for i in range(0, max_channel + 1): + writer.write("{ ") + if channel_parsers.has_key(i): + writer.write(channel_parsers[i][1]) + writer.write(", ") + + channel = channel_parsers[i][0] + max_msg = 0 + if is_server: + messages = channel.server_messages + else: + messages = channel.client_messages + for m in messages: + max_msg = max(max_msg, m.value) + writer.write(max_msg) + else: + writer.write("NULL, 0") + writer.write("}") + + if i != max_channel: + writer.write(",") + writer.newline() + writer.end_block(semicolon = True) + + with writer.if_block("channel < %d" % (max_channel + 1)): + with writer.if_block("max_message_type != NULL"): + writer.assign("*max_message_type", "channels[channel].max_messages") + writer.statement("return channels[channel].func") + + writer.statement("return NULL") + writer.end_block() + + +def write_full_protocol_parser(writer, is_server): + writer.newline() + if is_server: + function_name = "spice_parse_msg" + else: + function_name = "spice_parse_reply" + scope = writer.function(function_name, + "uint8_t *", + "uint8_t *message_start, uint8_t *message_end, uint32_t channel, uint16_t message_type, int minor, size_t *size_out") + scope.variable_def("spice_parse_channel_func_t", "func" ) + + if is_server: + writer.assign("func", "spice_get_server_channel_parser(channel, NULL)") + else: + writer.assign("func", "spice_get_client_channel_parser(channel, NULL)") + + with writer.if_block("func != NULL"): + writer.statement("return func(message_start, message_end, message_type, minor, size_out)") + + writer.statement("return NULL") + writer.end_block() + +def write_protocol_parser(writer, proto, is_server): + max_channel = 0 + parsers = {} + + for channel in proto.channels: + max_channel = max(max_channel, channel.value) + + parsers[channel.value] = (channel.channel_type, write_channel_parser(writer, channel.channel_type, is_server)) + + write_get_channel_parser(writer, parsers, max_channel, is_server) + write_full_protocol_parser(writer, is_server) + +def write_includes(writer): + writer.writeln("#include ") + writer.writeln("#include ") + writer.writeln("#include ") + writer.writeln("#include ") + writer.writeln("#include ") + writer.writeln("#include ") + writer.newline() + writer.writeln("#ifdef _MSC_VER") + writer.writeln("#pragma warning(disable:4101)") + writer.writeln("#endif") diff --git a/python_modules/ptypes.py b/python_modules/ptypes.py new file mode 100644 index 00000000..fe8a3212 --- /dev/null +++ b/python_modules/ptypes.py @@ -0,0 +1,965 @@ +import codegen +import types + +_types_by_name = {} +_types = [] + +def type_exists(name): + return _types_by_name.has_key(name) + +def lookup_type(name): + return _types_by_name[name] + +def get_named_types(): + return _types + +class FixedSize: + def __init__(self, val = 0, minor = 0): + if isinstance(val, FixedSize): + self.vals = val.vals + else: + self.vals = [0] * (minor + 1) + self.vals[minor] = val + + def __add__(self, other): + if isinstance(other, types.IntType): + other = FixedSize(other) + + new = FixedSize() + l = max(len(self.vals), len(other.vals)) + shared = min(len(self.vals), len(other.vals)) + + new.vals = [0] * l + + for i in range(shared): + new.vals[i] = self.vals[i] + other.vals[i] + + for i in range(shared,len(self.vals)): + new.vals[i] = self.vals[i]; + + for i in range(shared,len(other.vals)): + new.vals[i] = new.vals[i] + other.vals[i]; + + return new + + def __radd__(self, other): + return self.__add__(other) + + def __str__(self): + s = "%d" % (self.vals[0]) + + for i in range(1,len(self.vals)): + if self.vals[i] > 0: + s = s + " + ((minor >= %d)?%d:0)" % (i, self.vals[i]) + return s + +class Type: + def __init__(self): + self.attributes = {} + self.registred = False + self.name = None + + def has_name(self): + return self.name != None + + def get_type(self, recursive=False): + return self + + def is_primitive(self): + return False + + def is_fixed_sizeof(self): + return True + + def is_extra_size(self): + return False + + def contains_extra_size(self): + return False + + def is_fixed_nw_size(self): + return True + + def is_array(self): + return isinstance(self, ArrayType) + + def is_struct(self): + return isinstance(self, StructType) + + def is_pointer(self): + return isinstance(self, PointerType) + + def get_num_pointers(self): + return 0 + + def get_pointer_names(self): + return [] + + def sizeof(self): + return "sizeof(%s)" % (self.c_type()) + + def __repr__(self): + return self.__str__() + + def __str__(self): + if self.name != None: + return self.name + return "anonymous type" + + def resolve(self): + return self + + def register(self): + if self.registred or self.name == None: + return + self.registred = True + if _types_by_name.has_key(self.name): + raise Exception, "Type %s already defined" % self.name + _types.append(self) + _types_by_name[self.name] = self + + def has_pointer(self): + return False + + def has_attr(self, name): + return self.attributes.has_key(name) + +class TypeRef(Type): + def __init__(self, name): + Type.__init__(self) + self.name = name + + def __str__(self): + return "ref to %s" % (self.name) + + def resolve(self): + if not _types_by_name.has_key(self.name): + raise Exception, "Unknown type %s" % self.name + return _types_by_name[self.name] + + def register(self): + assert True, "Can't register TypeRef!" + + +class IntegerType(Type): + def __init__(self, bits, signed): + Type.__init__(self) + self.bits = bits + self.signed = signed + + if signed: + self.name = "int%d" % bits + else: + self.name = "uint%d" % bits + + def primitive_type(self): + return self.name + + def c_type(self): + return self.name + "_t" + + def get_fixed_nw_size(self): + return self.bits / 8 + + def is_primitive(self): + return True + +class TypeAlias(Type): + def __init__(self, name, the_type, attribute_list): + Type.__init__(self) + self.name = name + self.the_type = the_type + for attr in attribute_list: + self.attributes[attr[0][1:]] = attr[1:] + + def get_type(self, recursive=False): + if recursive: + return self.the_type.get_type(True) + else: + return self.the_type + + def primitive_type(self): + return self.the_type.primitive_type() + + def resolve(self): + self.the_type = self.the_type.resolve() + return self + + def __str__(self): + return "alias %s" % self.name + + def is_primitive(self): + return self.the_type.is_primitive() + + def is_fixed_sizeof(self): + return self.the_type.is_fixed_sizeof() + + def is_fixed_nw_size(self): + return self.the_type.is_fixed_nw_size() + + def get_fixed_nw_size(self): + return self.the_type.get_fixed_nw_size() + + def get_num_pointers(self): + return self.the_type.get_num_pointers() + + def get_pointer_names(self): + return self.the_type.get_pointer_names() + + def c_type(self): + if self.has_attr("ctype"): + return self.attributes["ctype"][0] + return self.name + + def has_pointer(self): + return self.the_type.has_pointer() + +class EnumBaseType(Type): + def is_enum(self): + return isinstance(self, EnumType) + + def primitive_type(self): + return "uint%d" % (self.bits) + + def c_type(self): + return "uint%d_t" % (self.bits) + + def c_name(self): + return codegen.prefix_camel(self.name) + + def c_enumname(self, value): + if self.has_attr("prefix"): + return self.attributes["prefix"][0] + self.names[value] + return codegen.prefix_underscore_upper(self.name.upper(), self.names[value]) + + def c_enumname_by_name(self, name): + if self.has_attr("prefix"): + return self.attributes["prefix"][0] + self.names[value] + return codegen.prefix_underscore_upper(self.name.upper(), name) + + def is_primitive(self): + return True + + def get_fixed_nw_size(self): + return self.bits / 8 + +class EnumType(EnumBaseType): + def __init__(self, bits, name, enums, attribute_list): + Type.__init__(self) + self.bits = bits + self.name = name + + last = -1 + names = {} + values = {} + for v in enums: + name = v[0] + if len(v) > 1: + value = v[1] + else: + value = last + 1 + last = value + + assert not names.has_key(value) + names[value] = name + values[name] = value + + self.names = names + self.values = values + + for attr in attribute_list: + self.attributes[attr[0][1:]] = attr[1:] + + def __str__(self): + return "enum %s" % self.name + + def c_define(self, writer): + writer.write("enum ") + writer.write(self.c_name()) + writer.begin_block() + values = self.names.keys() + values.sort() + current_default = 0 + for i in values: + writer.write(self.c_enumname(i)) + if i != current_default: + writer.write(" = %d" % (i)) + writer.write(",") + writer.newline() + current_default = i + 1 + writer.newline() + writer.write(codegen.prefix_underscore_upper(self.name.upper(), "ENUM_END")) + writer.newline() + writer.end_block(semicolon=True) + writer.newline() + +class FlagsType(EnumBaseType): + def __init__(self, bits, name, flags, attribute_list): + Type.__init__(self) + self.bits = bits + self.name = name + + last = -1 + names = {} + values = {} + for v in flags: + name = v[0] + if len(v) > 1: + value = v[1] + else: + value = last + 1 + last = value + + assert not names.has_key(value) + names[value] = name + values[name] = value + + self.names = names + self.values = values + + for attr in attribute_list: + self.attributes[attr[0][1:]] = attr[1:] + + def __str__(self): + return "flags %s" % self.name + + def c_define(self, writer): + writer.write("enum ") + writer.write(self.c_name()) + writer.begin_block() + values = self.names.keys() + values.sort() + mask = 0 + for i in values: + writer.write(self.c_enumname(i)) + mask = mask | (1< 0 + + def is_image_size_length(self): + if isinstance(self.size, types.IntType) or isinstance(self.size, types.StringType): + return False + return self.size[0] == "image_size" + + def is_bytes_length(self): + if isinstance(self.size, types.IntType) or isinstance(self.size, types.StringType): + return False + return self.size[0] == "bytes" + + def is_cstring_length(self): + if isinstance(self.size, types.IntType) or isinstance(self.size, types.StringType): + return False + return self.size[0] == "cstring" + + def is_fixed_sizeof(self): + return self.is_constant_length() and self.element_type.is_fixed_sizeof() + + def is_fixed_nw_size(self): + return self.is_constant_length() and self.element_type.is_fixed_nw_size() + + def get_fixed_nw_size(self): + if not self.is_fixed_nw_size(): + raise Exception, "Not a fixed size type" + + return self.element_type.get_fixed_nw_size() * self.size + + def get_num_pointers(self): + element_count = self.element_type.get_num_pointers() + if element_count == 0: + return 0 + if self.is_constant_length(self): + return element_count * self.size + raise Exception, "Pointers in dynamic arrays not supported" + + def get_pointer_names(self): + element_count = self.element_type.get_num_pointers() + if element_count == 0: + return [] + raise Exception, "Pointer names in arrays not supported" + + def contains_extra_size(self): + return self.element_type.contains_extra_size() + + def sizeof(self): + return "%s * %s" % (self.element_type.sizeof(), self.size) + + def c_type(self): + return self.element_type.c_type() + +class PointerType(Type): + def __init__(self, target_type): + Type.__init__(self) + self.name = None + self.target_type = target_type + + def __str__(self): + return "%s*" % (str(self.target_type)) + + def resolve(self): + self.target_type = self.target_type.resolve() + return self + + def get_fixed_size(self): + return 8 # offsets are 64bit + + def is_fixed_nw_size(self): + return True + + def is_primitive(self): + return True + + def primitive_type(self): + return "uint64" + + def get_fixed_nw_size(self): + return 8 + + def c_type(self): + return "SPICE_ADDRESS" + + def has_pointer(self): + return True + + def contains_extra_size(self): + return True + + def get_num_pointers(self): + return 1 + +class Containee: + def __init__(self): + self.attributes = {} + + def is_switch(self): + return False + + def is_pointer(self): + return not self.is_switch() and self.member_type.is_pointer() + + def is_array(self): + return not self.is_switch() and self.member_type.is_array() + + def is_struct(self): + return not self.is_switch() and self.member_type.is_struct() + + def is_primitive(self): + return not self.is_switch() and self.member_type.is_primitive() + + def has_attr(self, name): + return self.attributes.has_key(name) + + def has_minor_attr(self): + return self.has_attr("minor") + + def has_end_attr(self): + return self.has_attr("end") + + def get_minor_attr(self): + return self.attributes["minor"][0] + +class Member(Containee): + def __init__(self, name, member_type, attribute_list): + Containee.__init__(self) + self.name = name + self.member_type = member_type + for attr in attribute_list: + self.attributes[attr[0][1:]] = attr[1:] + + def resolve(self, container): + self.container = container + self.member_type = self.member_type.resolve() + self.member_type.register() + return self + + def is_primitive(self): + return self.member_type.is_primitive() + + def is_fixed_sizeof(self): + if self.has_end_attr(): + return False + return self.member_type.is_fixed_sizeof() + + def is_extra_size(self): + return self.has_end_attr() + + def is_fixed_nw_size(self): + return self.member_type.is_fixed_nw_size() + + def get_fixed_nw_size(self): + size = self.member_type.get_fixed_nw_size() + if self.has_minor_attr(): + minor = self.get_minor_attr() + size = FixedSize(size, minor) + return size + + def contains_extra_size(self): + return self.member_type.contains_extra_size() + + def sizeof(self): + return self.member_type.sizeof() + + def __repr__(self): + return "%s (%s)" % (str(self.name), str(self.member_type)) + + def has_pointer(self): + return self.member_type.has_pointer() + + def get_num_pointers(self): + return self.member_type.get_num_pointers() + + def get_pointer_names(self): + if self.member_type.is_pointer(): + names = [self.name + "_out"] + else: + names = self.member_type.get_pointer_names() + if self.has_attr("outvar"): + prefix = self.attributes["outvar"][0] + names = map(lambda name: prefix + "_" + name, names) + return names + +class SwitchCase: + def __init__(self, values, member): + self.values = values + self.member = member + self.members = [member] + + def get_check(self, var_cname, var_type): + checks = [] + for v in self.values: + if v == None: + return "1" + elif var_type.is_enum(): + checks.append("%s == %s" % (var_cname, var_type.c_enumname_by_name(v))) + else: + checks.append("(%s & %s)" % (var_cname, var_type.c_enumname_by_name(v))) + return " || ".join(checks) + + def resolve(self, container): + self.switch = container + self.member = self.member.resolve(self) + return self + + def has_pointer(self): + return self.member.has_pointer() + + def get_num_pointers(self): + return self.member.get_num_pointers() + + def get_pointer_names(self): + return self.member.get_pointer_names() + +class Switch(Containee): + def __init__(self, variable, cases, name, attribute_list): + Containee.__init__(self) + self.variable = variable + self.name = name + self.cases = cases + for attr in attribute_list: + self.attributes[attr[0][1:]] = attr[1:] + + def is_switch(self): + return True + + def has_switch_member(self, member): + for c in self.cases: + if c.member == member: + return True + return False + + def resolve(self, container): + self.container = container + self.cases = map(lambda c : c.resolve(self), self.cases) + return self + + def __repr__(self): + return "switch on %s %s" % (str(self.variable),str(self.name)) + + def is_fixed_sizeof(self): + # Kinda weird, but we're unlikely to have a real struct if there is an @end + if self.has_end_attr(): + return False + return True + + def is_fixed_nw_size(self): + if self.has_attr("fixedsize"): + return True + + size = None + for c in self.cases: + if not c.member.is_fixed_nw_size(): + return False + if size == None: + size = c.member.get_fixed_nw_size() + elif size != c.member.get_fixed_nw_size(): + return False + return True + + def is_extra_size(self): + return self.has_end_attr() + + def contains_extra_size(self): + for c in self.cases: + if c.member.is_extra_size(): + return True + if c.member.contains_extra_size(): + return True + return False + + def get_fixed_nw_size(self): + if not self.is_fixed_nw_size(): + raise Exception, "Not a fixed size type" + size = 0; + for c in self.cases: + size = max(size, c.member.get_fixed_nw_size()) + return size + + def sizeof(self): + return "sizeof(((%s *)NULL)->%s)" % (self.container.c_type(), + self.name) + + def has_pointer(self): + for c in self.cases: + if c.has_pointer(): + return True + return False + + def get_num_pointers(self): + count = 0 + for c in self.cases: + count = max(count, c.get_num_pointers()) + return count + + def get_pointer_names(self): + names = [] + for c in self.cases: + names = names + c.get_pointer_names() + return names + +class ContainerType(Type): + def is_fixed_sizeof(self): + for m in self.members: + if not m.is_fixed_sizeof(): + return False + return True + + def contains_extra_size(self): + for m in self.members: + if m.is_extra_size(): + return True + if m.contains_extra_size(): + return True + return False + + def is_fixed_nw_size(self): + for i in self.members: + if not i.is_fixed_nw_size(): + return False + return True + + def get_fixed_nw_size(self): + size = 0 + for i in self.members: + size = size + i.get_fixed_nw_size() + return size + + def get_fixed_nw_offset(self, member): + size = 0 + for i in self.members: + if i == member: + break + if i.is_fixed_nw_size(): + size = size + i.get_fixed_nw_size() + return size + + def resolve(self): + self.members = map(lambda m : m.resolve(self), self.members) + return self + + def get_num_pointers(self): + count = 0 + for m in self.members: + count = count + m.get_num_pointers() + return count + + def get_pointer_names(self): + names = [] + for m in self.members: + names = names + m.get_pointer_names() + return names + + def has_pointer(self): + for m in self.members: + if m.has_pointer(): + return True + return False + + def get_nw_offset(self, member, prefix = "", postfix = ""): + fixed = self.get_fixed_nw_offset(member) + v = [] + for m in self.members: + if m == member: + break + if m.is_switch() and m.has_switch_member(member): + break + if not m.is_fixed_nw_size(): + v.append(prefix + m.name + postfix) + if len(v) > 0: + return str(fixed) + " + " + (" + ".join(v)) + else: + return str(fixed) + + def lookup_member(self, name): + return self.members_by_name[name] + +class StructType(ContainerType): + def __init__(self, name, members, attribute_list): + Type.__init__(self) + self.name = name + self.members = members + self.members_by_name = {} + for m in members: + self.members_by_name[m.name] = m + for attr in attribute_list: + self.attributes[attr[0][1:]] = attr[1:] + + def __str__(self): + if self.name == None: + return "anonymous struct" + else: + return "struct %s" % self.name + + def c_type(self): + if self.has_attr("ctype"): + return self.attributes["ctype"][0] + return codegen.prefix_camel(self.name) + +class MessageType(ContainerType): + def __init__(self, name, members, attribute_list): + Type.__init__(self) + self.name = name + self.members = members + self.members_by_name = {} + for m in members: + self.members_by_name[m.name] = m + self.reverse_members = {} # ChannelMembers referencing this message + for attr in attribute_list: + self.attributes[attr[0][1:]] = attr[1:] + + def __str__(self): + if self.name == None: + return "anonymous message" + else: + return "message %s" % self.name + + def c_name(self): + if self.name == None: + cms = self.reverse_members.keys() + if len(cms) != 1: + raise "Unknown typename for message" + cm = cms[0] + channelname = cm.channel.member_name + if channelname == None: + channelname = "" + else: + channelname = channelname + "_" + if cm.is_server: + return "msg_" + channelname + cm.name + else: + return "msgc_" + channelname + cm.name + else: + return codegen.prefix_camel("Msg", self.name) + + def c_type(self): + if self.has_attr("ctype"): + return self.attributes["ctype"][0] + if self.name == None: + cms = self.reverse_members.keys() + if len(cms) != 1: + raise "Unknown typename for message" + cm = cms[0] + channelname = cm.channel.member_name + if channelname == None: + channelname = "" + if cm.is_server: + return codegen.prefix_camel("Msg", channelname, cm.name) + else: + return codegen.prefix_camel("Msgc", channelname, cm.name) + else: + return codegen.prefix_camel("Msg", self.name) + +class ChannelMember(Containee): + def __init__(self, name, message_type, value): + Containee.__init__(self) + self.name = name + self.message_type = message_type + self.value = value + + def resolve(self, channel): + self.channel = channel + self.message_type = self.message_type.resolve() + self.message_type.reverse_members[self] = 1 + + return self + + def __repr__(self): + return "%s (%s)" % (str(self.name), str(self.message_type)) + +class ChannelType(Type): + def __init__(self, name, base, members): + Type.__init__(self) + self.name = name + self.base = base + self.member_name = None + self.members = members + + def __str__(self): + if self.name == None: + return "anonymous channel" + else: + return "channel %s" % self.name + + def is_fixed_nw_size(self): + return False + + def get_client_message(self, name): + return self.client_messages_byname[name] + + def get_server_message(self, name): + return self.server_messages_byname[name] + + def resolve(self): + if self.base != None: + self.base = self.base.resolve() + + server_messages = self.base.server_messages[:] + server_messages_byname = self.base.server_messages_byname.copy() + client_messages = self.base.client_messages[:] + client_messages_byname = self.base.client_messages_byname.copy() + else: + server_messages = [] + server_messages_byname = {} + client_messages = [] + client_messages_byname = {} + + server_count = 1 + client_count = 1 + + server = True + for m in self.members: + if m == "server": + server = True + elif m == "client": + server = False + elif server: + m.is_server = True + m = m.resolve(self) + if m.value: + server_count = m.value + 1 + else: + m.value = server_count + server_count = server_count + 1 + server_messages.append(m) + server_messages_byname[m.name] = m + else: + m.is_server = False + m = m.resolve(self) + if m.value: + client_count = m.value + 1 + else: + m.value = client_count + client_count = client_count + 1 + client_messages.append(m) + client_messages_byname[m.name] = m + + self.server_messages = server_messages + self.server_messages_byname = server_messages_byname + self.client_messages = client_messages + self.client_messages_byname = client_messages_byname + + return self + +class ProtocolMember: + def __init__(self, name, channel_type, value): + self.name = name + self.channel_type = channel_type + self.value = value + + def resolve(self, protocol): + self.channel_type = self.channel_type.resolve() + assert(self.channel_type.member_name == None) + self.channel_type.member_name = self.name + return self + + def __repr__(self): + return "%s (%s)" % (str(self.name), str(self.channel_type)) + +class ProtocolType(Type): + def __init__(self, name, channels): + Type.__init__(self) + self.name = name + self.channels = channels + + def __str__(self): + if self.name == None: + return "anonymous protocol" + else: + return "protocol %s" % self.name + + def is_fixed_nw_size(self): + return False + + def resolve(self): + count = 1 + for m in self.channels: + m = m.resolve(self) + if m.value: + count = m.value + 1 + else: + m.value = count + count = count + 1 + + return self + +int8 = IntegerType(8, True) +uint8 = IntegerType(8, False) +int16 = IntegerType(16, True) +uint16 = IntegerType(16, False) +int32 = IntegerType(32, True) +uint32 = IntegerType(32, False) +int64 = IntegerType(64, True) +uint64 = IntegerType(64, False) diff --git a/python_modules/spice_parser.py b/python_modules/spice_parser.py new file mode 100644 index 00000000..65916b35 --- /dev/null +++ b/python_modules/spice_parser.py @@ -0,0 +1,157 @@ +from pyparsing import Literal, CaselessLiteral, Word, OneOrMore, ZeroOrMore, \ + Forward, delimitedList, Group, Optional, Combine, alphas, nums, restOfLine, cStyleComment, \ + alphanums, ParseException, ParseResults, Keyword, StringEnd, replaceWith + +import ptypes +import sys + +cvtInt = lambda toks: int(toks[0]) + +def parseVariableDef(toks): + t = toks[0][0] + pointer = toks[0][1] + name = toks[0][2] + array_size = toks[0][3] + attributes = toks[0][4] + + if array_size != None: + t = ptypes.ArrayType(t, array_size) + + if pointer != None: + t = ptypes.PointerType(t); + + return ptypes.Member(name, t, attributes) + +bnf = None +def SPICE_BNF(): + global bnf + + if not bnf: + + # punctuation + colon = Literal(":").suppress() + lbrace = Literal("{").suppress() + rbrace = Literal("}").suppress() + lbrack = Literal("[").suppress() + rbrack = Literal("]").suppress() + lparen = Literal("(").suppress() + rparen = Literal(")").suppress() + equals = Literal("=").suppress() + comma = Literal(",").suppress() + semi = Literal(";").suppress() + + # primitive types + int8_ = Keyword("int8").setParseAction(replaceWith(ptypes.int8)) + uint8_ = Keyword("uint8").setParseAction(replaceWith(ptypes.uint8)) + int16_ = Keyword("int16").setParseAction(replaceWith(ptypes.int16)) + uint16_ = Keyword("uint16").setParseAction(replaceWith(ptypes.uint16)) + int32_ = Keyword("int32").setParseAction(replaceWith(ptypes.int32)) + uint32_ = Keyword("uint32").setParseAction(replaceWith(ptypes.uint32)) + int64_ = Keyword("int64").setParseAction(replaceWith(ptypes.int64)) + uint64_ = Keyword("uint64").setParseAction(replaceWith(ptypes.uint64)) + + # keywords + channel_ = Keyword("channel") + enum32_ = Keyword("enum32").setParseAction(replaceWith(32)) + enum16_ = Keyword("enum16").setParseAction(replaceWith(16)) + enum8_ = Keyword("enum8").setParseAction(replaceWith(8)) + flags32_ = Keyword("flags32").setParseAction(replaceWith(32)) + flags16_ = Keyword("flags16").setParseAction(replaceWith(16)) + flags8_ = Keyword("flags8").setParseAction(replaceWith(8)) + channel_ = Keyword("channel") + server_ = Keyword("server") + client_ = Keyword("client") + protocol_ = Keyword("protocol") + typedef_ = Keyword("typedef") + struct_ = Keyword("struct") + message_ = Keyword("message") + image_size_ = Keyword("image_size") + bytes_ = Keyword("bytes") + cstring_ = Keyword("cstring") + switch_ = Keyword("switch") + default_ = Keyword("default") + case_ = Keyword("case") + + identifier = Word( alphas, alphanums + "_" ) + enumname = Word( alphanums + "_" ) + + integer = ( Combine( CaselessLiteral("0x") + Word( nums+"abcdefABCDEF" ) ) | + Word( nums+"+-", nums ) ).setName("int").setParseAction(cvtInt) + + typename = identifier.copy().setParseAction(lambda toks : ptypes.TypeRef(str(toks[0]))) + + # This is just normal "types", i.e. not channels or messages + typeSpec = Forward() + + attributeValue = integer ^ identifier + attribute = Group(Combine ("@" + identifier) + Optional(lparen + delimitedList(attributeValue) + rparen)) + attributes = Group(ZeroOrMore(attribute)) + arraySizeSpecImage = Group(image_size_ + lparen + integer + comma + identifier + comma + identifier + rparen) + arraySizeSpecBytes = Group(bytes_ + lparen + identifier + rparen) + arraySizeSpecCString = Group(cstring_ + lparen + rparen) + arraySizeSpec = lbrack + Optional(identifier ^ integer ^ arraySizeSpecImage ^ arraySizeSpecBytes ^arraySizeSpecCString, default="") + rbrack + variableDef = Group(typeSpec + Optional("*", default=None) + identifier + Optional(arraySizeSpec, default=None) + attributes - semi) \ + .setParseAction(parseVariableDef) + + switchCase = Group(Group(OneOrMore(default_.setParseAction(replaceWith(None)) + colon | case_.suppress() + identifier + colon)) + variableDef) \ + .setParseAction(lambda toks: ptypes.SwitchCase(toks[0][0], toks[0][1])) + switchBody = Group(switch_ + lparen + identifier + rparen + lbrace + Group(OneOrMore(switchCase)) + rbrace + identifier + attributes - semi) \ + .setParseAction(lambda toks: ptypes.Switch(toks[0][1], toks[0][2], toks[0][3], toks[0][4])) + messageBody = structBody = Group(lbrace + ZeroOrMore(variableDef | switchBody) + rbrace) + structSpec = Group(struct_ + identifier + structBody + attributes).setParseAction(lambda toks: ptypes.StructType(toks[0][1], toks[0][2], toks[0][3])) + + # have to use longest match for type, in case a user-defined type name starts with a keyword type, like "channel_type" + typeSpec << ( structSpec ^ int8_ ^ uint8_ ^ int16_ ^ uint16_ ^ + int32_ ^ uint32_ ^ int64_ ^ uint64_ ^ + typename).setName("type") + + flagsBody = enumBody = Group(lbrace + delimitedList(Group (enumname + Optional(equals + integer))) + Optional(comma) + rbrace) + + messageSpec = Group(message_ + messageBody + attributes).setParseAction(lambda toks: ptypes.MessageType(None, toks[0][1], toks[0][2])) | typename + + channelParent = Optional(colon + typename, default=None) + channelMessage = Group(messageSpec + identifier + Optional(equals + integer, default=None) + semi) \ + .setParseAction(lambda toks: ptypes.ChannelMember(toks[0][1], toks[0][0], toks[0][2])) + channelBody = channelParent + Group(lbrace + ZeroOrMore( server_ + colon | client_ + colon | channelMessage) + rbrace) + + enum_ = (enum32_ | enum16_ | enum8_) + flags_ = (flags32_ | flags16_ | flags8_) + enumDef = Group(enum_ + identifier + enumBody + attributes - semi).setParseAction(lambda toks: ptypes.EnumType(toks[0][0], toks[0][1], toks[0][2], toks[0][3])) + flagsDef = Group(flags_ + identifier + flagsBody + attributes - semi).setParseAction(lambda toks: ptypes.FlagsType(toks[0][0], toks[0][1], toks[0][2], toks[0][3])) + messageDef = Group(message_ + identifier + messageBody + attributes - semi).setParseAction(lambda toks: ptypes.MessageType(toks[0][1], toks[0][2], toks[0][3])) + channelDef = Group(channel_ + identifier + channelBody - semi).setParseAction(lambda toks: ptypes.ChannelType(toks[0][1], toks[0][2], toks[0][3])) + structDef = Group(struct_ + identifier + structBody + attributes - semi).setParseAction(lambda toks: ptypes.StructType(toks[0][1], toks[0][2], toks[0][3])) + typedefDef = Group(typedef_ + identifier + typeSpec + attributes - semi).setParseAction(lambda toks: ptypes.TypeAlias(toks[0][1], toks[0][2], toks[0][3])) + + definitions = typedefDef | structDef | enumDef | flagsDef | messageDef | channelDef + + protocolChannel = Group(typename + identifier + Optional(equals + integer, default=None) + semi) \ + .setParseAction(lambda toks: ptypes.ProtocolMember(toks[0][1], toks[0][0], toks[0][2])) + protocolDef = Group(protocol_ + identifier + Group(lbrace + ZeroOrMore(protocolChannel) + rbrace) + semi) \ + .setParseAction(lambda toks: ptypes.ProtocolType(toks[0][1], toks[0][2])) + + bnf = ZeroOrMore (definitions) + protocolDef + StringEnd() + + singleLineComment = "//" + restOfLine + bnf.ignore( singleLineComment ) + bnf.ignore( cStyleComment ) + + return bnf + + +def parse(filename): + try: + bnf = SPICE_BNF() + types = bnf.parseFile(filename) + except ParseException, err: + print >> sys.stderr, err.line + print >> sys.stderr, " "*(err.column-1) + "^" + print >> sys.stderr, err + return None + + for t in types: + t.resolve() + t.register() + protocol = types[-1] + return protocol + diff --git a/spice.proto b/spice.proto new file mode 100644 index 00000000..dec6a63a --- /dev/null +++ b/spice.proto @@ -0,0 +1,1086 @@ +/* built in types: + int8, uint8, 16, 32, 64 +*/ + +typedef fixed28_4 int32 @ctype(SPICE_FIXED28_4); + +struct Point { + int32 x; + int32 y; +}; + +struct Point16 { + int16 x; + int16 y; +}; + +struct PointFix { + fixed28_4 x; + fixed28_4 y; +}; + +struct Rect { + int32 top; + int32 left; + int32 bottom; + int32 right; +}; + +enum32 link_err { + OK, + ERROR, + INVALID_MAGIC, + INVALID_DATA, + VERSION_MISMATCH, + NEED_SECURED, + NEED_UNSECURED, + PERMISSION_DENIED, + BAD_CONNECTION_ID, + CHANNEL_NOT_AVAILABLE +}; + +enum32 warn_code { + WARN_GENERAL +} @prefix(SPICE_); + +enum32 info_code { + INFO_GENERAL +} @prefix(SPICE_); + +flags32 migrate_flags { + NEED_FLUSH, + NEED_DATA_TRANSFER +} @prefix(SPICE_MIGRATE_); + +enum32 notify_severity { + INFO, + WARN, + ERROR, +}; + +enum32 notify_visibility { + LOW, + MEDIUM, + HIGH, +}; + +flags32 mouse_mode { + SERVER, + CLIENT, +}; + +enum16 pubkey_type { + INVALID, + RSA, + RSA2, + DSA, + DSA1, + DSA2, + DSA3, + DSA4, + DH, + EC, +}; + +message Empty { +}; + +message Data { + uint8 data[] @end @ctype(uint8_t); +} @nocopy; + +struct ChannelWait { + uint8 channel_type; + uint8 channel_id; + uint64 message_serial; +} @ctype(SpiceWaitForChannel); + +channel BaseChannel { + server: + message { + migrate_flags flags; + } migrate; + + Data migrate_data; + + message { + uint32 generation; + uint32 window; + } set_ack; + + message { + uint32 id; + uint64 timestamp; + uint8 data[] @end @ctype(uint8_t); + } ping; + + message { + uint8 wait_count; + ChannelWait wait_list[wait_count] @end; + } wait_for_channels; + + message { + uint64 time_stamp; + link_err reason; + } @ctype(SpiceMsgDisconnect) disconnecting; + + message { + uint64 time_stamp; + notify_severity severity; + notify_visibility visibilty; + uint32 what; /* error_code/warn_code/info_code */ + uint32 message_len; + uint8 message[message_len] @end; + uint8 zero @end @ctype(uint8_t) @zero; + } notify; + + client: + message { + uint32 generation; + } ack_sync; + + Empty ack; + + message { + uint32 id; + uint64 timestamp; + } @ctype(SpiceMsgPing) pong; + + Empty migrate_flush_mark; + + Data migrate_data; + + message { + uint64 time_stamp; + link_err reason; + } @ctype(SpiceMsgDisconnect) disconnecting; +}; + +struct ChannelId { + uint8 type; + uint8 id; +}; + +channel MainChannel : BaseChannel { + server: + message { + uint16 port; + uint16 sport; + uint32 host_offset; + uint32 host_size; + pubkey_type pub_key_type @minor(2); + uint32 pub_key_offset @minor(2); + uint32 pub_key_size @minor(2); + uint8 host_data[host_size] @end @ctype(uint8_t) @zero_terminated; + uint8 pub_key_data[pub_key_size] @minor(2) @end @ctype(uint8_t) @zero_terminated; + } @ctype(SpiceMsgMainMigrationBegin) migrate_begin = 101; + + Empty migrate_cancel; + + message { + uint32 session_id; + uint32 display_channels_hint; + uint32 supported_mouse_modes; + uint32 current_mouse_mode; + uint32 agent_connected; + uint32 agent_tokens; + uint32 multi_media_time; + uint32 ram_hint; + } init; + + message { + uint32 num_of_channels; + ChannelId channels[num_of_channels] @end; + } @ctype(SpiceMsgChannels) channels_list; + + message { + mouse_mode supported_modes; + mouse_mode current_mode @unique_flag; + } mouse_mode; + + message { + uint32 time; + } @ctype(SpiceMsgMainMultiMediaTime) multi_media_time; + + Empty agent_connected; + + message { + link_err error_code; + } @ctype(SpiceMsgMainAgentDisconnect) agent_disconnected; + + Data agent_data; + + message { + uint32 num_tokens; + } @ctype(SpiceMsgMainAgentTokens) agent_token; + + message { + uint16 port; + uint16 sport; + uint32 host_offset; + uint32 host_size; + uint32 cert_subject_offset; + uint32 cert_subject_size; + uint8 host_data[host_size] @end @ctype(uint8_t) @zero_terminated; + uint8 cert_subject_data[cert_subject_size] @end @ctype(uint8_t) @zero_terminated; + } @ctype(SpiceMsgMainMigrationSwitchHost) migrate_switch_host; + + client: + message { + uint64 cache_size; + } @ctype(SpiceMsgcClientInfo) client_info = 101; + + Empty migrate_connected; + + Empty migrate_connect_error; + + Empty attach_channels; + + message { + mouse_mode mode; + } mouse_mode_request; + + message { + uint32 num_tokens; + } agent_start; + + Data agent_data; + + message { + uint32 num_tokens; + } @ctype(SpiceMsgcMainAgentTokens) agent_token; +}; + +enum32 clip_type { + NONE, + RECTS, + PATH, +}; + +flags32 path_flags { /* TODO: C enum names changes */ + BEGIN = 0, + END = 1, + CLOSE = 3, + BEZIER = 4, +} @prefix(SPICE_PATH_); + +enum32 video_codec_type { + MJPEG = 1, +}; + +flags32 stream_flags { + TOP_DOWN = 0, +}; + +enum32 brush_type { + NONE, + SOLID, + PATTERN, +}; + +flags8 mask_flags { + INVERS, +}; + +enum8 image_type { + BITMAP, + QUIC, + RESERVED, + LZ_PLT = 100, + LZ_RGB, + GLZ_RGB, + FROM_CACHE, + SURFACE, + JPEG, + FROM_CACHE_LOSSLESS, +}; + +flags8 image_flags { + CACHE_ME, + HIGH_BITS_SET, + CACHE_REPLACE_ME, +}; + +enum8 bitmap_fmt { + INVALID, + 1BIT_LE, + 1BIT_BE, + 4BIT_LE, + 4BIT_BE, + 8BIT /* 8bit indexed mode */, + 16BIT, /* 0555 mode */ + 24BIT /* 3 byte, brg */, + 32BIT /* 4 byte, xrgb in little endian format */, + RGBA /* 4 byte, argb in little endian format */ +}; + +flags8 bitmap_flags { + PAL_CACHE_ME, + PAL_FROM_CACHE, + TOP_DOWN, +}; + +enum8 image_scale_mode { + INTERPOLATE, + NEAREST, +}; + +flags16 ropd { + INVERS_SRC, + INVERS_BRUSH, + INVERS_DEST, + OP_PUT, + OP_OR, + OP_AND, + OP_XOR, + OP_BLACKNESS, + OP_WHITENESS, + OP_INVERS, + INVERS_RES, +}; + +flags8 line_flags { + STYLED = 3, + START_WITH_GAP = 2, +}; + +enum8 line_cap { + ROUND, + SQUARE, + BUTT, +}; + +enum8 line_join { + ROUND, + BEVEL, + MITER, +}; + +flags16 string_flags { + RASTER_A1, + RASTER_A4, + RASTER_A8, + RASTER_TOP_DOWN, +}; + +flags32 surface_flags { + PRIMARY +}; + +enum32 surface_fmt { + INVALID, + 1_A = 1, + 8_A = 8, + 16_555 = 16 , + 16_565 = 80, + 32_xRGB = 32, + 32_ARGB = 96 +}; + +flags16 alpha_flags { + DEST_HAS_ALPHA, + SRC_SURFACE_HAS_ALPHA +}; + +enum8 resource_type { + INVALID, + PIXMAP +} @prefix(SPICE_RES_TYPE_); + +struct ClipRects { + uint32 num_rects; + Rect rects[num_rects] @end; +}; + +struct PathSegment { + path_flags flags; + uint32 count; + PointFix points[count] @end; +} @ctype(SpicePathSeg); + +struct Path { + uint32 size; + PathSegment segments[bytes(size)] @end; +}; + +struct Clip { + clip_type type; + switch (type) { + case NONE: + uint64 data @zero; + case RECTS: + ClipRects *data; + case PATH: + Path *data; + } u @anon; +}; + +struct DisplayBase { + uint32 surface_id; + Rect box; + Clip clip; +} @ctype(SpiceMsgDisplayBase); + +struct ResourceID { + uint8 type; + uint64 id; +}; + +struct WaitForChannel { + uint8 channel_type; + uint8 channel_id; + uint64 message_serial; +}; + +struct Palette { + uint64 unique; + uint16 num_ents; + uint32 ents[num_ents] @end; +}; + +struct BitmapData { + bitmap_fmt format; + bitmap_flags flags; + uint32 x; + uint32 y; + uint32 stride; + switch (flags) { + case PAL_FROM_CACHE: + uint64 palette; + default: + Palette *palette; + } pal @anon; + uint8 *data[image_size(8, stride, y)] @nocopy; /* pointer to array, not array of pointers as in C */ +} @ctype(SpiceBitmap); + +struct BinaryData { + uint32 data_size; + uint8 data[data_size] @end; +} @ctype(SpiceQUICData); + +struct LZPLTData { + bitmap_flags flags; + uint32 data_size; + switch (flags) { + case PAL_FROM_CACHE: + uint64 palette; + default: + Palette *palette @nonnull; + } pal @anon; + uint8 data[data_size] @end; +}; + +struct Surface { + uint32 surface_id; +}; + +struct Image { + uint64 id; + image_type type; + image_flags flags; + uint32 width; + uint32 height; + + switch (type) { + case BITMAP: + BitmapData bitmap_data @ctype(SpiceBitmap); + case QUIC: + case LZ_RGB: + case GLZ_RGB: + case JPEG: + BinaryData binary_data @ctype(SpiceQUICData); + case LZ_PLT: + LZPLTData lzplt_data @ctype(SpiceLZPLTData); + case SURFACE: + Surface surface_data; + } u @end; +} @ctype(SpiceImageDescriptor); + +struct Pattern { + Image *pat @nonnull; + Point pos; +}; + +struct Brush { + brush_type type; + switch (type) { + case SOLID: + uint32 color; + case PATTERN: + Pattern pattern; + } u @fixedsize; +}; + +struct QMask { + mask_flags flags; + Point pos; + Image *bitmap; +}; + +struct LineAttr { + line_flags flags; + line_join join_style; + line_cap end_style; + uint8 style_nseg; + fixed28_4 width; + fixed28_4 miter_limit; + fixed28_4 *style[style_nseg]; +}; + +struct RasterGlyphA1 { + Point render_pos; + Point glyph_origin; + uint16 width; + uint16 height; + uint8 data[image_size(1, width, height)] @end; +} @ctype(SpiceRasterGlyph); + +struct RasterGlyphA4 { + Point render_pos; + Point glyph_origin; + uint16 width; + uint16 height; + uint8 data[image_size(4, width, height)] @end; +} @ctype(SpiceRasterGlyph); + +struct RasterGlyphA8 { + Point render_pos; + Point glyph_origin; + uint16 width; + uint16 height; + uint8 data[image_size(8, width, height)] @end; +} @ctype(SpiceRasterGlyph); + +struct String { + uint16 length; + string_flags flags; /* Special: Only one of a1/a4/a8 set */ + switch (flags) { + case RASTER_A1: + RasterGlyphA1 glyphs[length] @ctype(SpiceRasterGlyph); + case RASTER_A4: + RasterGlyphA4 glyphs[length] @ctype(SpiceRasterGlyph); + case RASTER_A8: + RasterGlyphA8 glyphs[length] @ctype(SpiceRasterGlyph); + } u @end; +}; + +channel DisplayChannel : BaseChannel { + server: + message { + uint32 x_res; + uint32 y_res; + uint32 bits; + } mode = 101; + + Empty mark; + Empty reset; + message { + DisplayBase base; + Point src_pos; + } copy_bits; + + message { + uint16 count; + ResourceID resources[count] @end; + } @ctype(SpiceResourceList) inval_list; + + message { + uint8 wait_count; + WaitForChannel wait_list[wait_count] @end; + } @ctype(SpiceMsgWaitForChannels) inval_all_pixmaps; + + message { + uint64 id; + } @ctype(SpiceMsgDisplayInvalOne) inval_palette; + + Empty inval_all_palettes; + + message { + uint32 surface_id; + uint32 id; + stream_flags flags; + video_codec_type codec_type; + uint64 stamp; + uint32 stream_width; + uint32 stream_height; + uint32 src_width; + uint32 src_height; + Rect dest; + Clip clip; + } stream_create = 122; + + message { + uint32 id; + uint32 multi_media_time; + uint32 data_size; + uint32 pad_size; + uint8 data[data_size] @end; + uint8 padding[pad_size] @end @ctype(uint8_t); /* Uhm, why are we sending padding over network? */ + } stream_data; + + message { + uint32 id; + Clip clip; + } stream_clip; + + message { + uint32 id; + } stream_destroy; + + Empty stream_destroy_all; + + message { + DisplayBase base; + struct Fill { + Brush brush; + uint16 rop_decriptor; + QMask mask; + } data; + } draw_fill = 302; + + message { + DisplayBase base; + struct Opaque { + Image *src_bitmap; + Rect src_area; + Brush brush; + ropd rop_decriptor; + image_scale_mode scale_mode; + QMask mask; + } data; + } draw_opaque; + + message { + DisplayBase base; + struct Copy { + Image *src_bitmap; + Rect src_area; + ropd rop_decriptor; + image_scale_mode scale_mode; + QMask mask; + } data; + } draw_copy; + + message { + DisplayBase base; + struct Blend { + Image *src_bitmap; + Rect src_area; + ropd rop_decriptor; + image_scale_mode scale_mode; + QMask mask; + } @ctype(SpiceCopy) data; + } draw_blend; + + message { + DisplayBase base; + struct Blackness { + QMask mask; + } data; + } draw_blackness; + + message { + DisplayBase base; + struct Whiteness { + QMask mask; + } data; + } draw_whiteness; + + message { + DisplayBase base; + struct Invers { + QMask mask; + } data; + } draw_invers; + + message { + DisplayBase base; + struct Rop3 { + Image *src_bitmap; + Rect src_area; + Brush brush; + uint8 rop3; + image_scale_mode scale_mode; + QMask mask; + } data; + } draw_rop3; + + message { + DisplayBase base; + struct Stroke { + Path *path; + LineAttr attr; + Brush brush; + uint16 fore_mode; + uint16 back_mode; + } data; + } draw_stroke; + + message { + DisplayBase base; + struct Text { + String *str; + Rect back_area; + Brush fore_brush; + Brush back_brush; + uint16 fore_mode; + uint16 back_mode; + } data; + } draw_text; + + message { + DisplayBase base; + struct Transparent { + Image *src_bitmap; + Rect src_area; + uint32 src_color; + uint32 true_color; + } data; + } draw_transparent; + + message { + DisplayBase base; + struct AlphaBlnd { + alpha_flags alpha_flags; + uint8 alpha; + Image *src_bitmap; + Rect src_area; + } data; + } draw_alpha_blend; + + message { + uint32 surface_id; + uint32 width; + uint32 height; + uint32 format; + surface_flags flags; + } @ctype(SpiceMsgSurfaceCreate) surface_create; + + message { + uint32 surface_id; + } @ctype(SpiceMsgSurfaceDestroy) surface_destroy; + + client: + message { + uint8 pixmap_cache_id; + int64 pixmap_cache_size; //in pixels + uint8 glz_dictionary_id; + int32 glz_dictionary_window_size; // in pixels + } init = 101; +}; + +flags32 keyboard_modifier_flags { + SCROLL_LOCK, + NUM_LOCK, + CAPS_LOCK +}; + +enum32 mouse_button { + INVALID, + LEFT, + MIDDLE, + RIGHT, + UP, + DOWN, +}; + +flags32 mouse_button_mask { + LEFT, + MIDDLE, + RIGHT +}; + +channel InputsChannel : BaseChannel { + client: + message { + uint32 code; + } @ctype(SpiceMsgcKeyDown) key_down = 101; + + message { + uint32 code; + } @ctype(SpiceMsgcKeyUp) key_up; + + message { + keyboard_modifier_flags modifiers; + } @ctype(SpiceMsgcKeyModifiers) key_modifiers; + + message { + int32 dx; + int32 dy; + mouse_button_mask buttons_state; + } @ctype(SpiceMsgcMouseMotion) mouse_motion = 111; + + message { + uint32 x; + uint32 y; + mouse_button_mask buttons_state; + uint8 display_id; + } @ctype(SpiceMsgcMousePosition) mouse_position; + + message { + mouse_button button; + mouse_button_mask buttons_state; + } @ctype(SpiceMsgcMousePress) mouse_press; + + message { + mouse_button button; + mouse_button_mask buttons_state; + } @ctype(SpiceMsgcMouseRelease) mouse_release; + + server: + message { + keyboard_modifier_flags keyboard_modifiers; + } init = 101; + + message { + keyboard_modifier_flags modifiers; + } key_modifiers; + + Empty mouse_motion_ack = 111; +}; + +enum16 cursor_type { + ALPHA, + MONO, + COLOR4, + COLOR8, + COLOR16, + COLOR24, + COLOR32, +}; + +flags32 cursor_flags { + NONE, /* Means no cursor */ + CACHE_ME, + FROM_CACHE, +}; + +struct CursorHeader { + uint64 unique; + cursor_type type; + uint16 width; + uint16 height; + uint16 hot_spot_x; + uint16 hot_spot_y; +}; + +struct Cursor { + cursor_flags flags; + CursorHeader header; + uint8 data[] @end; +}; + +channel CursorChannel : BaseChannel { + server: + message { + Point16 position; + uint16 trail_length; + uint16 trail_frequency; + uint8 visible; + Cursor cursor; + } init = 101; + + Empty reset; + + message { + Point16 position; + uint8 visible; + Cursor cursor; + } set; + + message { + Point16 position; + } move; + + Empty hide; + + message { + uint16 length; + uint16 frequency; + } trail; + + message { + uint64 id; + } @ctype(SpiceMsgDisplayInvalOne) inval_one; + + Empty inval_all; +}; + +enum32 audio_data_mode { + INVALID, + RAW, + CELT_0_5_1, +}; + +enum32 audio_fmt { + INVALID, + S16, +}; + +channel PlaybackChannel : BaseChannel { + server: + message { + uint32 time; + uint8 data[] @end; + } @ctype(SpiceMsgPlaybackPacket) data = 101; + + message { + uint32 time; + audio_data_mode mode; + uint8 data[] @end; + } mode; + + message { + uint32 channels; + audio_fmt format; + uint32 frequency; + uint32 time; + } start; + + Empty stop; +}; + +channel RecordChannel : BaseChannel { + server: + message { + uint32 channels; + audio_fmt format; + uint32 frequency; + } start = 101; + + Empty stop; + client: + message { + uint32 time; + uint8 data[] @end; + } @ctype(SpiceMsgcRecordPacket) data = 101; + + message { + uint32 time; + audio_data_mode mode; + uint8 data[] @end; + } mode; + + message { + uint32 time; + } start_mark; +}; + +enum32 tunnel_service_type { + INVALID, + GENERIC, + IPP, +}; + +enum16 tunnel_ip_type { + INVALID, + IPv4, +}; + +struct TunnelIpInfo { + tunnel_ip_type type; + switch (type) { + case IPv4: + uint8 ipv4[4] @ctype(uint8_t); + } u @end; +} @ctype(SpiceMsgTunnelIpInfo); + +channel TunnelChannel : BaseChannel { + server: + message { + uint16 max_num_of_sockets; + uint32 max_socket_data_size; + } init = 101; + + message { + uint32 service_id; + TunnelIpInfo virtual_ip; + } service_ip_map; + + message { + uint16 connection_id; + uint32 service_id; + uint32 tokens; + } socket_open; + + message { + uint16 connection_id; + } socket_fin; + + message { + uint16 connection_id; + } socket_close; + + message { + uint16 connection_id; + uint8 data[] @end; + } socket_data; + + message { + uint16 connection_id; + } socket_closed_ack; + + message { + uint16 connection_id; + uint32 num_tokens; + } @ctype(SpiceMsgTunnelSocketTokens) socket_token; + + client: + message { + tunnel_service_type type; + uint32 id; + uint32 group; + uint32 port; + uint32 name; + uint32 description; + switch (type) { + case IPP: + TunnelIpInfo ip @ctype(SpiceMsgTunnelIpInfo); + } u @end; + } @ctype(SpiceMsgcTunnelAddGenericService) service_add = 101; + + message { + uint32 id; + } @ctype(SpiceMsgcTunnelRemoveService) service_remove; + + message { + uint16 connection_id; + uint32 tokens; + } socket_open_ack; + + message { + uint16 connection_id; + } socket_open_nack; + + message { + uint16 connection_id; + } socket_fin; + + message { + uint16 connection_id; + } socket_closed; + + message { + uint16 connection_id; + } socket_closed_ack; + + message { + uint16 connection_id; + uint8 data[] @end; + } socket_data; + + message { + uint16 connection_id; + uint32 num_tokens; + } @ctype(SpiceMsgcTunnelSocketTokens) socket_token; +}; + +protocol Spice { + MainChannel main = 1; + DisplayChannel display; + InputsChannel inputs; + CursorChannel cursor; + PlaybackChannel playback; + RecordChannel record; + TunnelChannel tunnel; +}; diff --git a/spice_gen.py b/spice_gen.py new file mode 100755 index 00000000..f897ce81 --- /dev/null +++ b/spice_gen.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python + +import os +import sys +from optparse import OptionParser +import traceback +from python_modules import spice_parser +from python_modules import ptypes +from python_modules import codegen +from python_modules import demarshal + +def write_channel_enums(writer, channel, client): + messages = filter(lambda m : m.channel == channel, \ + channel.client_messages if client else channel.server_messages) + if len(messages) == 0: + return + writer.begin_block("enum") + i = 0; + if client: + prefix = [ "MSGC" ] + else: + prefix = [ "MSG" ] + if channel.member_name: + prefix.append(channel.member_name.upper()) + prefix.append(None) # To be replaced with name + for m in messages: + prefix[-1] = m.name.upper() + enum = codegen.prefix_underscore_upper(*prefix) + if m.value == i: + writer.writeln("%s," % enum) + i = i + 1 + else: + writer.writeln("%s = %s," % (enum, m.value)) + i = m.value + 1 + if channel.member_name: + prefix[-1] = prefix[-2] + prefix[-2] = "END" + writer.newline() + writer.writeln("%s" % (codegen.prefix_underscore_upper(*prefix))) + writer.end_block(semicolon=True) + writer.newline() + +def write_enums(writer): + writer.writeln("#ifndef _H_SPICE_ENUMS") + writer.writeln("#define _H_SPICE_ENUMS") + writer.newline() + writer.comment("Generated from %s, don't edit" % writer.options["source"]).newline() + writer.newline() + + # Define enums + for t in ptypes.get_named_types(): + if isinstance(t, ptypes.EnumBaseType): + t.c_define(writer) + + i = 0; + writer.begin_block("enum") + for c in proto.channels: + enum = codegen.prefix_underscore_upper("CHANNEL", c.name.upper()) + if c.value == i: + writer.writeln("%s," % enum) + i = i + 1 + else: + writer.writeln("%s = %s," % (enum, c.value)) + i = c.value + 1 + writer.newline() + writer.writeln("SPICE_END_CHANNEL") + writer.end_block(semicolon=True) + writer.newline() + + for c in ptypes.get_named_types(): + if not isinstance(c, ptypes.ChannelType): + continue + write_channel_enums(writer, c, False) + write_channel_enums(writer, c, True) + + writer.writeln("#endif /* _H_SPICE_ENUMS */") + +parser = OptionParser(usage="usage: %prog [options] ") +parser.add_option("-e", "--generate-enums", + action="store_true", dest="generate_enums", default=False, + help="Generate enums") +parser.add_option("-d", "--generate-demarshallers", + action="store_true", dest="generate_demarshallers", default=False, + help="Generate demarshallers") +parser.add_option("-a", "--assert-on-error", + action="store_true", dest="assert_on_error", default=False, + help="Assert on error") +parser.add_option("-p", "--print-error", + action="store_true", dest="print_error", default=False, + help="Print errors") +parser.add_option("-s", "--server", + action="store_true", dest="server", default=False, + help="Print errors") +parser.add_option("-c", "--client", + action="store_true", dest="client", default=False, + help="Print errors") +parser.add_option("-k", "--keep-identical-file", + action="store_true", dest="keep_identical_file", default=False, + help="Print errors") +parser.add_option("-i", "--include", + dest="include", default=None, metavar="FILE", + help="Include FILE in generated code") + +(options, args) = parser.parse_args() + +if len(args) == 0: + parser.error("No protocol file specified") + +if len(args) == 1: + parser.error("No destination file specified") + +proto_file = args[0] +dest_file = args[1] +proto = spice_parser.parse(proto_file) + +if proto == None: + exit(1) + +codegen.set_prefix(proto.name) +writer = codegen.CodeWriter() +writer.set_option("source", os.path.basename(proto_file)) + +if options.assert_on_error: + writer.set_option("assert_on_error") + +if options.print_error: + writer.set_option("print_error") + +if options.include: + writer.writeln('#include "%s"' % options.include) + +if options.generate_enums: + write_enums(writer) + +if options.generate_demarshallers: + if not options.server and not options.client: + print >> sys.stderr, "Must specify client and/or server" + sys.exit(1) + demarshal.write_includes(writer) + + if options.server: + demarshal.write_protocol_parser(writer, proto, False) + if options.client: + demarshal.write_protocol_parser(writer, proto, True) + +content = writer.getvalue() +if options.keep_identical_file: + try: + f = open(dest_file, 'rb') + old_content = f.read() + f.close() + + if content == old_content: + print "No changes to %s" % dest_file + sys.exit(0) + + except IOError: + pass + +f = open(dest_file, 'wb') +f.write(content) +f.close() + +print "Wrote %s" % dest_file +sys.exit(0) -- cgit