From 0366e7395c97e51b8a6294c10176bef73e1bdcf7 Mon Sep 17 00:00:00 2001 From: Alexander Larsson Date: Wed, 26 May 2010 12:19:58 +0200 Subject: Initial import of spice protocol description and demarshall generator The "spice.proto" file describes in detail the networking prototcol that spice uses and spice_codegen.py can parse this and generate demarshallers for such network messages. --- python_modules/demarshal.py | 1033 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1033 insertions(+) create mode 100644 python_modules/demarshal.py (limited to 'python_modules/demarshal.py') diff --git a/python_modules/demarshal.py b/python_modules/demarshal.py new file mode 100644 index 00000000..fcd68508 --- /dev/null +++ b/python_modules/demarshal.py @@ -0,0 +1,1033 @@ +import ptypes +import codegen + + +def write_parser_helpers(writer): + if writer.is_generated("helper", "demarshaller"): + return + + writer.set_is_generated("helper", "demarshaller") + + writer = writer.function_helper() + + writer.writeln("#ifdef WORDS_BIGENDIAN") + for size in [8, 16, 32, 64]: + for sign in ["", "u"]: + utype = "uint%d" % (size) + type = "%sint%d" % (sign, size) + swap = "SPICE_BYTESWAP%d" % size + if size == 8: + writer.macro("read_%s" % type, "ptr", "(*((%s_t *)(ptr)))" % type) + else: + writer.macro("read_%s" % type, "ptr", "((%s_t)%s(*((%s_t *)(ptr)))" % (type, swap, utype)) + writer.writeln("#else") + for size in [8, 16, 32, 64]: + for sign in ["", "u"]: + type = "%sint%d" % (sign, size) + writer.macro("read_%s" % type, "ptr", "(*((%s_t *)(ptr)))" % type) + writer.writeln("#endif") + + for size in [8, 16, 32, 64]: + for sign in ["", "u"]: + writer.newline() + type = "%sint%d" % (sign, size) + ctype = "%s_t" % type + scope = writer.function("SPICE_GNUC_UNUSED consume_%s" % type, ctype, "uint8_t **ptr", True) + scope.variable_def(ctype, "val") + writer.assign("val", "read_%s(*ptr)" % type) + writer.increment("*ptr", size / 8) + writer.statement("return val") + writer.end_block() + + writer.newline() + writer.statement("typedef struct PointerInfo PointerInfo") + writer.statement("typedef uint8_t * (*parse_func_t)(uint8_t *message_start, uint8_t *message_end, uint8_t *struct_data, PointerInfo *ptr_info, int minor)") + writer.statement("typedef uint8_t * (*parse_msg_func_t)(uint8_t *message_start, uint8_t *message_end, int minor, size_t *size_out)") + writer.statement("typedef uint8_t * (*spice_parse_channel_func_t)(uint8_t *message_start, uint8_t *message_end, uint16_t message_type, int minor, size_t *size_out)") + + writer.newline() + writer.begin_block("struct PointerInfo") + writer.variable_def("uint64_t", "offset") + writer.variable_def("parse_func_t", "parse") + writer.variable_def("SPICE_ADDRESS *", "dest") + writer.variable_def("uint32_t", "nelements") + writer.end_block(semicolon=True) + +def write_read_primitive(writer, start, container, name, scope): + m = container.lookup_member(name) + assert(m.is_primitive()) + writer.assign("pos", start + " + " + container.get_nw_offset(m, "", "__nw_size")) + writer.error_check("pos + %s > message_end" % m.member_type.get_fixed_nw_size()) + + var = "%s__value" % (name) + scope.variable_def(m.member_type.c_type(), var) + writer.assign(var, "read_%s(pos)" % (m.member_type.primitive_type())) + return var + +def write_read_primitive_item(writer, item, scope): + assert(item.type.is_primitive()) + writer.assign("pos", item.get_position()) + writer.error_check("pos + %s > message_end" % item.type.get_fixed_nw_size()) + var = "%s__value" % (item.subprefix) + scope.variable_def(item.type.c_type(), var) + writer.assign(var, "read_%s(pos)" % (item.type.primitive_type())) + return var + +class ItemInfo: + def __init__(self, type, prefix, position): + self.type = type + self.prefix = prefix + self.subprefix = prefix + self.position = position + self.non_null = False + self.member = None + + def nw_size(self): + return self.prefix + "__nw_size" + + def mem_size(self): + return self.prefix + "__mem_size" + + def extra_size(self): + return self.prefix + "__extra_size" + + def get_position(self): + return self.position + +class MemberItemInfo(ItemInfo): + def __init__(self, member, container, start): + if not member.is_switch(): + self.type = member.member_type + self.prefix = member.name + self.subprefix = member.name + self.non_null = member.has_attr("nonnull") + self.position = "(%s + %s)" % (start, container.get_nw_offset(member, "", "__nw_size")) + self.member = member + +def write_validate_switch_member(writer, container, switch_member, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + var = container.lookup_member(switch_member.variable) + var_type = var.member_type + + v = write_read_primitive(writer, start, container, switch_member.variable, parent_scope) + + item = MemberItemInfo(switch_member, container, start) + + first = True + for c in switch_member.cases: + check = c.get_check(v, var_type) + m = c.member + with writer.if_block(check, not first, False) as if_scope: + item.type = c.member.member_type + item.subprefix = item.prefix + "_" + m.name + item.non_null = c.member.has_attr("nonnull") + sub_want_extra_size = want_extra_size + if sub_want_extra_size and not m.contains_extra_size(): + writer.assign(item.extra_size(), 0) + sub_want_extra_size = False + + write_validate_item(writer, container, item, if_scope, scope, start, + want_nw_size, want_mem_size, sub_want_extra_size) + + first = False + + with writer.block(" else"): + if want_nw_size: + writer.assign(item.nw_size(), 0) + if want_mem_size: + writer.assign(item.mem_size(), 0) + if want_extra_size: + writer.assign(item.extra_size(), 0) + + writer.newline() + +def write_validate_struct_function(writer, struct): + validate_function = "validate_%s" % struct.c_type() + if writer.is_generated("validator", validate_function): + return validate_function + + writer.set_is_generated("validator", validate_function) + writer = writer.function_helper() + scope = writer.function(validate_function, "intptr_t", "uint8_t *message_start, uint8_t *message_end, SPICE_ADDRESS offset, int minor") + scope.variable_def("uint8_t *", "start = message_start + offset") + scope.variable_def("SPICE_GNUC_UNUSED uint8_t *", "pos"); + scope.variable_def("size_t", "mem_size", "nw_size"); + num_pointers = struct.get_num_pointers() + if num_pointers != 0: + scope.variable_def("SPICE_GNUC_UNUSED intptr_t", "ptr_size"); + + writer.newline() + with writer.if_block("offset == 0"): + writer.statement("return 0") + + writer.newline() + writer.error_check("start >= message_end") + + writer.newline() + write_validate_container(writer, None, struct, "start", scope, True, True, False) + + writer.newline() + writer.comment("Check if struct fits in reported side").newline() + writer.error_check("start + nw_size > message_end") + + writer.statement("return mem_size") + + writer.newline() + writer.label("error") + writer.statement("return -1") + + writer.end_block() + + return validate_function + +def write_validate_pointer_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + if want_nw_size: + writer.assign(item.nw_size(), 8) + + if want_mem_size or want_extra_size: + target_type = item.type.target_type + + v = write_read_primitive_item(writer, item, scope) + if item.non_null: + writer.error_check("%s == 0" % v) + + # pointer target is struct, or array of primitives + # if array, need no function check + + if target_type.is_array(): + writer.error_check("message_start + %s >= message_end" % v) + + + assert target_type.element_type.is_primitive() + + array_item = ItemInfo(target_type, "%s__array" % item.prefix, start) + scope.variable_def("uint32_t", array_item.nw_size()) + scope.variable_def("uint32_t", array_item.mem_size()) + if target_type.is_cstring_length(): + writer.assign(array_item.nw_size(), "spice_strnlen((char *)message_start + %s, message_end - (message_start + %s))" % (v, v)) + writer.error_check("*(message_start + %s + %s) != 0" % (v, array_item.nw_size())) + writer.assign(array_item.mem_size(), array_item.nw_size()) + else: + write_validate_array_item(writer, container, array_item, scope, parent_scope, start, + True, True, False) + writer.error_check("message_start + %s + %s > message_end" % (v, array_item.nw_size())) + + if want_extra_size: + if item.member and item.member.has_attr("nocopy"): + writer.comment("@nocopy, so no extra size").newline() + writer.assign(item.extra_size(), 0) + elif target_type.element_type.get_fixed_nw_size == 1: + writer.assign(item.extra_size(), array_item.mem_size()) + # If not bytes or zero, add padding needed for alignment + else: + writer.assign(item.extra_size(), "%s + /* for alignment */ 3" % array_item.mem_size()) + if want_mem_size: + writer.assign(item.mem_size(), "sizeof(void *) + %s" % array_item.mem_size()) + + elif target_type.is_struct(): + validate_function = write_validate_struct_function(writer, target_type) + writer.assign("ptr_size", "%s(message_start, message_end, %s, minor)" % (validate_function, v)) + writer.error_check("ptr_size < 0") + + if want_extra_size: + writer.assign(item.extra_size(), "ptr_size + /* for alignment */ 3") + if want_mem_size: + writer.assign(item.mem_size(), "sizeof(void *) + ptr_size") + else: + raise NotImplementedError("pointer to unsupported type %s" % target_type) + + +def write_validate_array_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + array = item.type + is_byte_size = False + element_type = array.element_type + if array.is_bytes_length(): + nelements = "%s__nbytes" %(item.prefix) + else: + nelements = "%s__nelements" %(item.prefix) + if not parent_scope.variable_defined(nelements): + parent_scope.variable_def("uint32_t", nelements) + + if array.is_constant_length(): + writer.assign(nelements, array.size) + elif array.is_remaining_length(): + if element_type.is_fixed_nw_size(): + if element_type.get_fixed_nw_size() == 1: + writer.assign(nelements, "message_end - %s" % item.get_position()) + else: + writer.assign(nelements, "(message_end - %s) / (%s)" %(item.get_position(), element_type.get_fixed_nw_size())) + else: + raise NotImplementedError("TODO array[] of dynamic element size not done yet") + elif array.is_identifier_length(): + v = write_read_primitive(writer, start, container, array.size, scope) + writer.assign(nelements, v) + elif array.is_image_size_length(): + bpp = array.size[1] + width = array.size[2] + rows = array.size[3] + width_v = write_read_primitive(writer, start, container, width, scope) + rows_v = write_read_primitive(writer, start, container, rows, scope) + # TODO: Handle multiplication overflow + if bpp == 8: + writer.assign(nelements, "%s * %s" % (width_v, rows_v)) + elif bpp == 1: + writer.assign(nelements, "((%s + 7) / 8 ) * %s" % (width_v, rows_v)) + else: + writer.assign(nelements, "((%s * %s + 7) / 8 ) * %s" % (bpp, width_v, rows_v)) + elif array.is_bytes_length(): + is_byte_size = True + v = write_read_primitive(writer, start, container, array.size[1], scope) + writer.assign(nelements, v) + elif array.is_cstring_length(): + writer.todo("cstring array size type not handled yet") + else: + writer.todo("array size type not handled yet") + + writer.newline() + + nw_size = item.nw_size() + mem_size = item.mem_size() + extra_size = item.extra_size() + + if is_byte_size and want_nw_size: + writer.assign(nw_size, nelements) + want_nw_size = False + + if element_type.is_fixed_nw_size() and want_nw_size: + element_size = element_type.get_fixed_nw_size() + # TODO: Overflow check the multiplication + if element_size == 1: + writer.assign(nw_size, nelements) + else: + writer.assign(nw_size, "(%s) * %s" % (element_size, nelements)) + want_nw_size = False + + if element_type.is_fixed_sizeof() and want_mem_size and not is_byte_size: + # TODO: Overflow check the multiplication + writer.assign(mem_size, "%s * %s" % (element_type.sizeof(), nelements)) + want_mem_size = False + + if not element_type.contains_extra_size() and want_extra_size: + writer.assign(extra_size, 0) + want_extra_size = False + + if not (want_mem_size or want_nw_size or want_extra_size): + return + + start2 = codegen.increment_identifier(start) + scope.variable_def("uint8_t *", "%s = %s" % (start2, item.get_position())) + if is_byte_size: + start2_end = "%s_array_end" % start2 + scope.variable_def("uint8_t *", start2_end) + + element_item = ItemInfo(element_type, "%s__element" % item.prefix, start2) + + element_nw_size = element_item.nw_size() + element_mem_size = element_item.mem_size() + scope.variable_def("uint32_t", element_nw_size) + scope.variable_def("uint32_t", element_mem_size) + + if want_nw_size: + writer.assign(nw_size, 0) + if want_mem_size: + writer.assign(mem_size, 0) + if want_extra_size: + writer.assign(extra_size, 0) + + want_element_nw_size = want_nw_size + if element_type.is_fixed_nw_size(): + start_increment = element_type.get_fixed_nw_size() + else: + want_element_nw_size = True + start_increment = element_nw_size + + if is_byte_size: + writer.assign(start2_end, "%s + %s" % (start2, nelements)) + + with writer.index(no_block = is_byte_size) as index: + with writer.while_loop("%s < %s" % (start2, start2_end) ) if is_byte_size else writer.for_loop(index, nelements) as scope: + write_validate_item(writer, container, element_item, scope, parent_scope, start2, + want_element_nw_size, want_mem_size, want_extra_size) + + if want_nw_size: + writer.increment(nw_size, element_nw_size) + if want_mem_size: + writer.increment(mem_size, element_mem_size) + if want_extra_size: + writer.increment(extra_size, element_extra_size) + + writer.increment(start2, start_increment) + if is_byte_size: + writer.error_check("%s != %s" % (start2, start2_end)) + +def write_validate_struct_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + struct = item.type + start2 = codegen.increment_identifier(start) + scope.variable_def("SPICE_GNUC_UNUSED uint8_t *", start2 + " = %s" % (item.get_position())) + + write_validate_container(writer, item.prefix, struct, start2, scope, want_nw_size, want_mem_size, want_extra_size) + +def write_validate_primitive_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + if want_nw_size: + nw_size = item.nw_size() + writer.assign(nw_size, item.type.get_fixed_nw_size()) + if want_mem_size: + mem_size = item.mem_size() + writer.assign(mem_size, item.type.sizeof()) + assert not want_extra_size + +def write_validate_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + if item.type.is_pointer(): + write_validate_pointer_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + elif item.type.is_array(): + write_validate_array_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + elif item.type.is_struct(): + write_validate_struct_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + elif item.type.is_primitive(): + write_validate_primitive_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + else: + writer.todo("Implement validation of %s" % item.type) + +def write_validate_member(writer, container, member, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size): + if member.has_minor_attr(): + prefix = "if (minor >= %s)" % (member.get_minor_attr()) + newline = False + else: + prefix = "" + newline = True + item = MemberItemInfo(member, container, start) + with writer.block(prefix, newline=newline, comment=member.name) as scope: + if member.is_switch(): + write_validate_switch_member(writer, container, member, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + else: + write_validate_item(writer, container, item, scope, parent_scope, start, + want_nw_size, want_mem_size, want_extra_size) + + if member.has_minor_attr(): + with writer.block(" else", comment = "minor < %s" % (member.get_minor_attr())): + if member.is_array(): + nelements = "%s__nelements" %(item.prefix) + writer.assign(nelements, 0) + if want_nw_size: + writer.assign(item.nw_size(), 0) + + if want_mem_size: + if member.is_fixed_sizeof(): + writer.assign(item.mem_size(), member.sizeof()) + elif member.is_array(): + writer.assign(item.mem_size(), 0) + else: + raise NotImplementedError("TODO minor check for non-constant items") + + assert not want_extra_size + +def write_validate_container(writer, prefix, container, start, parent_scope, want_nw_size, want_mem_size, want_extra_size): + for m in container.members: + sub_want_nw_size = want_nw_size and not m.is_fixed_nw_size() + sub_want_mem_size = m.is_extra_size() + sub_want_extra_size = not m.is_extra_size() and m.contains_extra_size() + + defs = ["size_t"] + if sub_want_nw_size: + defs.append (m.name + "__nw_size") + if sub_want_mem_size: + defs.append (m.name + "__mem_size") + if sub_want_extra_size: + defs.append (m.name + "__extra_size") + + if sub_want_nw_size or sub_want_mem_size or sub_want_extra_size: + parent_scope.variable_def(*defs) + write_validate_member(writer, container, m, parent_scope, start, + sub_want_nw_size, sub_want_mem_size, sub_want_extra_size) + writer.newline() + + if want_nw_size: + if prefix: + nw_size = prefix + "__nw_size" + else: + nw_size = "nw_size" + + size = 0 + for m in container.members: + if m.is_fixed_nw_size(): + size = size + m.get_fixed_nw_size() + + nm_sum = str(size) + for m in container.members: + if not m.is_fixed_nw_size(): + nm_sum = nm_sum + " + " + m.name + "__nw_size" + + writer.assign(nw_size, nm_sum) + + if want_mem_size: + if prefix: + mem_size = prefix + "__mem_size" + else: + mem_size = "mem_size" + + mem_sum = container.sizeof() + for m in container.members: + if m.is_extra_size(): + mem_sum = mem_sum + " + " + m.name + "__mem_size" + elif m.contains_extra_size(): + mem_sum = mem_sum + " + " + m.name + "__extra_size" + + writer.assign(mem_size, mem_sum) + + if want_extra_size: + if prefix: + extra_size = prefix + "__extra_size" + else: + extra_size = "extra_size" + + extra_sum = [] + for m in container.members: + if m.is_extra_size(): + extra_sum.append(m.name + "__mem_size") + elif m.contains_extra_size(): + extra_sum.append(m.name + "__extra_size") + writer.assign(extra_size, codegen.sum_array(extra_sum)) + +class DemarshallingDestination: + def __init__(self): + pass + + def child_at_end(self, writer, t): + return RootDemarshallingDestination(self, t.c_type(), t.sizeof()) + + def child_sub(self, member): + return SubDemarshallingDestination(self, member) + + def declare(self, writer): + return writer.optional_block(self.reuse_scope) + + def is_toplevel(self): + return self.parent_dest == None and not self.is_helper + +class RootDemarshallingDestination(DemarshallingDestination): + def __init__(self, parent_dest, c_type, sizeof, pointer = None): + self.is_helper = False + self.reuse_scope = None + self.parent_dest = parent_dest + if parent_dest: + self.base_var = codegen.increment_identifier(parent_dest.base_var) + else: + self.base_var = "out" + self.c_type = c_type + self.sizeof = sizeof + self.pointer = pointer # None == at "end" + + def get_ref(self, member): + return self.base_var + "->" + member + + def declare(self, writer): + if self.reuse_scope: + scope = self.reuse_scope + else: + writer.begin_block() + scope = writer.get_subwriter() + + scope.variable_def(self.c_type + " *", self.base_var) + if not self.reuse_scope: + scope.newline() + + if self.pointer: + writer.assign(self.base_var, "(%s *)%s" % (self.c_type, self.pointer)) + else: + writer.assign(self.base_var, "(%s *)end" % (self.c_type)) + writer.increment("end", self.sizeof) + writer.newline() + + if self.reuse_scope: + return writer.no_block(self.reuse_scope) + else: + return writer.partial_block(scope) + +class SubDemarshallingDestination(DemarshallingDestination): + def __init__(self, parent_dest, member): + self.reuse_scope = None + self.parent_dest = parent_dest + self.base_var = parent_dest.base_var + self.member = member + self.is_helper = False + + def get_ref(self, member): + return self.parent_dest.get_ref(self.member) + "." + member + +def read_array_len(writer, prefix, array, dest, scope, handles_bytes = False): + if array.is_bytes_length(): + nelements = "%s__nbytes" % prefix + else: + nelements = "%s__nelements" % prefix + if dest.is_toplevel(): + return nelements # Already there for toplevel, need not recalculate + element_type = array.element_type + scope.variable_def("uint32_t", nelements) + if array.is_constant_length(): + writer.assign(nelements, array.size) + elif array.is_identifier_length(): + writer.assign(nelements, dest.get_ref(array.size)) + elif array.is_remaining_length(): + if element_type.is_fixed_nw_size(): + writer.assign(nelements, "(message_end - in) / (%s)" %(element_type.get_fixed_nw_size())) + else: + raise NotImplementedError("TODO array[] of dynamic element size not done yet") + elif array.is_image_size_length(): + bpp = array.size[1] + width = array.size[2] + rows = array.size[3] + width_v = dest.get_ref(width) + rows_v = dest.get_ref(rows) + # TODO: Handle multiplication overflow + if bpp == 8: + writer.assign(nelements, "%s * %s" % (width_v, rows_v)) + elif bpp == 1: + writer.assign(nelements, "((%s + 7) / 8 ) * %s" % (width_v, rows_v)) + else: + writer.assign(nelements, "((%s * %s + 7) / 8 ) * %s" % (bpp, width_v, rows_v)) + elif array.is_bytes_length(): + if not handles_bytes: + raise NotImplementedError("handling of bytes() not supported here yet") + writer.assign(nelements, dest.get_ref(array.size[1])) + else: + raise NotImplementedError("TODO array size type not handled yet") + return nelements + +def write_switch_parser(writer, container, switch, dest, scope): + var = container.lookup_member(switch.variable) + var_type = var.member_type + + if switch.has_attr("fixedsize"): + scope.variable_def("uint8_t *", "in_save") + writer.assign("in_save", "in") + + first = True + for c in switch.cases: + check = c.get_check(dest.get_ref(switch.variable), var_type) + m = c.member + with writer.if_block(check, not first, False) as block: + t = m.member_type + if switch.has_end_attr(): + dest2 = dest.child_at_end(writer, m.member_type) + elif switch.has_attr("anon"): + dest2 = dest + else: + if t.is_struct(): + dest2 = dest.child_sub(switch.name + "." + m.name) + else: + dest2 = dest.child_sub(switch.name) + dest2.reuse_scope = block + + if t.is_struct(): + write_container_parser(writer, t, dest2) + elif t.is_pointer(): + write_parse_pointer(writer, t, False, dest2, m.name, not m.has_attr("ptr32"), block) + elif t.is_primitive(): + writer.assign(dest2.get_ref(m.name), "consume_%s(&in)" % (t.primitive_type())) + #TODO validate e.g. flags and enums + elif t.is_array(): + nelements = read_array_len(writer, m.name, t, dest, block) + write_array_parser(writer, nelements, t, dest, block) + else: + writer.todo("Can't handle type %s" % m.member_type) + + first = False + + writer.newline() + + if switch.has_attr("fixedsize"): + writer.assign("in", "in_save + %s" % switch.get_fixed_nw_size()) + +def write_parse_ptr_function(writer, target_type): + if target_type.is_array(): + parse_function = "parse_array_%s" % target_type.element_type.primitive_type() + else: + parse_function = "parse_struct_%s" % target_type.c_type() + if writer.is_generated("parser", parse_function): + return parse_function + + writer.set_is_generated("parser", parse_function) + + writer = writer.function_helper() + scope = writer.function(parse_function, "uint8_t *", "uint8_t *message_start, uint8_t *message_end, uint8_t *struct_data, PointerInfo *this_ptr_info, int minor") + scope.variable_def("uint8_t *", "in = message_start + this_ptr_info->offset") + scope.variable_def("uint8_t *", "end") + + num_pointers = target_type.get_num_pointers() + if num_pointers != 0: + scope.variable_def("SPICE_GNUC_UNUSED intptr_t", "ptr_size"); + scope.variable_def("uint32_t", "n_ptr=0"); + scope.variable_def("PointerInfo", "ptr_info[%s]" % num_pointers) + + writer.newline() + if target_type.is_array(): + writer.assign("end", "struct_data") + else: + writer.assign("end", "struct_data + %s" % (target_type.sizeof())) + + dest = RootDemarshallingDestination(None, target_type.c_type(), target_type.sizeof(), "struct_data") + dest.is_helper = True + dest.reuse_scope = scope + if target_type.is_array(): + write_array_parser(writer, "this_ptr_info->nelements", target_type, dest, scope) + else: + write_container_parser(writer, target_type, dest) + + if num_pointers != 0: + write_ptr_info_check(writer) + + writer.statement("return end") + + if writer.has_error_check: + writer.newline() + writer.label("error") + writer.statement("return NULL") + + writer.end_block() + + return parse_function + +def write_array_parser(writer, nelements, array, dest, scope): + is_byte_size = array.is_bytes_length() + + element_type = array.element_type + if element_type == ptypes.uint8 or element_type == ptypes.int8: + writer.statement("memcpy(end, in, %s)" % (nelements)) + writer.increment("in", nelements) + writer.increment("end", nelements) + else: + if is_byte_size: + scope.variable_def("uint8_t *", "array_end") + writer.assign("array_end", "end + %s" % nelements) + with writer.index(no_block = is_byte_size) as index: + with writer.while_loop("end < array_end") if is_byte_size else writer.for_loop(index, nelements) as array_scope: + if element_type.is_primitive(): + writer.statement("*(%s *)end = consume_%s(&in)" % (element_type.c_type(), element_type.primitive_type())) + writer.increment("end", element_type.sizeof()) + else: + dest2 = dest.child_at_end(writer, element_type) + dest2.reuse_scope = array_scope + write_container_parser(writer, element_type, dest2) + +def write_parse_pointer(writer, t, at_end, dest, member_name, is_64bit, scope): + target_type = t.target_type + if is_64bit: + writer.assign("ptr_info[n_ptr].offset", "consume_uint64(&in)") + else: + writer.assign("ptr_info[n_ptr].offset", "consume_uint32(&in)") + writer.assign("ptr_info[n_ptr].parse", write_parse_ptr_function(writer, target_type)) + if at_end: + writer.assign("ptr_info[n_ptr].dest", "end") + writer.increment("end", "sizeof(SPICE_ADDRESS)"); + else: + writer.assign("ptr_info[n_ptr].dest", "&%s" % dest.get_ref(member_name)) + if target_type.is_array(): + nelements = read_array_len(writer, member_name, target_type, dest, scope) + writer.assign("ptr_info[n_ptr].nelements", nelements) + + writer.statement("n_ptr++") + +def write_member_parser(writer, container, member, dest, scope): + if member.is_switch(): + write_switch_parser(writer, container, member, dest, scope) + return + + t = member.member_type + + if t.is_pointer(): + if member.has_attr("nocopy"): + writer.comment("Reuse data from network message").newline() + writer.assign(dest.get_ref(member.name), "(size_t)(message_start + consume_uint64(&in))") + else: + write_parse_pointer(writer, t, member.has_end_attr(), dest, member.name, not member.has_attr("ptr32"), scope) + elif t.is_primitive(): + if member.has_end_attr(): + writer.statement("*(%s *)end = consume_%s(&in)" % (t.c_type(), t.primitive_type())) + writer.increment("end", t.sizeof()) + else: + writer.assign(dest.get_ref(member.name), "consume_%s(&in)" % (t.primitive_type())) + #TODO validate e.g. flags and enums + elif t.is_array(): + nelements = read_array_len(writer, member.name, t, dest, scope, handles_bytes = True) + write_array_parser(writer, nelements, t, dest, scope) + elif t.is_struct(): + if member.has_end_attr(): + dest2 = dest.child_at_end(writer, t) + else: + dest2 = dest.child_sub(member.name) + writer.comment(member.name) + write_container_parser(writer, t, dest2) + else: + raise NotImplementedError("TODO can't handle parsing of %s" % t) + +def write_container_parser(writer, container, dest): + with dest.declare(writer) as scope: + for m in container.members: + if m.has_minor_attr(): + writer.begin_block("if (minor >= %s)" % m.get_minor_attr()) + write_member_parser(writer, container, m, dest, scope) + if m.has_minor_attr(): + # We need to zero out the fixed part of all optional fields + if not m.member_type.is_array(): + writer.end_block(newline=False) + writer.begin_block(" else") + # TODO: This is not right for fields that don't exist in the struct + if m.member_type.is_primitive(): + writer.assign(dest.get_ref(m.name), "0") + elif m.is_fixed_sizeof(): + writer.statement("memset ((char *)&%s, 0, %s)" % (dest.get_ref(m.name), m.sizeof())) + else: + raise NotImplementedError("TODO Clear optional dynamic fields") + writer.end_block() + +def write_ptr_info_check(writer): + writer.newline() + with writer.index() as index: + with writer.for_loop(index, "n_ptr") as scope: + offset = "ptr_info[%s].offset" % index + function = "ptr_info[%s].parse" % index + dest = "ptr_info[%s].dest" % index + with writer.if_block("%s == 0" % offset, newline=False): + writer.assign("*%s" % dest, "0") + with writer.block(" else"): + writer.comment("Align to 32 bit").newline() + writer.assign("end", "(uint8_t *)SPICE_ALIGN((size_t)end, 4)") + writer.assign("*%s" % dest, "(size_t)end") + writer.assign("end", "%s(message_start, message_end, end, &ptr_info[%s], minor)" % (function, index)) + writer.error_check("end == NULL") + writer.newline() + +def write_msg_parser(writer, message): + msg_name = message.c_name() + function_name = "parse_%s" % msg_name + if writer.is_generated("demarshaller", function_name): + return function_name + writer.set_is_generated("demarshaller", function_name) + + msg_type = message.c_type() + msg_sizeof = message.sizeof() + + writer.newline() + parent_scope = writer.function(function_name, + "uint8_t *", + "uint8_t *message_start, uint8_t *message_end, int minor, size_t *size", True) + parent_scope.variable_def("SPICE_GNUC_UNUSED uint8_t *", "pos"); + parent_scope.variable_def("uint8_t *", "start = message_start"); + parent_scope.variable_def("uint8_t *", "data = NULL"); + parent_scope.variable_def("size_t", "mem_size", "nw_size"); + if not message.has_attr("nocopy"): + parent_scope.variable_def("uint8_t *", "in", "end"); + num_pointers = message.get_num_pointers() + if num_pointers != 0: + parent_scope.variable_def("SPICE_GNUC_UNUSED intptr_t", "ptr_size"); + parent_scope.variable_def("uint32_t", "n_ptr=0"); + parent_scope.variable_def("PointerInfo", "ptr_info[%s]" % num_pointers) + writer.newline() + + write_parser_helpers(writer) + + write_validate_container(writer, None, message, "start", parent_scope, True, True, False) + + writer.newline() + + writer.comment("Check if message fits in reported side").newline() + with writer.block("if (start + nw_size > message_end)"): + writer.statement("return NULL") + + writer.newline().comment("Validated extents and calculated size").newline() + + if message.has_attr("nocopy"): + writer.assign("data", "message_start") + writer.assign("*size", "message_end - message_start") + else: + writer.assign("data", "(uint8_t *)malloc(mem_size)") + writer.error_check("data == NULL") + writer.assign("end", "data + %s" % (msg_sizeof)) + writer.assign("in", "start").newline() + + dest = RootDemarshallingDestination(None, msg_type, msg_sizeof, "data") + dest.reuse_scope = parent_scope + write_container_parser(writer, message, dest) + + writer.newline() + writer.statement("assert(in <= message_end)") + + if num_pointers != 0: + write_ptr_info_check(writer) + + writer.statement("assert(end <= data + mem_size)") + + writer.newline() + writer.assign("*size", "end - data") + + writer.statement("return data") + writer.newline() + if writer.has_error_check: + writer.label("error") + with writer.block("if (data != NULL)"): + writer.statement("free(data)") + writer.statement("return NULL") + writer.end_block() + + return function_name + +def write_channel_parser(writer, channel, server): + writer.newline() + ids = {} + min_id = 1000000 + if server: + messages = channel.server_messages + else: + messages = channel.client_messages + for m in messages: + ids[m.value] = m + + ranges = [] + ids2 = ids.copy() + while len(ids2) > 0: + end = start = min(ids2.keys()) + while ids2.has_key(end): + del ids2[end] + end = end + 1 + + ranges.append( (start, end) ) + + if server: + function_name = "parse_%s_msg" % channel.name + else: + function_name = "parse_%s_msgc" % channel.name + writer.newline() + scope = writer.function(function_name, + "uint8_t *", + "uint8_t *message_start, uint8_t *message_end, uint16_t message_type, int minor, size_t *size_out") + + helpers = writer.function_helper() + + d = 0 + for r in ranges: + d = d + 1 + writer.write("static parse_msg_func_t funcs%d[%d] = " % (d, r[1] - r[0])); + writer.begin_block() + for i in range(r[0], r[1]): + func = write_msg_parser(helpers, ids[i].message_type) + writer.write(func) + if i != r[1] -1: + writer.write(",") + writer.newline() + + writer.end_block(semicolon = True) + + d = 0 + for r in ranges: + d = d + 1 + with writer.if_block("message_type >= %d && message_type < %d" % (r[0], r[1]), d > 1, False): + writer.statement("return funcs%d[message_type-%d](message_start, message_end, minor, size_out)" % (d, r[0])) + writer.newline() + + writer.statement("return NULL") + writer.end_block() + + return function_name + +def write_get_channel_parser(writer, channel_parsers, max_channel, is_server): + writer.newline() + if is_server: + function_name = "spice_get_server_channel_parser" + else: + function_name = "spice_get_client_channel_parser" + + scope = writer.function(function_name, + "spice_parse_channel_func_t", + "uint32_t channel, unsigned int *max_message_type") + + writer.write("static struct {spice_parse_channel_func_t func; unsigned int max_messages; } channels[%d] = " % (max_channel+1)) + writer.begin_block() + for i in range(0, max_channel + 1): + writer.write("{ ") + if channel_parsers.has_key(i): + writer.write(channel_parsers[i][1]) + writer.write(", ") + + channel = channel_parsers[i][0] + max_msg = 0 + if is_server: + messages = channel.server_messages + else: + messages = channel.client_messages + for m in messages: + max_msg = max(max_msg, m.value) + writer.write(max_msg) + else: + writer.write("NULL, 0") + writer.write("}") + + if i != max_channel: + writer.write(",") + writer.newline() + writer.end_block(semicolon = True) + + with writer.if_block("channel < %d" % (max_channel + 1)): + with writer.if_block("max_message_type != NULL"): + writer.assign("*max_message_type", "channels[channel].max_messages") + writer.statement("return channels[channel].func") + + writer.statement("return NULL") + writer.end_block() + + +def write_full_protocol_parser(writer, is_server): + writer.newline() + if is_server: + function_name = "spice_parse_msg" + else: + function_name = "spice_parse_reply" + scope = writer.function(function_name, + "uint8_t *", + "uint8_t *message_start, uint8_t *message_end, uint32_t channel, uint16_t message_type, int minor, size_t *size_out") + scope.variable_def("spice_parse_channel_func_t", "func" ) + + if is_server: + writer.assign("func", "spice_get_server_channel_parser(channel, NULL)") + else: + writer.assign("func", "spice_get_client_channel_parser(channel, NULL)") + + with writer.if_block("func != NULL"): + writer.statement("return func(message_start, message_end, message_type, minor, size_out)") + + writer.statement("return NULL") + writer.end_block() + +def write_protocol_parser(writer, proto, is_server): + max_channel = 0 + parsers = {} + + for channel in proto.channels: + max_channel = max(max_channel, channel.value) + + parsers[channel.value] = (channel.channel_type, write_channel_parser(writer, channel.channel_type, is_server)) + + write_get_channel_parser(writer, parsers, max_channel, is_server) + write_full_protocol_parser(writer, is_server) + +def write_includes(writer): + writer.writeln("#include ") + writer.writeln("#include ") + writer.writeln("#include ") + writer.writeln("#include ") + writer.writeln("#include ") + writer.writeln("#include ") + writer.newline() + writer.writeln("#ifdef _MSC_VER") + writer.writeln("#pragma warning(disable:4101)") + writer.writeln("#endif") -- cgit