Sanitize filenames before putting them in #ifndef.
[apps/agl-service-can-low-level.git] / generator / nanopb_generator.py
index dadad64..d35a425 100644 (file)
@@ -1,5 +1,5 @@
 '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
-nanopb_version = "nanopb-0.1.7-dev"
+nanopb_version = "nanopb-0.1.9-dev"
 
 try:
     import google.protobuf.descriptor_pb2 as descriptor
@@ -73,6 +73,9 @@ class Names:
         else:
             raise ValueError("Name parts should be of type str")
     
+    def __eq__(self, other):
+        return isinstance(other, Names) and self.parts == other.parts
+    
 def names_from_type_name(type_name):
     '''Parse Names() from FieldDescriptorProto type_name'''
     if type_name[0] != '.':
@@ -83,12 +86,15 @@ class Enum:
     def __init__(self, names, desc, enum_options):
         '''desc is EnumDescriptorProto'''
         
+        self.options = enum_options
+        self.names = names + desc.name
+        
         if enum_options.long_names:
-            self.names = names + desc.name
+            self.values = [(self.names + x.name, x.number) for x in desc.value]            
         else:
-            self.names = names
+            self.values = [(names + x.name, x.number) for x in desc.value] 
         
-        self.values = [(self.names + x.name, x.number) for x in desc.value]
+        self.value_longnames = [self.names + x.name for x in desc.value]
     
     def __str__(self):
         result = 'typedef enum _%s {\n' % self.names
@@ -300,7 +306,14 @@ class Field:
 class Message:
     def __init__(self, names, desc, message_options):
         self.name = names
-        self.fields = [Field(self.name, f, get_nanopb_suboptions(f, message_options)) for f in desc.field]
+        self.fields = []
+        
+        for f in desc.field:
+            field_options = get_nanopb_suboptions(f, message_options)
+            if field_options.type != nanopb_pb2.FT_IGNORE:
+                self.fields.append(Field(self.name, f, field_options))
+        
+        self.packed = message_options.packed_struct
         self.ordered_fields = self.fields[:]
         self.ordered_fields.sort()
 
@@ -311,7 +324,12 @@ class Message:
     def __str__(self):
         result = 'typedef struct _%s {\n' % self.name
         result += '\n'.join([str(f) for f in self.ordered_fields])
-        result += '\n} %s;' % self.name
+        result += '\n}'
+        
+        if self.packed:
+            result += ' pb_packed'
+        
+        result += ' %s;' % self.name
         return result
     
     def types(self):
@@ -389,7 +407,17 @@ def parse_file(fdesc, file_options):
         message_options = get_nanopb_suboptions(message, file_options)
         messages.append(Message(names, message, message_options))
         for enum in message.enum_type:
-            enums.append(Enum(names, enum, message_options))
+            enum_options = get_nanopb_suboptions(enum, message_options)
+            enums.append(Enum(names, enum, enum_options))
+    
+    # Fix field default values where enum short names are used.
+    for enum in enums:
+        if not enum.options.long_names:
+            for message in messages:
+                for field in message.fields:
+                    if field.default in enum.value_longnames:
+                        idx = enum.value_longnames.index(field.default)
+                        field.default = enum.values[idx][0]
     
     return enums, messages
 
@@ -424,6 +452,16 @@ def sort_dependencies(messages):
         if msgname in message_by_name:
             yield message_by_name[msgname]
 
+def make_identifier(headername):
+    '''Make #ifndef identifier that contains uppercase A-Z and digits 0-9'''
+    result = ""
+    for c in headername.upper():
+        if c.isalnum():
+            result += c
+        else:
+            result += '_'
+    return result
+
 def generate_header(dependencies, headername, enums, messages):
     '''Generate content for a header file.
     Generates strings, which should be concatenated and stored to file.
@@ -432,7 +470,7 @@ def generate_header(dependencies, headername, enums, messages):
     yield '/* Automatically generated nanopb header */\n'
     yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
     
-    symbol = headername.replace('.', '_').upper()
+    symbol = make_identifier(headername)
     yield '#ifndef _PB_%s_\n' % symbol
     yield '#define _PB_%s_\n' % symbol
     yield '#include <pb.h>\n\n'
@@ -477,7 +515,9 @@ def generate_header(dependencies, headername, enums, messages):
     worst = 0
     worst_field = ''
     checks = []
+    checks_msgnames = []
     for msg in messages:
+        checks_msgnames.append(msg.name)
         for field in msg.fields:
             status = field.largest_field_value()
             if isinstance(status, (str, unicode)):
@@ -495,7 +535,8 @@ def generate_header(dependencies, headername, enums, messages):
                 yield '#error Field descriptor for %s is too large. Define PB_FIELD_16BIT to fix this.\n' % worst_field
             else:
                 assertion = ' && '.join(str(c) + ' < 256' for c in checks)
-                yield 'STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT)\n' % assertion
+                msgs = '_'.join(str(n) for n in checks_msgnames)
+                yield 'STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
             yield '#endif\n\n'
         
         if worst > 65535 or checks:
@@ -504,7 +545,8 @@ def generate_header(dependencies, headername, enums, messages):
                 yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field
             else:
                 assertion = ' && '.join(str(c) + ' < 65536' for c in checks)
-                yield 'STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT)\n' % assertion
+                msgs = '_'.join(str(n) for n in checks_msgnames)
+                yield 'STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
             yield '#endif\n'
     
     yield '\n#ifdef __cplusplus\n'