Move the generator .proto files to a subdir, and get rid of precompiled versions.
[apps/agl-service-can-low-level.git] / generator / nanopb_generator.py
index 130ff93..0002409 100755 (executable)
@@ -1,7 +1,7 @@
 #!/usr/bin/python
 
 '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
 #!/usr/bin/python
 
 '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
-nanopb_version = "nanopb-0.2.4-dev"
+nanopb_version = "nanopb-0.2.5-dev"
 
 try:
     import google.protobuf.descriptor_pb2 as descriptor
 
 try:
     import google.protobuf.descriptor_pb2 as descriptor
@@ -16,7 +16,7 @@ except:
     raise
 
 try:
     raise
 
 try:
-    import nanopb_pb2
+    import proto.nanopb_pb2 as nanopb_pb2
 except:
     print
     print "***************************************************************"
 except:
     print
     print "***************************************************************"
@@ -46,7 +46,7 @@ datatypes = {
     FieldD.TYPE_FIXED32:    ('uint32_t', 'FIXED32',     4),
     FieldD.TYPE_FIXED64:    ('uint64_t', 'FIXED64',     8),
     FieldD.TYPE_FLOAT:      ('float',    'FLOAT',       4),
     FieldD.TYPE_FIXED32:    ('uint32_t', 'FIXED32',     4),
     FieldD.TYPE_FIXED64:    ('uint64_t', 'FIXED64',     8),
     FieldD.TYPE_FLOAT:      ('float',    'FLOAT',       4),
-    FieldD.TYPE_INT32:      ('int32_t',  'INT32',       5),
+    FieldD.TYPE_INT32:      ('int32_t',  'INT32',      10),
     FieldD.TYPE_INT64:      ('int64_t',  'INT64',      10),
     FieldD.TYPE_SFIXED32:   ('int32_t',  'SFIXED32',    4),
     FieldD.TYPE_SFIXED64:   ('int64_t',  'SFIXED64',    8),
     FieldD.TYPE_INT64:      ('int64_t',  'INT64',      10),
     FieldD.TYPE_SFIXED32:   ('int32_t',  'SFIXED32',    4),
     FieldD.TYPE_SFIXED64:   ('int64_t',  'SFIXED64',    8),
@@ -94,6 +94,44 @@ assert varint_max_size(0) == 1
 assert varint_max_size(127) == 1
 assert varint_max_size(128) == 2
 
 assert varint_max_size(127) == 1
 assert varint_max_size(128) == 2
 
+class EncodedSize:
+    '''Class used to represent the encoded size of a field or a message.
+    Consists of a combination of symbolic sizes and integer sizes.'''
+    def __init__(self, value = 0, symbols = []):
+        if isinstance(value, (str, Names)):
+            symbols = [str(value)]
+            value = 0
+        self.value = value
+        self.symbols = symbols
+    
+    def __add__(self, other):
+        if isinstance(other, (int, long)):
+            return EncodedSize(self.value + other, self.symbols)
+        elif isinstance(other, (str, Names)):
+            return EncodedSize(self.value, self.symbols + [str(other)])
+        elif isinstance(other, EncodedSize):
+            return EncodedSize(self.value + other.value, self.symbols + other.symbols)
+        else:
+            raise ValueError("Cannot add size: " + repr(other))
+
+    def __mul__(self, other):
+        if isinstance(other, (int, long)):
+            return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols])
+        else:
+            raise ValueError("Cannot multiply size: " + repr(other))
+
+    def __str__(self):
+        if not self.symbols:
+            return str(self.value)
+        else:
+            return '(' + str(self.value) + ' + ' + ' + '.join(self.symbols) + ')'
+
+    def upperlimit(self):
+        if not self.symbols:
+            return self.value
+        else:
+            return 2**32 - 1
+
 class Enum:
     def __init__(self, names, desc, enum_options):
         '''desc is EnumDescriptorProto'''
 class Enum:
     def __init__(self, names, desc, enum_options):
         '''desc is EnumDescriptorProto'''
@@ -151,6 +189,34 @@ class Field:
         else:
             raise NotImplementedError(desc.label)
         
         else:
             raise NotImplementedError(desc.label)
         
+        # Check if the field can be implemented with static allocation
+        # i.e. whether the data size is known.
+        if desc.type == FieldD.TYPE_STRING and self.max_size is None:
+            can_be_static = False
+        
+        if desc.type == FieldD.TYPE_BYTES and self.max_size is None:
+            can_be_static = False
+        
+        # Decide how the field data will be allocated
+        if field_options.type == nanopb_pb2.FT_DEFAULT:
+            if can_be_static:
+                field_options.type = nanopb_pb2.FT_STATIC
+            else:
+                field_options.type = nanopb_pb2.FT_CALLBACK
+        
+        if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static:
+            raise Exception("Field %s is defined as static, but max_size or "
+                            "max_count is not given." % self.name)
+        
+        if field_options.type == nanopb_pb2.FT_STATIC:
+            self.allocation = 'STATIC'
+        elif field_options.type == nanopb_pb2.FT_POINTER:
+            self.allocation = 'POINTER'
+        elif field_options.type == nanopb_pb2.FT_CALLBACK:
+            self.allocation = 'CALLBACK'
+        else:
+            raise NotImplementedError(field_options.type)
+        
         # Decide the C data type to use in the struct.
         if datatypes.has_key(desc.type):
             self.ctype, self.pbtype, self.enc_size = datatypes[desc.type]
         # Decide the C data type to use in the struct.
         if datatypes.has_key(desc.type):
             self.ctype, self.pbtype, self.enc_size = datatypes[desc.type]
@@ -162,19 +228,18 @@ class Field:
             self.enc_size = 5 # protoc rejects enum values > 32 bits
         elif desc.type == FieldD.TYPE_STRING:
             self.pbtype = 'STRING'
             self.enc_size = 5 # protoc rejects enum values > 32 bits
         elif desc.type == FieldD.TYPE_STRING:
             self.pbtype = 'STRING'
-            if self.max_size is None:
-                can_be_static = False
-            else:
+            self.ctype = 'char'
+            if self.allocation == 'STATIC':
                 self.ctype = 'char'
                 self.array_decl += '[%d]' % self.max_size
                 self.enc_size = varint_max_size(self.max_size) + self.max_size
         elif desc.type == FieldD.TYPE_BYTES:
             self.pbtype = 'BYTES'
                 self.ctype = 'char'
                 self.array_decl += '[%d]' % self.max_size
                 self.enc_size = varint_max_size(self.max_size) + self.max_size
         elif desc.type == FieldD.TYPE_BYTES:
             self.pbtype = 'BYTES'
-            if self.max_size is None:
-                can_be_static = False
-            else:
+            if self.allocation == 'STATIC':
                 self.ctype = self.struct_name + self.name + 't'
                 self.enc_size = varint_max_size(self.max_size) + self.max_size
                 self.ctype = self.struct_name + self.name + 't'
                 self.enc_size = varint_max_size(self.max_size) + self.max_size
+            elif self.allocation == 'POINTER':
+                self.ctype = 'pb_bytes_ptr_t'
         elif desc.type == FieldD.TYPE_MESSAGE:
             self.pbtype = 'MESSAGE'
             self.ctype = self.submsgname = names_from_type_name(desc.type_name)
         elif desc.type == FieldD.TYPE_MESSAGE:
             self.pbtype = 'MESSAGE'
             self.ctype = self.submsgname = names_from_type_name(desc.type_name)
@@ -182,35 +247,31 @@ class Field:
         else:
             raise NotImplementedError(desc.type)
         
         else:
             raise NotImplementedError(desc.type)
         
-        if field_options.type == nanopb_pb2.FT_DEFAULT:
-            if can_be_static:
-                field_options.type = nanopb_pb2.FT_STATIC
-            else:
-                field_options.type = nanopb_pb2.FT_CALLBACK
-        
-        if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static:
-            raise Exception("Field %s is defined as static, but max_size or max_count is not given." % self.name)
-        
-        if field_options.type == nanopb_pb2.FT_STATIC:
-            self.allocation = 'STATIC'
-        elif field_options.type == nanopb_pb2.FT_CALLBACK:
-            self.allocation = 'CALLBACK'
-            self.ctype = 'pb_callback_t'
-            self.array_decl = ''
-        else:
-            raise NotImplementedError(field_options.type)
-    
     def __cmp__(self, other):
         return cmp(self.tag, other.tag)
     
     def __str__(self):
     def __cmp__(self, other):
         return cmp(self.tag, other.tag)
     
     def __str__(self):
-        if self.rules == 'OPTIONAL' and self.allocation == 'STATIC':
-            result = '    bool has_' + self.name + ';\n'
-        elif self.rules == 'REPEATED' and self.allocation == 'STATIC':
-            result = '    size_t ' + self.name + '_count;\n'
+        result = ''
+        if self.allocation == 'POINTER':
+            if self.rules == 'REPEATED':
+                result += '    size_t ' + self.name + '_count;\n'
+            
+            if self.pbtype == 'MESSAGE':
+                # Use struct definition, so recursive submessages are possible
+                result += '    struct _%s *%s;' % (self.ctype, self.name)
+            elif self.rules == 'REPEATED' and self.pbtype == 'STRING':
+                # String arrays need to be defined as pointers to pointers
+                result += '    %s **%s;' % (self.ctype, self.name)
+            else:
+                result += '    %s *%s;' % (self.ctype, self.name)
+        elif self.allocation == 'CALLBACK':
+            result += '    pb_callback_t %s;' % self.name
         else:
         else:
-            result = ''
-        result += '    %s %s%s;' % (self.ctype, self.name, self.array_decl)
+            if self.rules == 'OPTIONAL' and self.allocation == 'STATIC':
+                result += '    bool has_' + self.name + ';\n'
+            elif self.rules == 'REPEATED' and self.allocation == 'STATIC':
+                result += '    size_t ' + self.name + '_count;\n'
+            result += '    %s %s%s;' % (self.ctype, self.name, self.array_decl)
         return result
     
     def types(self):
         return result
     
     def types(self):
@@ -265,7 +326,7 @@ class Field:
         result  = '    PB_FIELD2(%3d, ' % self.tag
         result += '%-8s, ' % self.pbtype
         result += '%s, ' % self.rules
         result  = '    PB_FIELD2(%3d, ' % self.tag
         result += '%-8s, ' % self.pbtype
         result += '%s, ' % self.rules
-        result += '%s, ' % self.allocation
+        result += '%-8s, ' % self.allocation
         result += '%s, ' % ("FIRST" if not prev_field_name else "OTHER")
         result += '%s, ' % self.struct_name
         result += '%s, ' % self.name
         result += '%s, ' % ("FIRST" if not prev_field_name else "OTHER")
         result += '%s, ' % self.struct_name
         result += '%s, ' % self.name
@@ -301,23 +362,32 @@ class Field:
         if self.allocation != 'STATIC':
             return None
         
         if self.allocation != 'STATIC':
             return None
         
-        encsize = self.enc_size
         if self.pbtype == 'MESSAGE':
             for msg in allmsgs:
                 if msg.name == self.submsgname:
                     encsize = msg.encoded_size(allmsgs)
                     if encsize is None:
                         return None # Submessage size is indeterminate
         if self.pbtype == 'MESSAGE':
             for msg in allmsgs:
                 if msg.name == self.submsgname:
                     encsize = msg.encoded_size(allmsgs)
                     if encsize is None:
                         return None # Submessage size is indeterminate
-                    encsize += varint_max_size(encsize) # submsg length is encoded also
+                        
+                    # Include submessage length prefix
+                    encsize += varint_max_size(encsize.upperlimit())
                     break
             else:
                 # Submessage cannot be found, this currently occurs when
                 # the submessage type is defined in a different file.
                     break
             else:
                 # Submessage cannot be found, this currently occurs when
                 # the submessage type is defined in a different file.
-                return None
-        
-        if encsize is None:
+                # Instead of direct numeric value, reference the size that
+                # has been #defined in the other file.
+                encsize = EncodedSize(self.submsgname + 'size')
+
+                # We will have to make a conservative assumption on the length
+                # prefix size, though.
+                encsize += 5
+
+        elif self.enc_size is None:
             raise RuntimeError("Could not determine encoded size for %s.%s"
                                % (self.struct_name, self.name))
             raise RuntimeError("Could not determine encoded size for %s.%s"
                                % (self.struct_name, self.name))
+        else:
+            encsize = EncodedSize(self.enc_size)
         
         encsize += varint_max_size(self.tag << 3) # Tag + wire type
 
         
         encsize += varint_max_size(self.tag << 3) # Tag + wire type
 
@@ -362,7 +432,7 @@ class ExtensionRange(Field):
         # We exclude extensions from the count, because they cannot be known
         # until runtime. Other option would be to return None here, but this
         # way the value remains useful if extensions are not used.
         # We exclude extensions from the count, because they cannot be known
         # until runtime. Other option would be to return None here, but this
         # way the value remains useful if extensions are not used.
-        return 0
+        return EncodedSize(0)
 
 class ExtensionField(Field):
     def __init__(self, struct_name, desc, field_options):
 
 class ExtensionField(Field):
     def __init__(self, struct_name, desc, field_options):
@@ -376,6 +446,11 @@ class ExtensionField(Field):
             self.skip = False
             self.rules = 'OPTEXT'
 
             self.skip = False
             self.rules = 'OPTEXT'
 
+    def tags(self):
+        '''Return the #define for the tag number of this field.'''
+        identifier = '%s_tag' % self.fullname
+        return '#define %-40s %d\n' % (identifier, self.tag)
+
     def extension_decl(self):
         '''Declaration of the extension type in the .pb.h file'''
         if self.skip:
     def extension_decl(self):
         '''Declaration of the extension type in the .pb.h file'''
         if self.skip:
@@ -491,7 +566,7 @@ class Message:
         '''Return the maximum size that this message can take when encoded.
         If the size cannot be determined, returns None.
         '''
         '''Return the maximum size that this message can take when encoded.
         If the size cannot be determined, returns None.
         '''
-        size = 0
+        size = EncodedSize(0)
         for field in self.fields:
             fsize = field.encoded_size(allmsgs)
             if fsize is None:
         for field in self.fields:
             fsize = field.encoded_size(allmsgs)
             if fsize is None:
@@ -662,6 +737,8 @@ def generate_header(dependencies, headername, enums, messages, extensions, optio
     for msg in sort_dependencies(messages):
         for field in msg.fields:
             yield field.tags()
     for msg in sort_dependencies(messages):
         for field in msg.fields:
             yield field.tags()
+    for extension in extensions:
+        yield extension.tags()
     yield '\n'
     
     yield '/* Struct field encoding specification for nanopb */\n'
     yield '\n'
     
     yield '/* Struct field encoding specification for nanopb */\n'
@@ -674,7 +751,7 @@ def generate_header(dependencies, headername, enums, messages, extensions, optio
         msize = msg.encoded_size(messages)
         if msize is not None:
             identifier = '%s_size' % msg.name
         msize = msg.encoded_size(messages)
         if msize is not None:
             identifier = '%s_size' % msg.name
-            yield '#define %-40s %d\n' % (identifier, msize)
+            yield '#define %-40s %s\n' % (identifier, msize)
     yield '\n'
     
     yield '#ifdef __cplusplus\n'
     yield '\n'
     
     yield '#ifdef __cplusplus\n'
@@ -948,7 +1025,7 @@ def main_cli():
 def main_plugin():
     '''Main function when invoked as a protoc plugin.'''
 
 def main_plugin():
     '''Main function when invoked as a protoc plugin.'''
 
-    import plugin_pb2
+    import proto.plugin_pb2 as plugin_pb2
     data = sys.stdin.read()
     request = plugin_pb2.CodeGeneratorRequest.FromString(data)
     
     data = sys.stdin.read()
     request = plugin_pb2.CodeGeneratorRequest.FromString(data)