Rename poorly named identifier to avoid name conflicts.
[apps/agl-service-can-low-level.git] / generator / nanopb_generator.py
index 475e78a..5010814 100755 (executable)
@@ -1,7 +1,7 @@
 #!/usr/bin/python
 
 '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
-nanopb_version = "nanopb-0.2.5"
+nanopb_version = "nanopb-0.3.0-dev"
 
 import sys
 
@@ -15,6 +15,7 @@ except:
 
 try:
     import google.protobuf.text_format as text_format
+    import google.protobuf.descriptor_pb2 as descriptor
 except:
     sys.stderr.write('''
          *************************************************************
@@ -26,7 +27,7 @@ except:
 
 try:
     import proto.nanopb_pb2 as nanopb_pb2
-    import proto.descriptor_pb2 as descriptor
+    import proto.plugin_pb2 as plugin_pb2
 except:
     sys.stderr.write('''
          ********************************************************************
@@ -36,7 +37,6 @@ except:
     ''' + '\n')
     raise
 
-
 # ---------------------------------------------------------------------------
 #                     Generation of single fields
 # ---------------------------------------------------------------------------
@@ -169,6 +169,7 @@ class Field:
         self.max_count = None
         self.array_decl = ""
         self.enc_size = None
+        self.ctype = None
         
         # Parse field options
         if field_options.HasField("max_size"):
@@ -245,7 +246,7 @@ class Field:
                 self.ctype = self.struct_name + self.name + 't'
                 self.enc_size = varint_max_size(self.max_size) + self.max_size
             elif self.allocation == 'POINTER':
-                self.ctype = 'pb_bytes_ptr_t'
+                self.ctype = 'pb_bytes_array_t'
         elif desc.type == FieldD.TYPE_MESSAGE:
             self.pbtype = 'MESSAGE'
             self.ctype = self.submsgname = names_from_type_name(desc.type_name)
@@ -260,13 +261,13 @@ class Field:
         result = ''
         if self.allocation == 'POINTER':
             if self.rules == 'REPEATED':
-                result += '    size_t ' + self.name + '_count;\n'
+                result += '    pb_size_t ' + self.name + '_count;\n'
             
             if self.pbtype == 'MESSAGE':
                 # Use struct definition, so recursive submessages are possible
                 result += '    struct _%s *%s;' % (self.ctype, self.name)
-            elif self.rules == 'REPEATED' and self.pbtype == 'STRING':
-                # String arrays need to be defined as pointers to pointers
+            elif self.rules == 'REPEATED' and self.pbtype in ['STRING', 'BYTES']:
+                # String/bytes arrays need to be defined as pointers to pointers
                 result += '    %s **%s;' % (self.ctype, self.name)
             else:
                 result += '    %s *%s;' % (self.ctype, self.name)
@@ -276,44 +277,75 @@ class Field:
             if self.rules == 'OPTIONAL' and self.allocation == 'STATIC':
                 result += '    bool has_' + self.name + ';\n'
             elif self.rules == 'REPEATED' and self.allocation == 'STATIC':
-                result += '    size_t ' + self.name + '_count;\n'
+                result += '    pb_size_t ' + self.name + '_count;\n'
             result += '    %s %s%s;' % (self.ctype, self.name, self.array_decl)
         return result
     
     def types(self):
         '''Return definitions for any special types this field might need.'''
         if self.pbtype == 'BYTES' and self.allocation == 'STATIC':
-            result = 'typedef struct {\n'
-            result += '    size_t size;\n'
-            result += '    uint8_t bytes[%d];\n' % self.max_size
-            result += '} %s;\n' % self.ctype
+            result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, self.ctype)
         else:
             result = None
         return result
     
+    def get_initializer(self, null_init):
+        '''Return literal expression for this field's default value.'''
+        
+        if self.pbtype == 'MESSAGE':
+            if null_init:
+                return '%s_init_zero' % self.ctype
+            else:
+                return '%s_init_default' % self.ctype
+        
+        if self.default is None or null_init:
+            if self.pbtype == 'STRING':
+                return '""'
+            elif self.pbtype == 'BYTES':
+                return '{0, {0}}'
+            elif self.pbtype == 'ENUM':
+                return '(%s)0' % self.ctype
+            else:
+                return '0'
+        
+        default = str(self.default)
+        
+        if self.pbtype == 'STRING':
+            default = default.encode('utf-8').encode('string_escape')
+            default = default.replace('"', '\\"')
+            default = '"' + default + '"'
+        elif self.pbtype == 'BYTES':
+            data = default.decode('string_escape')
+            data = ['0x%02x' % ord(c) for c in data]
+            if len(data) == 0:
+                default = '{0, {0}}'
+            else:
+                default = '{%d, {%s}}' % (len(data), ','.join(data))
+        elif self.pbtype in ['FIXED32', 'UINT32']:
+            default += 'u'
+        elif self.pbtype in ['FIXED64', 'UINT64']:
+            default += 'ull'
+        elif self.pbtype in ['SFIXED64', 'INT64']:
+            default += 'll'
+        
+        return default
+    
     def default_decl(self, declaration_only = False):
         '''Return definition for this field's default value.'''
         if self.default is None:
             return None
 
-        ctype, default = self.ctype, self.default
+        ctype = self.ctype
+        default = self.get_initializer(False)
         array_decl = ''
         
         if self.pbtype == 'STRING':
             if self.allocation != 'STATIC':
                 return None # Not implemented
-        
             array_decl = '[%d]' % self.max_size
-            default = str(self.default).encode('string_escape')
-            default = default.replace('"', '\\"')
-            default = '"' + default + '"'
         elif self.pbtype == 'BYTES':
             if self.allocation != 'STATIC':
                 return None # Not implemented
-
-            data = self.default.decode('string_escape')
-            data = ['0x%02x' % ord(c) for c in data]
-            default = '{%d, {%s}}' % (len(data), ','.join(data))
         
         if declaration_only:
             return 'extern const %s %s_default%s;' % (ctype, self.struct_name + self.name, array_decl)
@@ -344,6 +376,8 @@ class Field:
             result += '0)'
         elif self.pbtype in ['BYTES', 'STRING'] and self.allocation != 'STATIC':
             result += '0)' # Arbitrary size default values not implemented
+        elif self.rules == 'OPTEXT':
+            result += '0)' # Default value for extensions is not implemented
         else:
             result += '&%s_default)' % (self.struct_name + self.name)
         
@@ -433,7 +467,7 @@ class ExtensionRange(Field):
     
     def tags(self):
         return ''
-        
+    
     def encoded_size(self, allmsgs):
         # We exclude extensions from the count, because they cannot be known
         # until runtime. Other option would be to return None here, but this
@@ -544,6 +578,32 @@ class Message:
                 result += types + '\n'
         return result
     
+    def get_initializer(self, null_init):
+        if not self.ordered_fields:
+            return '{0}'
+    
+        parts = []
+        for field in self.ordered_fields:
+            if field.allocation == 'STATIC':
+                if field.rules == 'REPEATED':
+                    parts.append('0')
+                    parts.append('{'
+                                 + ', '.join([field.get_initializer(null_init)] * field.max_count)
+                                 + '}')
+                elif field.rules == 'OPTIONAL':
+                    parts.append('false')
+                    parts.append(field.get_initializer(null_init))
+                else:
+                    parts.append(field.get_initializer(null_init))
+            elif field.allocation == 'POINTER':
+                parts.append('NULL')
+            elif field.allocation == 'CALLBACK':
+                if field.pbtype == 'EXTENSION':
+                    parts.append('NULL')
+                else:
+                    parts.append('{{NULL}, NULL}')
+        return '{' + ', '.join(parts) + '}'
+    
     def default_decl(self, declaration_only = False):
         result = ""
         for field in self.fields:
@@ -630,13 +690,17 @@ def parse_file(fdesc, file_options):
     
     for names, message in iterate_messages(fdesc, base_name):
         message_options = get_nanopb_suboptions(message, file_options, names)
+        
+        if message_options.skip_message:
+            continue
+        
         messages.append(Message(names, message, message_options))
         for enum in message.enum_type:
             enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name)
             enums.append(Enum(names, enum, enum_options))
     
     for names, extension in iterate_extensions(fdesc, base_name):
-        field_options = get_nanopb_suboptions(extension, file_options, names)
+        field_options = get_nanopb_suboptions(extension, file_options, names + extension.name)
         if field_options.type != nanopb_pb2.FT_IGNORE:
             extensions.append(ExtensionField(names, extension, field_options))
     
@@ -698,11 +762,14 @@ def generate_header(dependencies, headername, enums, messages, extensions, optio
     '''
     
     yield '/* Automatically generated nanopb header */\n'
-    yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
+    if options.notimestamp:
+        yield '/* Generated by %s */\n\n' % (nanopb_version)
+    else:
+        yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
     
     symbol = make_identifier(headername)
-    yield '#ifndef _PB_%s_\n' % symbol
-    yield '#define _PB_%s_\n' % symbol
+    yield '#ifndef PB_%s_INCLUDED\n' % symbol
+    yield '#define PB_%s_INCLUDED\n' % symbol
     try:
         yield options.libformat % ('pb.h')
     except TypeError:
@@ -712,7 +779,7 @@ def generate_header(dependencies, headername, enums, messages, extensions, optio
     
     for dependency in dependencies:
         noext = os.path.splitext(dependency)[0]
-        yield options.genformat % (noext + '.' + options.extension + '.h')
+        yield options.genformat % (noext + options.extension + '.h')
         yield '\n'
 
     yield '#ifdef __cplusplus\n'
@@ -739,6 +806,15 @@ def generate_header(dependencies, headername, enums, messages, extensions, optio
         yield msg.default_decl(True)
     yield '\n'
     
+    yield '/* Initializer values for message structs */\n'
+    for msg in messages:
+        identifier = '%s_init_default' % msg.name
+        yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False))
+    for msg in messages:
+        identifier = '%s_init_zero' % msg.name
+        yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True))
+    yield '\n'
+    
     yield '/* Field tags (for use in manual encoding/decoding) */\n'
     for msg in sort_dependencies(messages):
         for field in msg.fields:
@@ -771,7 +847,10 @@ def generate_source(headername, enums, messages, extensions, options):
     '''Generate content for a source file.'''
     
     yield '/* Automatically generated nanopb constant definitions */\n'
-    yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
+    if options.notimestamp:
+        yield '/* Generated by %s */\n\n' % (nanopb_version)
+    else:
+        yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
     yield options.genformat % (headername)
     yield '\n'
     
@@ -815,6 +894,23 @@ def generate_source(headername, enums, messages, extensions, options):
     if worst > 255 or checks:
         yield '\n/* Check that field information fits in pb_field_t */\n'
         
+        if worst > 65535 or checks:
+            yield '#if !defined(PB_FIELD_32BIT)\n'
+            if worst > 65535:
+                yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field
+            else:
+                assertion = ' && '.join(str(c) + ' < 65536' for c in checks)
+                msgs = '_'.join(str(n) for n in checks_msgnames)
+                yield '/* If you get an error here, it means that you need to define PB_FIELD_32BIT\n'
+                yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n'
+                yield ' * \n'
+                yield ' * The reason you need to do this is that some of your messages contain tag\n'
+                yield ' * numbers or field sizes that are larger than what can fit in 8 or 16 bit\n'
+                yield ' * field descriptors.\n'
+                yield ' */\n'
+                yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
+            yield '#endif\n\n'
+        
         if worst < 65536:
             yield '#if !defined(PB_FIELD_16BIT) && !defined(PB_FIELD_32BIT)\n'
             if worst > 255:
@@ -822,18 +918,15 @@ def generate_source(headername, enums, messages, extensions, options):
             else:
                 assertion = ' && '.join(str(c) + ' < 256' for c in checks)
                 msgs = '_'.join(str(n) for n in checks_msgnames)
-                yield 'STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
+                yield '/* If you get an error here, it means that you need to define PB_FIELD_16BIT\n'
+                yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n'
+                yield ' * \n'
+                yield ' * The reason you need to do this is that some of your messages contain tag\n'
+                yield ' * numbers or field sizes that are larger than what can fit in the default\n'
+                yield ' * 8 bit descriptors.\n'
+                yield ' */\n'
+                yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
             yield '#endif\n\n'
-        
-        if worst > 65535 or checks:
-            yield '#if !defined(PB_FIELD_32BIT)\n'
-            if worst > 65535:
-                yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field
-            else:
-                assertion = ' && '.join(str(c) + ' < 65536' for c in checks)
-                msgs = '_'.join(str(n) for n in checks_msgnames)
-                yield 'STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
-            yield '#endif\n'
     
     # Add check for sizeof(double)
     has_double = False
@@ -848,7 +941,7 @@ def generate_source(headername, enums, messages, extensions, options):
         yield ' * These are not directly supported by nanopb, but see example_avr_double.\n'
         yield ' * To get rid of this error, remove any double fields from your .proto.\n'
         yield ' */\n'
-        yield 'STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n'
+        yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n'
     
     yield '\n'
 
@@ -879,6 +972,7 @@ class Globals:
     '''Ugly global variables, should find a good way to pass these.'''
     verbose_options = False
     separate_options = []
+    matched_namemasks = set()
 
 def get_nanopb_suboptions(subdesc, options, name):
     '''Get copy of options, and merge information from subdesc.'''
@@ -889,6 +983,7 @@ def get_nanopb_suboptions(subdesc, options, name):
     dotname = '.'.join(name.parts)
     for namemask, options in Globals.separate_options:
         if fnmatch(dotname, namemask):
+            Globals.matched_namemasks.add(namemask)
             new_options.MergeFrom(options)
     
     # Handle options defined in .proto
@@ -908,8 +1003,8 @@ def get_nanopb_suboptions(subdesc, options, name):
         new_options.MergeFrom(ext)
     
     if Globals.verbose_options:
-        print "Options for " + dotname + ":"
-        print text_format.MessageToString(new_options)
+        sys.stderr.write("Options for " + dotname + ": ")
+        sys.stderr.write(text_format.MessageToString(new_options) + "\n")
     
     return new_options
 
@@ -928,8 +1023,8 @@ optparser = OptionParser(
              "Output will be written to file.pb.h and file.pb.c.")
 optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[],
     help="Exclude file from generated #include list.")
-optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default="pb",
-    help="Set extension to use instead of 'pb' for generated files. [default: %default]")
+optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default=".pb",
+    help="Set extension to use instead of '.pb' for generated files. [default: %default]")
 optparser.add_option("-f", "--options-file", dest="options_file", metavar="FILE", default="%s.options",
     help="Set name of a separate generator options file.")
 optparser.add_option("-Q", "--generated-include-format", dest="genformat",
@@ -938,6 +1033,8 @@ optparser.add_option("-Q", "--generated-include-format", dest="genformat",
 optparser.add_option("-L", "--library-include-format", dest="libformat",
     metavar="FORMAT", default='#include <%s>\n',
     help="Set format string to use for including the nanopb pb.h header. [default: %default]")
+optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=False,
+    help="Don't add timestamp to .pb.h and .pb.c preambles")
 optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False,
     help="Don't print anything except errors.")
 optparser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False,
@@ -967,28 +1064,38 @@ def process_file(filename, fdesc, options):
         fdesc = descriptor.FileDescriptorSet.FromString(data).file[0]
     
     # Check if there is a separate .options file
+    had_abspath = False
     try:
         optfilename = options.options_file % os.path.splitext(filename)[0]
     except TypeError:
         # No %s specified, use the filename as-is
         optfilename = options.options_file
-    
-    if options.verbose:
-        print 'Reading options from ' + optfilename
-    
+        had_abspath = True
+
     if os.path.isfile(optfilename):
+        if options.verbose:
+            sys.stderr.write('Reading options from ' + optfilename + '\n')
+
         Globals.separate_options = read_options_file(open(optfilename, "rU"))
     else:
+        # If we are given a full filename and it does not exist, give an error.
+        # However, don't give error when we automatically look for .options file
+        # with the same name as .proto.
+        if options.verbose or had_abspath:
+            sys.stderr.write('Options file not found: ' + optfilename)
+
         Globals.separate_options = []
+
+    Globals.matched_namemasks = set()
     
     # Parse the file
     file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename]))
     enums, messages, extensions = parse_file(fdesc, file_options)
-    
+
     # Decide the file names
     noext = os.path.splitext(filename)[0]
-    headername = noext + '.' + options.extension + '.h'
-    sourcename = noext + '.' + options.extension + '.c'
+    headername = noext + options.extension + '.h'
+    sourcename = noext + options.extension + '.c'
     headerbasename = os.path.basename(headername)
     
     # List of .proto files that should not be included in the C header file
@@ -1002,6 +1109,14 @@ def process_file(filename, fdesc, options):
     sourcedata = ''.join(generate_source(headerbasename, enums,
                                          messages, extensions, options))
 
+    # Check if there were any lines in .options that did not match a member
+    unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks]
+    if unmatched and not options.quiet:
+        sys.stderr.write("Following patterns in " + optfilename + " did not match any fields: "
+                         + ', '.join(unmatched) + "\n")
+        if not Globals.verbose_options:
+            sys.stderr.write("Use  protoc --nanopb-out=-v:.   to see a list of the field names.\n")
+
     return {'headername': headername, 'headerdata': headerdata,
             'sourcename': sourcename, 'sourcedata': sourcedata}
     
@@ -1023,7 +1138,8 @@ def main_cli():
         results = process_file(filename, None, options)
         
         if not options.quiet:
-            print "Writing to " + results['headername'] + " and " + results['sourcename']
+            sys.stderr.write("Writing to " + results['headername'] + " and "
+                             + results['sourcename'] + "\n")
     
         open(results['headername'], 'w').write(results['headerdata'])
         open(results['sourcename'], 'w').write(results['sourcedata'])        
@@ -1038,7 +1154,6 @@ def main_plugin():
         msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
         msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
     
-    import proto.plugin_pb2 as plugin_pb2
     data = sys.stdin.read()
     request = plugin_pb2.CodeGeneratorRequest.FromString(data)
     
@@ -1046,10 +1161,7 @@ def main_plugin():
     args = shlex.split(request.parameter)
     options, dummy = optparser.parse_args(args)
     
-    # We can't go printing stuff to stdout
-    Globals.verbose_options = False
-    options.verbose = False
-    options.quiet = True
+    Globals.verbose_options = options.verbose
     
     response = plugin_pb2.CodeGeneratorResponse()