Generator bugfixes
authorPetteri Aimonen <jpa@npb.mail.kapsi.fi>
Wed, 24 Aug 2011 13:52:08 +0000 (13:52 +0000)
committerPetteri Aimonen <jpa@npb.mail.kapsi.fi>
Wed, 24 Aug 2011 13:52:08 +0000 (13:52 +0000)
git-svn-id: https://svn.kapsi.fi/jpa/nanopb@970 e3a754e5-d11d-0410-8d38-ebb782a927b9

generator/nanopb_generator.py

index f09be34..67c422d 100644 (file)
@@ -108,7 +108,8 @@ class Field:
         elif desc.type == FieldD.TYPE_ENUM:
             self.ltype = 'PB_LTYPE_VARINT'
             self.ctype = names_from_type_name(desc.type_name)
-            self.default = Names(self.ctype) + self.default
+            if self.default is not None:
+                self.default = self.ctype + self.default
         elif desc.type == FieldD.TYPE_STRING:
             self.ltype = 'PB_LTYPE_STRING'
             if self.max_size is None:
@@ -218,7 +219,7 @@ class Field:
             result += '\n    pb_membersize(%s, %s[0]),' % (self.struct_name, self.name)
             result += ('\n    pb_membersize(%s, %s) / pb_membersize(%s, %s[0]),'
                        % (self.struct_name, self.name, self.struct_name, self.name))
-        elif self.ltype == 'PB_LTYPE_BYTES':
+        elif self.htype != 'PB_HTYPE_CALLBACK' and self.ltype == 'PB_LTYPE_BYTES':
             result += '\n    pb_membersize(%s, bytes),' % self.ctype
             result += ' 0,'
         else:
@@ -240,24 +241,10 @@ class Message:
         self.fields = [Field(self.name, f) for f in desc.field]
         self.ordered_fields = self.fields[:]
         self.ordered_fields.sort()
-    
-    def __cmp__(self, other):
-        '''Sort messages so that submessages are declared before the message
-        that uses them.
-        '''
-        if self.refers_to(other.name):
-            return 1
-        elif other.refers_to(self.name):
-            return -1
-        else:
-            return 0
-    
-    def refers_to(self, name):
-        '''Returns True if this message uses the specified type as field type.'''
-        for field in self.fields:
-            if str(field.ctype) == str(name):
-                return True
-        return False
+
+    def get_dependencies(self):
+        '''Get list of type names that this structure refers to.'''
+        return [str(field.ctype) for field in self.fields]
     
     def __str__(self):
         result = 'typedef struct {\n'
@@ -317,16 +304,52 @@ def parse_file(fdesc):
     enums = []
     messages = []
     
+    if fdesc.package:
+        base_name = Names(fdesc.package.split('.'))
+    else:
+        base_name = Names()
+    
     for enum in fdesc.enum_type:
-        enums.append(Enum(Names(), enum))
+        enums.append(Enum(base_name, enum))
     
-    for names, message in iterate_messages(fdesc):
+    for names, message in iterate_messages(fdesc, base_name):
         messages.append(Message(names, message))
         for enum in message.enum_type:
             enums.append(Enum(names, enum))
     
     return enums, messages
 
+def toposort2(data):
+    '''Topological sort.
+    From http://code.activestate.com/recipes/577413-topological-sort/
+    This function is under the MIT license.
+    '''
+    for k, v in data.items():
+        v.discard(k) # Ignore self dependencies
+    extra_items_in_deps = reduce(set.union, data.values()) - set(data.keys())
+    data.update({item:set() for item in extra_items_in_deps})
+    while True:
+        ordered = set(item for item,dep in data.items() if not dep)
+        if not ordered:
+            break
+        for item in sorted(ordered):
+            yield item
+        data = {item: (dep - ordered) for item,dep in data.items()
+                if item not in ordered}
+    assert not data, "A cyclic dependency exists amongst %r" % data
+
+def sort_dependencies(messages):
+    '''Sort a list of Messages based on dependencies.'''
+    dependencies = {}
+    message_by_name = {}
+    for message in messages:
+        dependencies[str(message.name)] = set(message.get_dependencies())
+        message_by_name[str(message.name)] = message
+    
+    for msgname in toposort2(dependencies):
+        if msgname in message_by_name:
+            yield message_by_name[msgname]
+
 def generate_header(headername, enums, messages):
     '''Generate content for a header file.
     Generates strings, which should be concatenated and stored to file.
@@ -344,8 +367,7 @@ def generate_header(headername, enums, messages):
         yield str(enum) + '\n\n'
     
     yield '/* Struct definitions */\n'
-    messages.sort()
-    for msg in messages:
+    for msg in sort_dependencies(messages):
         yield msg.types()
         yield str(msg) + '\n\n'