Extend inline / fixed length bytes array support (issue #244)
[apps/agl-service-can-low-level.git] / generator / nanopb_generator.py
index 0e9b018..9cce6a5 100755 (executable)
@@ -3,7 +3,7 @@
 from __future__ import unicode_literals
 
 '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
-nanopb_version = "nanopb-0.3.6-dev"
+nanopb_version = "nanopb-0.3.8-dev"
 
 import sys
 import re
@@ -207,6 +207,27 @@ class Enum:
             for i, x in enumerate(self.values):
                 result += '\n#define %s %s' % (self.value_longnames[i], x[0])
 
+        if self.options.enum_to_string:
+            result += '\nconst char *%s_name(%s v);\n' % (self.names, self.names)
+
+        return result
+
+    def enum_to_string_definition(self):
+        if not self.options.enum_to_string:
+            return ""
+
+        result = 'const char *%s_name(%s v) {\n' % (self.names, self.names)
+        result += '    switch (v) {\n'
+
+        for ((enumname, _), strname) in zip(self.values, self.value_longnames):
+            # Strip off the leading type name from the string value.
+            strval = str(strname)[len(str(self.names)) + 1:]
+            result += '        case %s: return "%s";\n' % (enumname, strval)
+
+        result += '    }\n'
+        result += '    return "unknown";\n'
+        result += '}\n'
+
         return result
 
 class FieldMaxSize:
@@ -217,7 +238,7 @@ class FieldMaxSize:
             self.worst = worst
 
         self.worst_field = field_name
-        self.checks = checks
+        self.checks = list(checks)
 
     def extend(self, extend, field_name = None):
         self.worst = max(self.worst, extend.worst)
@@ -241,9 +262,20 @@ class Field:
         self.enc_size = None
         self.ctype = None
 
+        if field_options.type == nanopb_pb2.FT_INLINE:
+            # Before nanopb-0.3.8, fixed length bytes arrays were specified
+            # by setting type to FT_INLINE. But to handle pointer typed fields,
+            # it makes sense to have it as a separate option.
+            field_options.type = nanopb_pb2.FT_STATIC
+            field_options.fixed_length = True
+
         # Parse field options
         if field_options.HasField("max_size"):
             self.max_size = field_options.max_size
+        
+        if desc.type == FieldD.TYPE_STRING and field_options.HasField("max_length"):
+            # max_length overrides max_size for strings
+            self.max_size = field_options.max_length + 1
 
         if field_options.HasField("max_count"):
             self.max_count = field_options.max_count
@@ -253,16 +285,18 @@ class Field:
 
         # Check field rules, i.e. required/optional/repeated.
         can_be_static = True
-        if desc.label == FieldD.LABEL_REQUIRED:
-            self.rules = 'REQUIRED'
-        elif desc.label == FieldD.LABEL_OPTIONAL:
-            self.rules = 'OPTIONAL'
-        elif desc.label == FieldD.LABEL_REPEATED:
+        if desc.label == FieldD.LABEL_REPEATED:
             self.rules = 'REPEATED'
             if self.max_count is None:
                 can_be_static = False
             else:
                 self.array_decl = '[%d]' % self.max_count
+        elif field_options.proto3:
+            self.rules = 'SINGULAR'
+        elif desc.label == FieldD.LABEL_REQUIRED:
+            self.rules = 'REQUIRED'
+        elif desc.label == FieldD.LABEL_OPTIONAL:
+            self.rules = 'OPTIONAL'
         else:
             raise NotImplementedError(desc.label)
 
@@ -317,12 +351,17 @@ class Field:
                 self.array_decl += '[%d]' % self.max_size
                 self.enc_size = varint_max_size(self.max_size) + self.max_size
         elif desc.type == FieldD.TYPE_BYTES:
-            self.pbtype = 'BYTES'
-            if self.allocation == 'STATIC':
-                self.ctype = self.struct_name + self.name + 't'
+            if field_options.fixed_length:
+                self.pbtype = 'FIXED_LENGTH_BYTES'
                 self.enc_size = varint_max_size(self.max_size) + self.max_size
-            elif self.allocation == 'POINTER':
+                self.ctype = 'pb_byte_t'
+                self.array_decl += '[%d]' % self.max_size
+            else:
+                self.pbtype = 'BYTES'
                 self.ctype = 'pb_bytes_array_t'
+                if self.allocation == 'STATIC':
+                    self.ctype = self.struct_name + self.name + 't'
+                    self.enc_size = varint_max_size(self.max_size) + self.max_size
         elif desc.type == FieldD.TYPE_MESSAGE:
             self.pbtype = 'MESSAGE'
             self.ctype = self.submsgname = names_from_type_name(desc.type_name)
@@ -342,6 +381,9 @@ class Field:
             if self.pbtype == 'MESSAGE':
                 # Use struct definition, so recursive submessages are possible
                 result += '    struct _%s *%s;' % (self.ctype, self.name)
+            elif self.pbtype == 'FIXED_LENGTH_BYTES':
+                # Pointer to fixed size array
+                result += '    %s (*%s)%s;' % (self.ctype, self.name, self.array_decl)
             elif self.rules == 'REPEATED' and self.pbtype in ['STRING', 'BYTES']:
                 # String/bytes arrays need to be defined as pointers to pointers
                 result += '    %s **%s;' % (self.ctype, self.name)
@@ -389,6 +431,8 @@ class Field:
                 inner_init = '""'
             elif self.pbtype == 'BYTES':
                 inner_init = '{0, {0}}'
+            elif self.pbtype == 'FIXED_LENGTH_BYTES':
+                inner_init = '{0}'
             elif self.pbtype in ('ENUM', 'UENUM'):
                 inner_init = '(%s)0' % self.ctype
             else:
@@ -403,6 +447,12 @@ class Field:
                     inner_init = '{0, {0}}'
                 else:
                     inner_init = '{%d, {%s}}' % (len(data), ','.join(data))
+            elif self.pbtype == 'FIXED_LENGTH_BYTES':
+                data = ['0x%02x' % ord(c) for c in self.default]
+                if len(data) == 0:
+                    inner_init = '{0}'
+                else:
+                    inner_init = '{%s}' % ','.join(data)
             elif self.pbtype in ['FIXED32', 'UINT32']:
                 inner_init = str(self.default) + 'u'
             elif self.pbtype in ['FIXED64', 'UINT64']:
@@ -454,6 +504,10 @@ class Field:
         elif self.pbtype == 'BYTES':
             if self.allocation != 'STATIC':
                 return None # Not implemented
+        elif self.pbtype == 'FIXED_LENGTH_BYTES':
+            if self.allocation != 'STATIC':
+                return None # Not implemented
+            array_decl = '[%d]' % self.max_size
 
         if declaration_only:
             return 'extern const %s %s_default%s;' % (ctype, self.struct_name + self.name, array_decl)
@@ -465,9 +519,10 @@ class Field:
         identifier = '%s_%s_tag' % (self.struct_name, self.name)
         return '#define %-40s %d\n' % (identifier, self.tag)
 
-    def pb_field_t(self, prev_field_name):
+    def pb_field_t(self, prev_field_name, union_index = None):
         '''Return the pb_field_t initializer to use in the constant array.
-        prev_field_name is the name of the previous field or None.
+        prev_field_name is the name of the previous field or None. For OneOf
+        unions, union_index is the index of this field inside the OneOf.
         '''
 
         if self.rules == 'ONEOF':
@@ -482,7 +537,14 @@ class Field:
         result += '%-8s, ' % self.pbtype
         result += '%s, ' % self.rules
         result += '%-8s, ' % self.allocation
-        result += '%s, ' % ("FIRST" if not prev_field_name else "OTHER")
+        
+        if union_index is not None and union_index > 0:
+            result += 'UNION, '
+        elif prev_field_name is None:
+            result += 'FIRST, '
+        else:
+            result += 'OTHER, '
+        
         result += '%s, ' % self.struct_name
         result += '%s, ' % self.name
         result += '%s, ' % (prev_field_name or self.name)
@@ -491,7 +553,7 @@ class Field:
             result += '&%s_fields)' % self.submsgname
         elif self.default is None:
             result += '0)'
-        elif self.pbtype in ['BYTES', 'STRING'] and self.allocation != 'STATIC':
+        elif self.pbtype in ['BYTES', 'STRING', 'FIXED_LENGTH_BYTES'] and self.allocation != 'STATIC':
             result += '0)' # Arbitrary size default values not implemented
         elif self.rules == 'OPTEXT':
             result += '0)' # Default value for extensions is not implemented
@@ -507,8 +569,8 @@ class Field:
         '''Determine if this field needs 16bit or 32bit pb_field_t structure to compile properly.
         Returns numeric value or a C-expression for assert.'''
         check = []
-        if self.pbtype == 'MESSAGE':
-            if self.rules == 'REPEATED' and self.allocation == 'STATIC':
+        if self.pbtype == 'MESSAGE' and self.allocation == 'STATIC':
+            if self.rules == 'REPEATED':
                 check.append('pb_membersize(%s, %s[0])' % (self.struct_name, self.name))
             elif self.rules == 'ONEOF':
                 if self.anonymous:
@@ -517,6 +579,9 @@ class Field:
                     check.append('pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name))
             else:
                 check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name))
+        elif self.pbtype == 'BYTES' and self.allocation == 'STATIC':
+            if self.max_size > 251:
+                check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name))
 
         return FieldMaxSize([self.tag, self.max_size, self.max_count],
                             check,
@@ -718,8 +783,10 @@ class OneOf(Field):
         return ''.join([f.tags() for f in self.fields])
 
     def pb_field_t(self, prev_field_name):
-        result = ',\n'.join([f.pb_field_t(prev_field_name) for f in self.fields])
-        return result
+        parts = []
+        for union_index, field in enumerate(self.fields):
+            parts.append(field.pb_field_t(prev_field_name, union_index))
+        return ',\n'.join(parts)
 
     def get_last_field_name(self):
         if self.anonymous:
@@ -1045,7 +1112,10 @@ class ProtoFile:
         else:
             yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
 
-        symbol = make_identifier(headername)
+        if self.fdesc.package:
+            symbol = make_identifier(self.fdesc.package + '_' + headername)
+        else:
+            symbol = make_identifier(headername)
         yield '#ifndef PB_%s_INCLUDED\n' % symbol
         yield '#define PB_%s_INCLUDED\n' % symbol
         try:
@@ -1059,7 +1129,7 @@ class ProtoFile:
             noext = os.path.splitext(incfile)[0]
             yield options.genformat % (noext + options.extension + '.h')
             yield '\n'
-            
+
         yield '/* @@protoc_insertion_point(includes) */\n'
 
         yield '#if PB_PROTO_HEADER_VERSION != 30\n'
@@ -1119,9 +1189,11 @@ class ProtoFile:
             yield '/* Maximum encoded size of messages (where known) */\n'
             for msg in self.messages:
                 msize = msg.encoded_size(self.dependencies)
+                identifier = '%s_size' % msg.name
                 if msize is not None:
-                    identifier = '%s_size' % msg.name
                     yield '#define %-40s %s\n' % (identifier, msize)
+                else:
+                    yield '/* %s depends on runtime parameters */\n' % identifier
             yield '\n'
 
             yield '/* Message IDs (where set with "msgid" option) */\n'
@@ -1187,6 +1259,9 @@ class ProtoFile:
         for ext in self.extensions:
             yield ext.extension_def() + '\n'
 
+        for enum in self.enums:
+            yield enum.enum_to_string_definition() + '\n'
+
         # Add checks for numeric limits
         if self.messages:
             largest_msg = max(self.messages, key = lambda m: m.count_required_fields())
@@ -1316,6 +1391,9 @@ def get_nanopb_suboptions(subdesc, options, name):
     new_options = nanopb_pb2.NanoPBOptions()
     new_options.CopyFrom(options)
 
+    if hasattr(subdesc, 'syntax') and subdesc.syntax == "proto3":
+        new_options.proto3 = True
+
     # Handle options defined in a separate file
     dotname = '.'.join(name.parts)
     for namemask, options in Globals.separate_options:
@@ -1367,6 +1445,9 @@ optparser.add_option("-f", "--options-file", dest="options_file", metavar="FILE"
 optparser.add_option("-I", "--options-path", dest="options_path", metavar="DIR",
     action="append", default = [],
     help="Search for .options files additionally in this path")
+optparser.add_option("-D", "--output-dir", dest="output_dir",
+                     metavar="OUTPUTDIR", default=None,
+                     help="Output directory of .pb.h and .pb.c files")
 optparser.add_option("-Q", "--generated-include-format", dest="genformat",
     metavar="FORMAT", default='#include "%s"\n',
     help="Set format string to use for including other .pb.h files. [default: %default]")
@@ -1483,17 +1564,29 @@ def main_cli():
     if options.quiet:
         options.verbose = False
 
-    Globals.verbose_options = options.verbose
+    if options.output_dir and not os.path.exists(options.output_dir):
+        optparser.print_help()
+        sys.stderr.write("\noutput_dir does not exist: %s\n" % options.output_dir)
+        sys.exit(1)
+
 
+    Globals.verbose_options = options.verbose
     for filename in filenames:
         results = process_file(filename, None, options)
 
+        base_dir = options.output_dir or ''
+        to_write = [
+            (os.path.join(base_dir, results['headername']), results['headerdata']),
+            (os.path.join(base_dir, results['sourcename']), results['sourcedata']),
+        ]
+
         if not options.quiet:
-            sys.stderr.write("Writing to " + results['headername'] + " and "
-                             + results['sourcename'] + "\n")
+            paths = " and ".join([x[0] for x in to_write])
+            sys.stderr.write("Writing to %s\n" % paths)
 
-        open(results['headername'], 'w').write(results['headerdata'])
-        open(results['sourcename'], 'w').write(results['sourcedata'])
+        for path, data in to_write:
+            with open(path, 'w') as f:
+                f.write(data)
 
 def main_plugin():
     '''Main function when invoked as a protoc plugin.'''
@@ -1557,4 +1650,3 @@ if __name__ == '__main__':
         main_plugin()
     else:
         main_cli()
-