Give names to generated structures to allow forward declaration.
[apps/agl-service-can-low-level.git] / generator / nanopb_generator.py
index 2ceafc7..1c21422 100644 (file)
@@ -1,7 +1,27 @@
 '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
 
-import google.protobuf.descriptor_pb2 as descriptor
-import nanopb_pb2
+try:
+    import google.protobuf.descriptor_pb2 as descriptor
+except:
+    print
+    print "*************************************************************"
+    print "*** Could not import the Google protobuf Python libraries ***"
+    print "*** Try installing package 'python-protobuf' or similar.  ***"
+    print "*************************************************************"
+    print
+    raise
+
+try:
+    import nanopb_pb2
+except:
+    print
+    print "***************************************************************"
+    print "*** Could not import the precompiled nanopb_pb2.py.         ***"
+    print "*** Run 'make' in the 'generator' folder to update the file.***"
+    print "***************************************************************"
+    print
+    raise
+
 import os.path
 
 # Values are tuple (c type, pb ltype)
@@ -55,7 +75,7 @@ class Enum:
         self.values = [(self.names + x.name, x.number) for x in desc.value]
     
     def __str__(self):
-        result = 'typedef enum {\n'
+        result = 'typedef enum _%s {\n' % self.names
         result += ',\n'.join(["    %s = %d" % x for x in self.values])
         result += '\n} %s;' % self.names
         return result
@@ -197,10 +217,10 @@ class Field:
         prev_field_name is the name of the previous field or None.
         '''
         result = '    {%d, ' % self.tag
-        result += self.htype
+        result += '(pb_type_t) ((int) ' + self.htype
         if self.ltype is not None:
-            result += ' | ' + self.ltype
-        result += ',\n'
+            result += ' | (int) ' + self.ltype
+        result += '),\n'
         
         if prev_field_name is None:
             result += '    offsetof(%s, %s),' % (self.struct_name, self.name)
@@ -231,6 +251,18 @@ class Field:
             result += '\n    &%s_default}' % (self.struct_name + self.name)
         
         return result
+    
+    def largest_field_value(self):
+        '''Determine if this field needs 16bit or 32bit pb_field_t structure to compile properly.
+        Returns numeric value or a C-expression for assert.'''
+        if self.ltype == 'PB_LTYPE_SUBMESSAGE':
+            if self.htype == 'PB_HTYPE_ARRAY':
+                return 'pb_membersize(%s, %s[0])' % (self.struct_name, self.name)
+            else:
+                return 'pb_membersize(%s, %s)' % (self.struct_name, self.name)
+
+        return max(self.tag, self.max_size, self.max_count)        
+
 
 class Message:
     def __init__(self, names, desc):
@@ -244,7 +276,7 @@ class Message:
         return [str(field.ctype) for field in self.fields]
     
     def __str__(self):
-        result = 'typedef struct {\n'
+        result = 'typedef struct _%s {\n' % self.name
         result += '\n'.join([str(f) for f in self.ordered_fields])
         result += '\n} %s;' % self.name
         return result
@@ -323,7 +355,7 @@ def toposort2(data):
     '''
     for k, v in data.items():
         v.discard(k) # Ignore self dependencies
-    extra_items_in_deps = reduce(set.union, data.values()) - set(data.keys())
+    extra_items_in_deps = reduce(set.union, data.values(), set()) - set(data.keys())
     data.update(dict([(item, set()) for item in extra_items_in_deps]))
     while True:
         ordered = set(item for item,dep in data.items() if not dep)
@@ -347,7 +379,7 @@ def sort_dependencies(messages):
         if msgname in message_by_name:
             yield message_by_name[msgname]
 
-def generate_header(headername, enums, messages):
+def generate_header(dependencies, headername, enums, messages):
     '''Generate content for a header file.
     Generates strings, which should be concatenated and stored to file.
     '''
@@ -359,6 +391,11 @@ def generate_header(headername, enums, messages):
     yield '#define _PB_%s_\n' % symbol
     yield '#include <pb.h>\n\n'
     
+    for dependency in dependencies:
+        noext = os.path.splitext(dependency)[0]
+        yield '#include "%s.pb.h"\n' % noext
+    yield '\n'
+    
     yield '/* Enum definitions */\n'
     for enum in enums:
         yield str(enum) + '\n\n'
@@ -377,6 +414,51 @@ def generate_header(headername, enums, messages):
     for msg in messages:
         yield msg.fields_declaration() + '\n'
     
+    if messages:
+        count_required_fields = lambda m: len([f for f in msg.fields if f.htype == 'PB_HTYPE_REQUIRED'])
+        largest_msg = max(messages, key = count_required_fields)
+        largest_count = count_required_fields(largest_msg)
+        if largest_count > 64:
+            yield '\n/* Check that missing required fields will be properly detected */\n'
+            yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count
+            yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name
+            yield '       setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count
+            yield '#endif\n'
+    
+    worst = 0
+    worst_field = ''
+    checks = []
+    for msg in messages:
+        for field in msg.fields:
+            status = field.largest_field_value()
+            if isinstance(status, (str, unicode)):
+                checks.append(status)
+            elif status > worst:
+                worst = status
+                worst_field = str(field.struct_name) + '.' + str(field.name)
+
+    if worst > 255 or checks:
+        yield '\n/* Check that field information fits in pb_field_t */\n'
+        
+        if worst < 65536:
+            yield '#if !defined(PB_FIELD_16BIT) && !defined(PB_FIELD_32BIT)\n'
+            if worst > 255:
+                yield '#error Field descriptor for %s is too large. Define PB_FIELD_16BIT to fix this.\n' % worst_field
+            else:
+                assertion = ' && '.join(str(c) + ' < 256' for c in checks)
+                yield 'STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT)\n' % assertion
+            yield '#endif\n\n'
+        
+        if worst > 65535 or checks:
+            yield '#if !defined(PB_FIELD_32BIT)\n'
+            if worst > 65535:
+                yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field
+            else:
+                assertion = ' && '.join(str(c) + ' < 65536' for c in checks)
+                yield 'STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT)\n' % assertion
+            yield '#endif\n'
+    
+    # End of header
     yield '\n#endif\n'
 
 def generate_source(headername, enums, messages):
@@ -404,7 +486,7 @@ if __name__ == '__main__':
         print "Output fill be written to file.pb.h and file.pb.c"
         sys.exit(1)
     
-    data = open(sys.argv[1]).read()
+    data = open(sys.argv[1], 'rb').read()
     fdesc = descriptor.FileDescriptorSet.FromString(data)
     enums, messages = parse_file(fdesc.file[0])
     
@@ -415,8 +497,13 @@ if __name__ == '__main__':
     
     print "Writing to " + headername + " and " + sourcename
     
+    # List of .proto files that should not be included in the C header file
+    # even if they are mentioned in the source .proto.
+    excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto']
+    dependencies = [d for d in fdesc.file[0].dependency if d not in excludes]
+    
     header = open(headername, 'w')
-    for part in generate_header(headerbasename, enums, messages):
+    for part in generate_header(dependencies, headerbasename, enums, messages):
         header.write(part)
 
     source = open(sourcename, 'w')