Warn if PB_MAX_REQUIRED_FIELDS is not large enough.
[apps/agl-service-can-low-level.git] / generator / nanopb_generator.py
index e62d04f..730c0aa 100644 (file)
@@ -1,21 +1,41 @@
 '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
 
-import google.protobuf.descriptor_pb2 as descriptor
-import nanopb_pb2
+try:
+    import google.protobuf.descriptor_pb2 as descriptor
+except:
+    print
+    print "*************************************************************"
+    print "*** Could not import the Google protobuf Python libraries ***"
+    print "*** Try installing package 'python-protobuf' or similar.  ***"
+    print "*************************************************************"
+    print
+    raise
+
+try:
+    import nanopb_pb2
+except:
+    print
+    print "***************************************************************"
+    print "*** Could not import the precompiled nanopb_pb2.py.         ***"
+    print "*** Run 'make' in the 'generator' folder to update the file.***"
+    print "***************************************************************"
+    print
+    raise
+
 import os.path
 
 # Values are tuple (c type, pb ltype)
 FieldD = descriptor.FieldDescriptorProto
 datatypes = {
     FieldD.TYPE_BOOL: ('bool', 'PB_LTYPE_VARINT'),
-    FieldD.TYPE_DOUBLE: ('double', 'PB_LTYPE_FIXED'),
-    FieldD.TYPE_FIXED32: ('uint32_t', 'PB_LTYPE_FIXED'),
-    FieldD.TYPE_FIXED64: ('uint64_t', 'PB_LTYPE_FIXED'),
-    FieldD.TYPE_FLOAT: ('float', 'PB_LTYPE_FIXED'),
+    FieldD.TYPE_DOUBLE: ('double', 'PB_LTYPE_FIXED64'),
+    FieldD.TYPE_FIXED32: ('uint32_t', 'PB_LTYPE_FIXED32'),
+    FieldD.TYPE_FIXED64: ('uint64_t', 'PB_LTYPE_FIXED64'),
+    FieldD.TYPE_FLOAT: ('float', 'PB_LTYPE_FIXED32'),
     FieldD.TYPE_INT32: ('int32_t', 'PB_LTYPE_VARINT'),
     FieldD.TYPE_INT64: ('int64_t', 'PB_LTYPE_VARINT'),
-    FieldD.TYPE_SFIXED32: ('int32_t', 'PB_LTYPE_FIXED'),
-    FieldD.TYPE_SFIXED64: ('int64_t', 'PB_LTYPE_FIXED'),
+    FieldD.TYPE_SFIXED32: ('int32_t', 'PB_LTYPE_FIXED32'),
+    FieldD.TYPE_SFIXED64: ('int64_t', 'PB_LTYPE_FIXED64'),
     FieldD.TYPE_SINT32: ('int32_t', 'PB_LTYPE_SVARINT'),
     FieldD.TYPE_SINT64: ('int64_t', 'PB_LTYPE_SVARINT'),
     FieldD.TYPE_UINT32: ('uint32_t', 'PB_LTYPE_VARINT'),
@@ -108,7 +128,8 @@ class Field:
         elif desc.type == FieldD.TYPE_ENUM:
             self.ltype = 'PB_LTYPE_VARINT'
             self.ctype = names_from_type_name(desc.type_name)
-            self.default = Names(self.ctype) + self.default
+            if self.default is not None:
+                self.default = self.ctype + self.default
         elif desc.type == FieldD.TYPE_STRING:
             self.ltype = 'PB_LTYPE_STRING'
             if self.max_size is None:
@@ -124,7 +145,7 @@ class Field:
                 self.ctype = self.struct_name + self.name + 't'
         elif desc.type == FieldD.TYPE_MESSAGE:
             self.ltype = 'PB_LTYPE_SUBMESSAGE'
-            self.ctype = names_from_type_name(desc.type_name)
+            self.ctype = self.submsgname = names_from_type_name(desc.type_name)
         else:
             raise NotImplementedError(desc.type)
         
@@ -167,8 +188,8 @@ class Field:
             if self.max_size is None:
                 return None # Not implemented
             else:
-                array_decl = '[%d]' % self.max_size
-            default = self.default.encode('string_escape')
+                array_decl = '[%d]' % (self.max_size + 1)
+            default = str(self.default).encode('string_escape')
             default = default.replace('"', '\\"')
             default = '"' + default + '"'
         elif self.ltype == 'PB_LTYPE_BYTES':
@@ -223,7 +244,7 @@ class Field:
             result += ' 0,'
         
         if self.ltype == 'PB_LTYPE_SUBMESSAGE':
-            result += '\n    &%s_fields}' % self.ctype
+            result += '\n    &%s_fields}' % self.submsgname
         elif self.default is None or self.htype == 'PB_HTYPE_CALLBACK':
             result += ' 0}'
         else:
@@ -237,24 +258,10 @@ class Message:
         self.fields = [Field(self.name, f) for f in desc.field]
         self.ordered_fields = self.fields[:]
         self.ordered_fields.sort()
-    
-    def __cmp__(self, other):
-        '''Sort messages so that submessages are declared before the message
-        that uses them.
-        '''
-        if self.refers_to(other.name):
-            return 1
-        elif other.refers_to(self.name):
-            return -1
-        else:
-            return 0
-    
-    def refers_to(self, name):
-        '''Returns True if this message uses the specified type as field type.'''
-        for field in self.fields:
-            if str(field.ctype) == str(name):
-                return True
-        return False
+
+    def get_dependencies(self):
+        '''Get list of type names that this structure refers to.'''
+        return [str(field.ctype) for field in self.fields]
     
     def __str__(self):
         result = 'typedef struct {\n'
@@ -314,17 +321,53 @@ def parse_file(fdesc):
     enums = []
     messages = []
     
+    if fdesc.package:
+        base_name = Names(fdesc.package.split('.'))
+    else:
+        base_name = Names()
+    
     for enum in fdesc.enum_type:
-        enums.append(Enum(Names(), enum))
+        enums.append(Enum(base_name, enum))
     
-    for names, message in iterate_messages(fdesc):
+    for names, message in iterate_messages(fdesc, base_name):
         messages.append(Message(names, message))
         for enum in message.enum_type:
             enums.append(Enum(names, enum))
     
     return enums, messages
 
-def generate_header(headername, enums, messages):
+def toposort2(data):
+    '''Topological sort.
+    From http://code.activestate.com/recipes/577413-topological-sort/
+    This function is under the MIT license.
+    '''
+    for k, v in data.items():
+        v.discard(k) # Ignore self dependencies
+    extra_items_in_deps = reduce(set.union, data.values(), set()) - set(data.keys())
+    data.update(dict([(item, set()) for item in extra_items_in_deps]))
+    while True:
+        ordered = set(item for item,dep in data.items() if not dep)
+        if not ordered:
+            break
+        for item in sorted(ordered):
+            yield item
+        data = dict([(item, (dep - ordered)) for item,dep in data.items()
+                if item not in ordered])
+    assert not data, "A cyclic dependency exists amongst %r" % data
+
+def sort_dependencies(messages):
+    '''Sort a list of Messages based on dependencies.'''
+    dependencies = {}
+    message_by_name = {}
+    for message in messages:
+        dependencies[str(message.name)] = set(message.get_dependencies())
+        message_by_name[str(message.name)] = message
+    
+    for msgname in toposort2(dependencies):
+        if msgname in message_by_name:
+            yield message_by_name[msgname]
+
+def generate_header(dependencies, headername, enums, messages):
     '''Generate content for a header file.
     Generates strings, which should be concatenated and stored to file.
     '''
@@ -336,13 +379,17 @@ def generate_header(headername, enums, messages):
     yield '#define _PB_%s_\n' % symbol
     yield '#include <pb.h>\n\n'
     
+    for dependency in dependencies:
+        noext = os.path.splitext(dependency)[0]
+        yield '#include "%s.pb.h"\n' % noext
+    yield '\n'
+    
     yield '/* Enum definitions */\n'
     for enum in enums:
         yield str(enum) + '\n\n'
     
     yield '/* Struct definitions */\n'
-    messages.sort()
-    for msg in messages:
+    for msg in sort_dependencies(messages):
         yield msg.types()
         yield str(msg) + '\n\n'
         
@@ -355,6 +402,16 @@ def generate_header(headername, enums, messages):
     for msg in messages:
         yield msg.fields_declaration() + '\n'
     
+    count_required_fields = lambda m: len([f for f in msg.fields if f.htype == 'PB_HTYPE_REQUIRED'])
+    largest_msg = max(messages, key = count_required_fields)
+    largest_count = count_required_fields(largest_msg)
+    if largest_count > 64:
+        yield '\n/* Check that missing required fields will be properly detected */\n'
+        yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count
+        yield '#warning Properly detecting missing required fields in %s requires \\\n' % largest_msg.name
+        yield '         setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count
+        yield '#endif\n'
+    
     yield '\n#endif\n'
 
 def generate_source(headername, enums, messages):
@@ -379,26 +436,31 @@ if __name__ == '__main__':
         print "Usage: " + sys.argv[0] + " file.pb"
         print "where file.pb has been compiled from .proto by:"
         print "protoc -ofile.pb file.proto"
-        print "Output fill be written to file.h and file.c"
+        print "Output fill be written to file.pb.h and file.pb.c"
         sys.exit(1)
     
-    data = open(sys.argv[1]).read()
+    data = open(sys.argv[1], 'rb').read()
     fdesc = descriptor.FileDescriptorSet.FromString(data)
     enums, messages = parse_file(fdesc.file[0])
     
     noext = os.path.splitext(sys.argv[1])[0]
-    headername = noext + '.h'
-    sourcename = noext + '.c'
+    headername = noext + '.pb.h'
+    sourcename = noext + '.pb.c'
     headerbasename = os.path.basename(headername)
     
     print "Writing to " + headername + " and " + sourcename
     
+    # List of .proto files that should not be included in the C header file
+    # even if they are mentioned in the source .proto.
+    excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto']
+    dependencies = [d for d in fdesc.file[0].dependency if d not in excludes]
+    
     header = open(headername, 'w')
-    for part in generate_header(headerbasename, enums, messages):
+    for part in generate_header(dependencies, headerbasename, enums, messages):
         header.write(part)
 
     source = open(sourcename, 'w')
     for part in generate_source(headerbasename, enums, messages):
         source.write(part)
 
-    
\ No newline at end of file
+