Merge pull request #169 from kylemanna/python3
[apps/agl-service-can-low-level.git] / generator / nanopb_generator.py
index 4fd3a4d..7fe0db9 100755 (executable)
@@ -1,10 +1,13 @@
-#!/usr/bin/python
+#!/usr/bin/env python
+
+from __future__ import unicode_literals
 
 '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
 nanopb_version = "nanopb-0.3.5-dev"
 
 import sys
 import re
+from functools import reduce
 
 try:
     # Add some dummy imports to keep packaging tools happy.
@@ -82,7 +85,14 @@ class Names:
         return '_'.join(self.parts)
 
     def __add__(self, other):
-        if isinstance(other, (str, unicode)):
+        # The fdesc names are unicode and need to be handled for
+        # python2 and python3
+        try:
+              realstr = unicode
+        except NameError:
+              realstr = str
+
+        if isinstance(other, realstr):
             return Names(self.parts + (other,))
         elif isinstance(other, tuple):
             return Names(self.parts + other)
@@ -123,7 +133,7 @@ class EncodedSize:
         self.symbols = symbols
     
     def __add__(self, other):
-        if isinstance(other, (int, long)):
+        if isinstance(other, int):
             return EncodedSize(self.value + other, self.symbols)
         elif isinstance(other, (str, Names)):
             return EncodedSize(self.value, self.symbols + [str(other)])
@@ -133,7 +143,7 @@ class EncodedSize:
             raise ValueError("Cannot add size: " + repr(other))
 
     def __mul__(self, other):
-        if isinstance(other, (int, long)):
+        if isinstance(other, int):
             return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols])
         else:
             raise ValueError("Cannot multiply size: " + repr(other))
@@ -192,6 +202,24 @@ class Enum:
         
         return result
 
+class FieldMaxSize:
+    def __init__(self, worst = 0, checks = [], field_name = 'undefined'):
+        if isinstance(worst, list):
+            self.worst = max(i for i in worst if i is not None)
+        else:
+            self.worst = worst
+
+        self.worst_field = field_name
+        self.checks = checks
+
+    def extend(self, extend, field_name = None):
+        self.worst = max(self.worst, extend.worst)
+
+        if self.worst == extend.worst:
+            self.worst_field = extend.worst_field
+
+        self.checks.extend(extend.checks)
+
 class Field:
     def __init__(self, struct_name, desc, field_options):
         '''desc is FieldDescriptorProto'''
@@ -260,7 +288,7 @@ class Field:
             raise NotImplementedError(field_options.type)
         
         # Decide the C data type to use in the struct.
-        if datatypes.has_key(desc.type):
+        if desc.type in datatypes:
             self.ctype, self.pbtype, self.enc_size, isa = datatypes[desc.type]
 
             # Override the field size if user wants to use smaller integers
@@ -295,8 +323,8 @@ class Field:
         else:
             raise NotImplementedError(desc.type)
         
-    def __cmp__(self, other):
-        return cmp(self.tag, other.tag)
+    def __lt__(self, other):
+        return self.tag < other.tag
     
     def __str__(self):
         result = ''
@@ -360,12 +388,10 @@ class Field:
                 inner_init = '0'
         else:
             if self.pbtype == 'STRING':
-                inner_init = self.default.encode('utf-8').encode('string_escape')
-                inner_init = inner_init.replace('"', '\\"')
+                inner_init = self.default.replace('"', '\\"')
                 inner_init = '"' + inner_init + '"'
             elif self.pbtype == 'BYTES':
-                data = str(self.default).decode('string_escape')
-                data = ['0x%02x' % ord(c) for c in data]
+                data = ['0x%02x' % ord(c) for c in self.default]
                 if len(data) == 0:
                     inner_init = '{0, {0}}'
                 else:
@@ -467,15 +493,18 @@ class Field:
     def largest_field_value(self):
         '''Determine if this field needs 16bit or 32bit pb_field_t structure to compile properly.
         Returns numeric value or a C-expression for assert.'''
+        check = []
         if self.pbtype == 'MESSAGE':
             if self.rules == 'REPEATED' and self.allocation == 'STATIC':
-                return 'pb_membersize(%s, %s[0])' % (self.struct_name, self.name)
+                check.append('pb_membersize(%s, %s[0])' % (self.struct_name, self.name))
             elif self.rules == 'ONEOF':
-                return 'pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name)
+                check.append('pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name))
             else:
-                return 'pb_membersize(%s, %s)' % (self.struct_name, self.name)
+                check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name))
 
-        return max(self.tag, self.max_size, self.max_count)        
+        return FieldMaxSize([self.tag, self.max_size, self.max_count],
+                            check,
+                            ('%s.%s' % (self.struct_name, self.name)))
 
     def encoded_size(self, dependencies):
         '''Return the maximum size that this field can take when encoded,
@@ -639,9 +668,6 @@ class OneOf(Field):
         # Sort by the lowest tag number inside union
         self.tag = min([f.tag for f in self.fields])
 
-    def __cmp__(self, other):
-        return cmp(self.tag, other.tag)
-
     def __str__(self):
         result = ''
         if self.fields:
@@ -675,7 +701,10 @@ class OneOf(Field):
         return result
 
     def largest_field_value(self):
-        return max([f.largest_field_value() for f in self.fields])
+        largest = FieldMaxSize()
+        for f in self.fields:
+            largest.extend(f.largest_field_value())
+        return largest
 
     def encoded_size(self, dependencies):
         largest = EncodedSize(0)
@@ -875,17 +904,17 @@ def toposort2(data):
     From http://code.activestate.com/recipes/577413-topological-sort/
     This function is under the MIT license.
     '''
-    for k, v in data.items():
+    for k, v in list(data.items()):
         v.discard(k) # Ignore self dependencies
-    extra_items_in_deps = reduce(set.union, data.values(), set()) - set(data.keys())
+    extra_items_in_deps = reduce(set.union, list(data.values()), set()) - set(data.keys())
     data.update(dict([(item, set()) for item in extra_items_in_deps]))
     while True:
-        ordered = set(item for item,dep in data.items() if not dep)
+        ordered = set(item for item,dep in list(data.items()) if not dep)
         if not ordered:
             break
         for item in sorted(ordered):
             yield item
-        data = dict([(item, (dep - ordered)) for item,dep in data.items()
+        data = dict([(item, (dep - ordered)) for item,dep in list(data.items())
                 if item not in ordered])
     assert not data, "A cyclic dependency exists amongst %r" % data
 
@@ -1136,20 +1165,17 @@ class ProtoFile:
                 yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name
                 yield '       setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count
                 yield '#endif\n'
-        
-        worst = 0
-        worst_field = ''
-        checks = []
+
+        max_field = FieldMaxSize()
         checks_msgnames = []
         for msg in self.messages:
             checks_msgnames.append(msg.name)
             for field in msg.fields:
-                status = field.largest_field_value()
-                if isinstance(status, (str, unicode)):
-                    checks.append(status)
-                elif status > worst:
-                    worst = status
-                    worst_field = str(field.struct_name) + '.' + str(field.name)
+                max_field.extend(field.largest_field_value())
+
+        worst = max_field.worst
+        worst_field = max_field.worst_field
+        checks = max_field.checks
 
         if worst > 255 or checks:
             yield '\n/* Check that field information fits in pb_field_t */\n'
@@ -1237,7 +1263,7 @@ def read_options_file(infile):
         
         try:
             text_format.Merge(parts[1], opts)
-        except Exception, e:
+        except Exception as e:
             sys.stderr.write("%s:%d: " % (infile.name, i + 1) +
                              "Unparseable option line: '%s'. " % line +
                              "Error: %s\n" % str(e))
@@ -1439,14 +1465,15 @@ def main_cli():
 def main_plugin():
     '''Main function when invoked as a protoc plugin.'''
 
-    import sys
+    import io, sys
     if sys.platform == "win32":
         import os, msvcrt
         # Set stdin and stdout to binary mode
         msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
         msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
     
-    data = sys.stdin.read()
+    data = io.open(sys.stdin.fileno(), "rb").read()
+
     request = plugin_pb2.CodeGeneratorRequest.FromString(data)
     
     try:
@@ -1489,7 +1516,7 @@ def main_plugin():
                 f.name = results['sourcename']
                 f.content = results['sourcedata']    
     
-    sys.stdout.write(response.SerializeToString())
+    io.open(sys.stdout.fileno(), "wb").write(response.SerializeToString())
 
 if __name__ == '__main__':
     # Check if we are running as a plugin under protoc