generator: Use search $PATH for python
[apps/agl-service-can-low-level.git] / generator / nanopb_generator.py
1 #!/usr/bin/env python
2
3 from __future__ import unicode_literals
4
5 '''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
6 nanopb_version = "nanopb-0.3.4-dev"
7
8 import sys
9 import re
10 from functools import reduce
11
12 try:
13     # Add some dummy imports to keep packaging tools happy.
14     import google, distutils.util # bbfreeze seems to need these
15     import pkg_resources # pyinstaller / protobuf 2.5 seem to need these
16 except:
17     # Don't care, we will error out later if it is actually important.
18     pass
19
20 try:
21     import google.protobuf.text_format as text_format
22     import google.protobuf.descriptor_pb2 as descriptor
23 except:
24     sys.stderr.write('''
25          *************************************************************
26          *** Could not import the Google protobuf Python libraries ***
27          *** Try installing package 'python-protobuf' or similar.  ***
28          *************************************************************
29     ''' + '\n')
30     raise
31
32 try:
33     import proto.nanopb_pb2 as nanopb_pb2
34     import proto.plugin_pb2 as plugin_pb2
35 except:
36     sys.stderr.write('''
37          ********************************************************************
38          *** Failed to import the protocol definitions for generator.     ***
39          *** You have to run 'make' in the nanopb/generator/proto folder. ***
40          ********************************************************************
41     ''' + '\n')
42     raise
43
44 # ---------------------------------------------------------------------------
45 #                     Generation of single fields
46 # ---------------------------------------------------------------------------
47
48 import time
49 import os.path
50
51 # Values are tuple (c type, pb type, encoded size, int_size_allowed)
52 FieldD = descriptor.FieldDescriptorProto
53 datatypes = {
54     FieldD.TYPE_BOOL:       ('bool',     'BOOL',        1,  False),
55     FieldD.TYPE_DOUBLE:     ('double',   'DOUBLE',      8,  False),
56     FieldD.TYPE_FIXED32:    ('uint32_t', 'FIXED32',     4,  False),
57     FieldD.TYPE_FIXED64:    ('uint64_t', 'FIXED64',     8,  False),
58     FieldD.TYPE_FLOAT:      ('float',    'FLOAT',       4,  False),
59     FieldD.TYPE_INT32:      ('int32_t',  'INT32',      10,  True),
60     FieldD.TYPE_INT64:      ('int64_t',  'INT64',      10,  True),
61     FieldD.TYPE_SFIXED32:   ('int32_t',  'SFIXED32',    4,  False),
62     FieldD.TYPE_SFIXED64:   ('int64_t',  'SFIXED64',    8,  False),
63     FieldD.TYPE_SINT32:     ('int32_t',  'SINT32',      5,  True),
64     FieldD.TYPE_SINT64:     ('int64_t',  'SINT64',     10,  True),
65     FieldD.TYPE_UINT32:     ('uint32_t', 'UINT32',      5,  True),
66     FieldD.TYPE_UINT64:     ('uint64_t', 'UINT64',     10,  True)
67 }
68
69 # Integer size overrides (from .proto settings)
70 intsizes = {
71     nanopb_pb2.IS_8:     'int8_t',
72     nanopb_pb2.IS_16:    'int16_t',
73     nanopb_pb2.IS_32:    'int32_t',
74     nanopb_pb2.IS_64:    'int64_t',
75 }
76
77 class Names:
78     '''Keeps a set of nested names and formats them to C identifier.'''
79     def __init__(self, parts = ()):
80         if isinstance(parts, Names):
81             parts = parts.parts
82         self.parts = tuple(parts)
83     
84     def __str__(self):
85         return '_'.join(self.parts)
86
87     def __add__(self, other):
88         # The fdesc names are unicode and need to be handled for
89         # python2 and python3
90         try:
91               realstr = unicode
92         except NameError:
93               realstr = str
94
95         if isinstance(other, realstr):
96             return Names(self.parts + (other,))
97         elif isinstance(other, tuple):
98             return Names(self.parts + other)
99         else:
100             raise ValueError("Name parts should be of type str")
101     
102     def __eq__(self, other):
103         return isinstance(other, Names) and self.parts == other.parts
104     
105 def names_from_type_name(type_name):
106     '''Parse Names() from FieldDescriptorProto type_name'''
107     if type_name[0] != '.':
108         raise NotImplementedError("Lookup of non-absolute type names is not supported")
109     return Names(type_name[1:].split('.'))
110
111 def varint_max_size(max_value):
112     '''Returns the maximum number of bytes a varint can take when encoded.'''
113     if max_value < 0:
114         max_value = 2**64 - max_value
115     for i in range(1, 11):
116         if (max_value >> (i * 7)) == 0:
117             return i
118     raise ValueError("Value too large for varint: " + str(max_value))
119
120 assert varint_max_size(-1) == 10
121 assert varint_max_size(0) == 1
122 assert varint_max_size(127) == 1
123 assert varint_max_size(128) == 2
124
125 class EncodedSize:
126     '''Class used to represent the encoded size of a field or a message.
127     Consists of a combination of symbolic sizes and integer sizes.'''
128     def __init__(self, value = 0, symbols = []):
129         if isinstance(value, (str, Names)):
130             symbols = [str(value)]
131             value = 0
132         self.value = value
133         self.symbols = symbols
134     
135     def __add__(self, other):
136         if isinstance(other, int):
137             return EncodedSize(self.value + other, self.symbols)
138         elif isinstance(other, (str, Names)):
139             return EncodedSize(self.value, self.symbols + [str(other)])
140         elif isinstance(other, EncodedSize):
141             return EncodedSize(self.value + other.value, self.symbols + other.symbols)
142         else:
143             raise ValueError("Cannot add size: " + repr(other))
144
145     def __mul__(self, other):
146         if isinstance(other, int):
147             return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols])
148         else:
149             raise ValueError("Cannot multiply size: " + repr(other))
150
151     def __str__(self):
152         if not self.symbols:
153             return str(self.value)
154         else:
155             return '(' + str(self.value) + ' + ' + ' + '.join(self.symbols) + ')'
156
157     def upperlimit(self):
158         if not self.symbols:
159             return self.value
160         else:
161             return 2**32 - 1
162
163 class Enum:
164     def __init__(self, names, desc, enum_options):
165         '''desc is EnumDescriptorProto'''
166         
167         self.options = enum_options
168         self.names = names + desc.name
169         
170         if enum_options.long_names:
171             self.values = [(self.names + x.name, x.number) for x in desc.value]            
172         else:
173             self.values = [(names + x.name, x.number) for x in desc.value] 
174         
175         self.value_longnames = [self.names + x.name for x in desc.value]
176         self.packed = enum_options.packed_enum
177     
178     def has_negative(self):
179         for n, v in self.values:
180             if v < 0:
181                 return True
182         return False
183     
184     def encoded_size(self):
185         return max([varint_max_size(v) for n,v in self.values])
186     
187     def __str__(self):
188         result = 'typedef enum _%s {\n' % self.names
189         result += ',\n'.join(["    %s = %d" % x for x in self.values])
190         result += '\n}'
191         
192         if self.packed:
193             result += ' pb_packed'
194         
195         result += ' %s;' % self.names
196         
197         if not self.options.long_names:
198             # Define the long names always so that enum value references
199             # from other files work properly.
200             for i, x in enumerate(self.values):
201                 result += '\n#define %s %s' % (self.value_longnames[i], x[0])
202         
203         return result
204
205 class FieldMaxSize:
206     def __init__(self, worst = 0, checks = [], field_name = 'undefined'):
207         if isinstance(worst, list):
208             self.worst = max(i for i in worst if i is not None)
209         else:
210             self.worst = worst
211
212         self.worst_field = field_name
213         self.checks = checks
214
215     def extend(self, extend, field_name = None):
216         self.worst = max(self.worst, extend.worst)
217
218         if self.worst == extend.worst:
219             self.worst_field = extend.worst_field
220
221         self.checks.extend(extend.checks)
222
223 class Field:
224     def __init__(self, struct_name, desc, field_options):
225         '''desc is FieldDescriptorProto'''
226         self.tag = desc.number
227         self.struct_name = struct_name
228         self.union_name = None
229         self.name = desc.name
230         self.default = None
231         self.max_size = None
232         self.max_count = None
233         self.array_decl = ""
234         self.enc_size = None
235         self.ctype = None
236         
237         # Parse field options
238         if field_options.HasField("max_size"):
239             self.max_size = field_options.max_size
240         
241         if field_options.HasField("max_count"):
242             self.max_count = field_options.max_count
243         
244         if desc.HasField('default_value'):
245             self.default = desc.default_value
246            
247         # Check field rules, i.e. required/optional/repeated.
248         can_be_static = True
249         if desc.label == FieldD.LABEL_REQUIRED:
250             self.rules = 'REQUIRED'
251         elif desc.label == FieldD.LABEL_OPTIONAL:
252             self.rules = 'OPTIONAL'
253         elif desc.label == FieldD.LABEL_REPEATED:
254             self.rules = 'REPEATED'
255             if self.max_count is None:
256                 can_be_static = False
257             else:
258                 self.array_decl = '[%d]' % self.max_count
259         else:
260             raise NotImplementedError(desc.label)
261         
262         # Check if the field can be implemented with static allocation
263         # i.e. whether the data size is known.
264         if desc.type == FieldD.TYPE_STRING and self.max_size is None:
265             can_be_static = False
266         
267         if desc.type == FieldD.TYPE_BYTES and self.max_size is None:
268             can_be_static = False
269         
270         # Decide how the field data will be allocated
271         if field_options.type == nanopb_pb2.FT_DEFAULT:
272             if can_be_static:
273                 field_options.type = nanopb_pb2.FT_STATIC
274             else:
275                 field_options.type = nanopb_pb2.FT_CALLBACK
276         
277         if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static:
278             raise Exception("Field %s is defined as static, but max_size or "
279                             "max_count is not given." % self.name)
280         
281         if field_options.type == nanopb_pb2.FT_STATIC:
282             self.allocation = 'STATIC'
283         elif field_options.type == nanopb_pb2.FT_POINTER:
284             self.allocation = 'POINTER'
285         elif field_options.type == nanopb_pb2.FT_CALLBACK:
286             self.allocation = 'CALLBACK'
287         else:
288             raise NotImplementedError(field_options.type)
289         
290         # Decide the C data type to use in the struct.
291         if desc.type in datatypes:
292             self.ctype, self.pbtype, self.enc_size, isa = datatypes[desc.type]
293
294             # Override the field size if user wants to use smaller integers
295             if isa and field_options.int_size != nanopb_pb2.IS_DEFAULT:
296                 self.ctype = intsizes[field_options.int_size]
297                 if desc.type == FieldD.TYPE_UINT32 or desc.type == FieldD.TYPE_UINT64:
298                     self.ctype = 'u' + self.ctype;
299         elif desc.type == FieldD.TYPE_ENUM:
300             self.pbtype = 'ENUM'
301             self.ctype = names_from_type_name(desc.type_name)
302             if self.default is not None:
303                 self.default = self.ctype + self.default
304             self.enc_size = None # Needs to be filled in when enum values are known
305         elif desc.type == FieldD.TYPE_STRING:
306             self.pbtype = 'STRING'
307             self.ctype = 'char'
308             if self.allocation == 'STATIC':
309                 self.ctype = 'char'
310                 self.array_decl += '[%d]' % self.max_size
311                 self.enc_size = varint_max_size(self.max_size) + self.max_size
312         elif desc.type == FieldD.TYPE_BYTES:
313             self.pbtype = 'BYTES'
314             if self.allocation == 'STATIC':
315                 self.ctype = self.struct_name + self.name + 't'
316                 self.enc_size = varint_max_size(self.max_size) + self.max_size
317             elif self.allocation == 'POINTER':
318                 self.ctype = 'pb_bytes_array_t'
319         elif desc.type == FieldD.TYPE_MESSAGE:
320             self.pbtype = 'MESSAGE'
321             self.ctype = self.submsgname = names_from_type_name(desc.type_name)
322             self.enc_size = None # Needs to be filled in after the message type is available
323         else:
324             raise NotImplementedError(desc.type)
325         
326     def __lt__(self, other):
327         return self.tag < other.tag
328     
329     def __str__(self):
330         result = ''
331         if self.allocation == 'POINTER':
332             if self.rules == 'REPEATED':
333                 result += '    pb_size_t ' + self.name + '_count;\n'
334             
335             if self.pbtype == 'MESSAGE':
336                 # Use struct definition, so recursive submessages are possible
337                 result += '    struct _%s *%s;' % (self.ctype, self.name)
338             elif self.rules == 'REPEATED' and self.pbtype in ['STRING', 'BYTES']:
339                 # String/bytes arrays need to be defined as pointers to pointers
340                 result += '    %s **%s;' % (self.ctype, self.name)
341             else:
342                 result += '    %s *%s;' % (self.ctype, self.name)
343         elif self.allocation == 'CALLBACK':
344             result += '    pb_callback_t %s;' % self.name
345         else:
346             if self.rules == 'OPTIONAL' and self.allocation == 'STATIC':
347                 result += '    bool has_' + self.name + ';\n'
348             elif self.rules == 'REPEATED' and self.allocation == 'STATIC':
349                 result += '    pb_size_t ' + self.name + '_count;\n'
350             result += '    %s %s%s;' % (self.ctype, self.name, self.array_decl)
351         return result
352     
353     def types(self):
354         '''Return definitions for any special types this field might need.'''
355         if self.pbtype == 'BYTES' and self.allocation == 'STATIC':
356             result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, self.ctype)
357         else:
358             result = ''
359         return result
360     
361     def get_dependencies(self):
362         '''Get list of type names used by this field.'''
363         if self.allocation == 'STATIC':
364             return [str(self.ctype)]
365         else:
366             return []
367
368     def get_initializer(self, null_init, inner_init_only = False):
369         '''Return literal expression for this field's default value.
370         null_init: If True, initialize to a 0 value instead of default from .proto
371         inner_init_only: If True, exclude initialization for any count/has fields
372         '''
373
374         inner_init = None
375         if self.pbtype == 'MESSAGE':
376             if null_init:
377                 inner_init = '%s_init_zero' % self.ctype
378             else:
379                 inner_init = '%s_init_default' % self.ctype
380         elif self.default is None or null_init:
381             if self.pbtype == 'STRING':
382                 inner_init = '""'
383             elif self.pbtype == 'BYTES':
384                 inner_init = '{0, {0}}'
385             elif self.pbtype in ('ENUM', 'UENUM'):
386                 inner_init = '(%s)0' % self.ctype
387             else:
388                 inner_init = '0'
389         else:
390             if self.pbtype == 'STRING':
391                 inner_init = self.default.replace('"', '\\"')
392                 inner_init = '"' + inner_init + '"'
393             elif self.pbtype == 'BYTES':
394                 data = ['0x%02x' % ord(c) for c in self.default]
395                 if len(data) == 0:
396                     inner_init = '{0, {0}}'
397                 else:
398                     inner_init = '{%d, {%s}}' % (len(data), ','.join(data))
399             elif self.pbtype in ['FIXED32', 'UINT32']:
400                 inner_init = str(self.default) + 'u'
401             elif self.pbtype in ['FIXED64', 'UINT64']:
402                 inner_init = str(self.default) + 'ull'
403             elif self.pbtype in ['SFIXED64', 'INT64']:
404                 inner_init = str(self.default) + 'll'
405             else:
406                 inner_init = str(self.default)
407         
408         if inner_init_only:
409             return inner_init
410
411         outer_init = None
412         if self.allocation == 'STATIC':
413             if self.rules == 'REPEATED':
414                 outer_init = '0, {'
415                 outer_init += ', '.join([inner_init] * self.max_count)
416                 outer_init += '}'
417             elif self.rules == 'OPTIONAL':
418                 outer_init = 'false, ' + inner_init
419             else:
420                 outer_init = inner_init
421         elif self.allocation == 'POINTER':
422             if self.rules == 'REPEATED':
423                 outer_init = '0, NULL'
424             else:
425                 outer_init = 'NULL'
426         elif self.allocation == 'CALLBACK':
427             if self.pbtype == 'EXTENSION':
428                 outer_init = 'NULL'
429             else:
430                 outer_init = '{{NULL}, NULL}'
431
432         return outer_init
433
434     def default_decl(self, declaration_only = False):
435         '''Return definition for this field's default value.'''
436         if self.default is None:
437             return None
438
439         ctype = self.ctype
440         default = self.get_initializer(False, True)
441         array_decl = ''
442         
443         if self.pbtype == 'STRING':
444             if self.allocation != 'STATIC':
445                 return None # Not implemented
446             array_decl = '[%d]' % self.max_size
447         elif self.pbtype == 'BYTES':
448             if self.allocation != 'STATIC':
449                 return None # Not implemented
450         
451         if declaration_only:
452             return 'extern const %s %s_default%s;' % (ctype, self.struct_name + self.name, array_decl)
453         else:
454             return 'const %s %s_default%s = %s;' % (ctype, self.struct_name + self.name, array_decl, default)
455     
456     def tags(self):
457         '''Return the #define for the tag number of this field.'''
458         identifier = '%s_%s_tag' % (self.struct_name, self.name)
459         return '#define %-40s %d\n' % (identifier, self.tag)
460     
461     def pb_field_t(self, prev_field_name):
462         '''Return the pb_field_t initializer to use in the constant array.
463         prev_field_name is the name of the previous field or None.
464         '''
465
466         if self.rules == 'ONEOF':
467             result = '    PB_ONEOF_FIELD(%s, ' % self.union_name
468         else:
469             result = '    PB_FIELD('
470
471         result += '%3d, ' % self.tag
472         result += '%-8s, ' % self.pbtype
473         result += '%s, ' % self.rules
474         result += '%-8s, ' % self.allocation
475         result += '%s, ' % ("FIRST" if not prev_field_name else "OTHER")
476         result += '%s, ' % self.struct_name
477         result += '%s, ' % self.name
478         result += '%s, ' % (prev_field_name or self.name)
479         
480         if self.pbtype == 'MESSAGE':
481             result += '&%s_fields)' % self.submsgname
482         elif self.default is None:
483             result += '0)'
484         elif self.pbtype in ['BYTES', 'STRING'] and self.allocation != 'STATIC':
485             result += '0)' # Arbitrary size default values not implemented
486         elif self.rules == 'OPTEXT':
487             result += '0)' # Default value for extensions is not implemented
488         else:
489             result += '&%s_default)' % (self.struct_name + self.name)
490         
491         return result
492     
493     def largest_field_value(self):
494         '''Determine if this field needs 16bit or 32bit pb_field_t structure to compile properly.
495         Returns numeric value or a C-expression for assert.'''
496         check = []
497         if self.pbtype == 'MESSAGE':
498             if self.rules == 'REPEATED' and self.allocation == 'STATIC':
499                 check.append('pb_membersize(%s, %s[0])' % (self.struct_name, self.name))
500             elif self.rules == 'ONEOF':
501                 check.append('pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name))
502             else:
503                 check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name))
504
505         return FieldMaxSize([self.tag, self.max_size, self.max_count],
506                             check,
507                             ('%s.%s' % (self.struct_name, self.name)))
508
509     def encoded_size(self, dependencies):
510         '''Return the maximum size that this field can take when encoded,
511         including the field tag. If the size cannot be determined, returns
512         None.'''
513         
514         if self.allocation != 'STATIC':
515             return None
516         
517         if self.pbtype == 'MESSAGE':
518             if str(self.submsgname) in dependencies:
519                 submsg = dependencies[str(self.submsgname)]
520                 encsize = submsg.encoded_size(dependencies)
521                 if encsize is None:
522                     return None # Submessage size is indeterminate
523                     
524                 # Include submessage length prefix
525                 encsize += varint_max_size(encsize.upperlimit())
526             else:
527                 # Submessage cannot be found, this currently occurs when
528                 # the submessage type is defined in a different file and
529                 # not using the protoc plugin.
530                 # Instead of direct numeric value, reference the size that
531                 # has been #defined in the other file.
532                 encsize = EncodedSize(self.submsgname + 'size')
533
534                 # We will have to make a conservative assumption on the length
535                 # prefix size, though.
536                 encsize += 5
537
538         elif self.pbtype in ['ENUM', 'UENUM']:
539             if str(self.ctype) in dependencies:
540                 enumtype = dependencies[str(self.ctype)]
541                 encsize = enumtype.encoded_size()
542             else:
543                 # Conservative assumption
544                 encsize = 10
545
546         elif self.enc_size is None:
547             raise RuntimeError("Could not determine encoded size for %s.%s"
548                                % (self.struct_name, self.name))
549         else:
550             encsize = EncodedSize(self.enc_size)
551         
552         encsize += varint_max_size(self.tag << 3) # Tag + wire type
553
554         if self.rules == 'REPEATED':
555             # Decoders must be always able to handle unpacked arrays.
556             # Therefore we have to reserve space for it, even though
557             # we emit packed arrays ourselves.
558             encsize *= self.max_count
559         
560         return encsize
561
562
563 class ExtensionRange(Field):
564     def __init__(self, struct_name, range_start, field_options):
565         '''Implements a special pb_extension_t* field in an extensible message
566         structure. The range_start signifies the index at which the extensions
567         start. Not necessarily all tags above this are extensions, it is merely
568         a speed optimization.
569         '''
570         self.tag = range_start
571         self.struct_name = struct_name
572         self.name = 'extensions'
573         self.pbtype = 'EXTENSION'
574         self.rules = 'OPTIONAL'
575         self.allocation = 'CALLBACK'
576         self.ctype = 'pb_extension_t'
577         self.array_decl = ''
578         self.default = None
579         self.max_size = 0
580         self.max_count = 0
581         
582     def __str__(self):
583         return '    pb_extension_t *extensions;'
584     
585     def types(self):
586         return ''
587     
588     def tags(self):
589         return ''
590     
591     def encoded_size(self, dependencies):
592         # We exclude extensions from the count, because they cannot be known
593         # until runtime. Other option would be to return None here, but this
594         # way the value remains useful if extensions are not used.
595         return EncodedSize(0)
596
597 class ExtensionField(Field):
598     def __init__(self, struct_name, desc, field_options):
599         self.fullname = struct_name + desc.name
600         self.extendee_name = names_from_type_name(desc.extendee)
601         Field.__init__(self, self.fullname + 'struct', desc, field_options)
602         
603         if self.rules != 'OPTIONAL':
604             self.skip = True
605         else:
606             self.skip = False
607             self.rules = 'OPTEXT'
608
609     def tags(self):
610         '''Return the #define for the tag number of this field.'''
611         identifier = '%s_tag' % self.fullname
612         return '#define %-40s %d\n' % (identifier, self.tag)
613
614     def extension_decl(self):
615         '''Declaration of the extension type in the .pb.h file'''
616         if self.skip:
617             msg = '/* Extension field %s was skipped because only "optional"\n' % self.fullname
618             msg +='   type of extension fields is currently supported. */\n'
619             return msg
620         
621         return ('extern const pb_extension_type_t %s; /* field type: %s */\n' %
622             (self.fullname, str(self).strip()))
623
624     def extension_def(self):
625         '''Definition of the extension type in the .pb.c file'''
626
627         if self.skip:
628             return ''
629
630         result  = 'typedef struct {\n'
631         result += str(self)
632         result += '\n} %s;\n\n' % self.struct_name
633         result += ('static const pb_field_t %s_field = \n  %s;\n\n' %
634                     (self.fullname, self.pb_field_t(None)))
635         result += 'const pb_extension_type_t %s = {\n' % self.fullname
636         result += '    NULL,\n'
637         result += '    NULL,\n'
638         result += '    &%s_field\n' % self.fullname
639         result += '};\n'
640         return result
641
642
643 # ---------------------------------------------------------------------------
644 #                   Generation of oneofs (unions)
645 # ---------------------------------------------------------------------------
646
647 class OneOf(Field):
648     def __init__(self, struct_name, oneof_desc):
649         self.struct_name = struct_name
650         self.name = oneof_desc.name
651         self.ctype = 'union'
652         self.pbtype = 'oneof'
653         self.fields = []
654         self.allocation = 'ONEOF'
655         self.default = None
656         self.rules = 'ONEOF'
657
658     def add_field(self, field):
659         if field.allocation == 'CALLBACK':
660             raise Exception("Callback fields inside of oneof are not supported"
661                             + " (field %s)" % field.name)
662
663         field.union_name = self.name
664         field.rules = 'ONEOF'
665         self.fields.append(field)
666         self.fields.sort(key = lambda f: f.tag)
667
668         # Sort by the lowest tag number inside union
669         self.tag = min([f.tag for f in self.fields])
670
671     def __str__(self):
672         result = ''
673         if self.fields:
674             result += '    pb_size_t which_' + self.name + ";\n"
675             result += '    union {\n'
676             for f in self.fields:
677                 result += '    ' + str(f).replace('\n', '\n    ') + '\n'
678             result += '    } ' + self.name + ';'
679         return result
680
681     def types(self):
682         return ''.join([f.types() for f in self.fields])
683
684     def get_dependencies(self):
685         deps = []
686         for f in self.fields:
687             deps += f.get_dependencies()
688         return deps
689
690     def get_initializer(self, null_init):
691         return '0, {' + self.fields[0].get_initializer(null_init) + '}'
692
693     def default_decl(self, declaration_only = False):
694         return None
695
696     def tags(self):
697         return '\n'.join([f.tags() for f in self.fields])
698
699     def pb_field_t(self, prev_field_name):
700         result = ',\n'.join([f.pb_field_t(prev_field_name) for f in self.fields])
701         return result
702
703     def largest_field_value(self):
704         largest = FieldMaxSize()
705         for f in self.fields:
706             largest.extend(f.largest_field_value())
707         return largest
708
709     def encoded_size(self, dependencies):
710         largest = EncodedSize(0)
711         for f in self.fields:
712             size = f.encoded_size(dependencies)
713             if size is None:
714                 return None
715             elif size.symbols:
716                 return None # Cannot resolve maximum of symbols
717             elif size.value > largest.value:
718                 largest = size
719
720         return largest
721
722 # ---------------------------------------------------------------------------
723 #                   Generation of messages (structures)
724 # ---------------------------------------------------------------------------
725
726
727 class Message:
728     def __init__(self, names, desc, message_options):
729         self.name = names
730         self.fields = []
731         self.oneofs = {}
732         no_unions = []
733
734         if message_options.msgid:
735             self.msgid = message_options.msgid
736
737         if hasattr(desc, 'oneof_decl'):
738             for i, f in enumerate(desc.oneof_decl):
739                 oneof_options = get_nanopb_suboptions(desc, message_options, self.name + f.name)
740                 if oneof_options.no_unions:
741                     no_unions.append(i) # No union, but add fields normally
742                 elif oneof_options.type == nanopb_pb2.FT_IGNORE:
743                     pass # No union and skip fields also
744                 else:
745                     oneof = OneOf(self.name, f)
746                     self.oneofs[i] = oneof
747                     self.fields.append(oneof)
748
749         for f in desc.field:
750             field_options = get_nanopb_suboptions(f, message_options, self.name + f.name)
751             if field_options.type == nanopb_pb2.FT_IGNORE:
752                 continue
753
754             field = Field(self.name, f, field_options)
755             if (hasattr(f, 'oneof_index') and
756                 f.HasField('oneof_index') and
757                 f.oneof_index not in no_unions):
758                 if f.oneof_index in self.oneofs:
759                     self.oneofs[f.oneof_index].add_field(field)
760             else:
761                 self.fields.append(field)
762         
763         if len(desc.extension_range) > 0:
764             field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions')
765             range_start = min([r.start for r in desc.extension_range])
766             if field_options.type != nanopb_pb2.FT_IGNORE:
767                 self.fields.append(ExtensionRange(self.name, range_start, field_options))
768         
769         self.packed = message_options.packed_struct
770         self.ordered_fields = self.fields[:]
771         self.ordered_fields.sort()
772
773     def get_dependencies(self):
774         '''Get list of type names that this structure refers to.'''
775         deps = []
776         for f in self.fields:
777             deps += f.get_dependencies()
778         return deps
779     
780     def __str__(self):
781         result = 'typedef struct _%s {\n' % self.name
782
783         if not self.ordered_fields:
784             # Empty structs are not allowed in C standard.
785             # Therefore add a dummy field if an empty message occurs.
786             result += '    uint8_t dummy_field;'
787
788         result += '\n'.join([str(f) for f in self.ordered_fields])
789         result += '\n}'
790         
791         if self.packed:
792             result += ' pb_packed'
793         
794         result += ' %s;' % self.name
795         
796         if self.packed:
797             result = 'PB_PACKED_STRUCT_START\n' + result
798             result += '\nPB_PACKED_STRUCT_END'
799         
800         return result
801     
802     def types(self):
803         return ''.join([f.types() for f in self.fields])
804
805     def get_initializer(self, null_init):
806         if not self.ordered_fields:
807             return '{0}'
808     
809         parts = []
810         for field in self.ordered_fields:
811             parts.append(field.get_initializer(null_init))
812         return '{' + ', '.join(parts) + '}'
813     
814     def default_decl(self, declaration_only = False):
815         result = ""
816         for field in self.fields:
817             default = field.default_decl(declaration_only)
818             if default is not None:
819                 result += default + '\n'
820         return result
821
822     def count_required_fields(self):
823         '''Returns number of required fields inside this message'''
824         count = 0
825         for f in self.fields:
826             if not isinstance(f, OneOf):
827                 if f.rules == 'REQUIRED':
828                     count += 1
829         return count
830
831     def count_all_fields(self):
832         count = 0
833         for f in self.fields:
834             if isinstance(f, OneOf):
835                 count += len(f.fields)
836             else:
837                 count += 1
838         return count
839
840     def fields_declaration(self):
841         result = 'extern const pb_field_t %s_fields[%d];' % (self.name, self.count_all_fields() + 1)
842         return result
843
844     def fields_definition(self):
845         result = 'const pb_field_t %s_fields[%d] = {\n' % (self.name, self.count_all_fields() + 1)
846         
847         prev = None
848         for field in self.ordered_fields:
849             result += field.pb_field_t(prev)
850             result += ',\n'
851             if isinstance(field, OneOf):
852                 prev = field.name + '.' + field.fields[-1].name
853             else:
854                 prev = field.name
855         
856         result += '    PB_LAST_FIELD\n};'
857         return result
858
859     def encoded_size(self, dependencies):
860         '''Return the maximum size that this message can take when encoded.
861         If the size cannot be determined, returns None.
862         '''
863         size = EncodedSize(0)
864         for field in self.fields:
865             fsize = field.encoded_size(dependencies)
866             if fsize is None:
867                 return None
868             size += fsize
869         
870         return size
871
872
873 # ---------------------------------------------------------------------------
874 #                    Processing of entire .proto files
875 # ---------------------------------------------------------------------------
876
877 def iterate_messages(desc, names = Names()):
878     '''Recursively find all messages. For each, yield name, DescriptorProto.'''
879     if hasattr(desc, 'message_type'):
880         submsgs = desc.message_type
881     else:
882         submsgs = desc.nested_type
883     
884     for submsg in submsgs:
885         sub_names = names + submsg.name
886         yield sub_names, submsg
887         
888         for x in iterate_messages(submsg, sub_names):
889             yield x
890
891 def iterate_extensions(desc, names = Names()):
892     '''Recursively find all extensions.
893     For each, yield name, FieldDescriptorProto.
894     '''
895     for extension in desc.extension:
896         yield names, extension
897
898     for subname, subdesc in iterate_messages(desc, names):
899         for extension in subdesc.extension:
900             yield subname, extension
901
902 def toposort2(data):
903     '''Topological sort.
904     From http://code.activestate.com/recipes/577413-topological-sort/
905     This function is under the MIT license.
906     '''
907     for k, v in list(data.items()):
908         v.discard(k) # Ignore self dependencies
909     extra_items_in_deps = reduce(set.union, list(data.values()), set()) - set(data.keys())
910     data.update(dict([(item, set()) for item in extra_items_in_deps]))
911     while True:
912         ordered = set(item for item,dep in list(data.items()) if not dep)
913         if not ordered:
914             break
915         for item in sorted(ordered):
916             yield item
917         data = dict([(item, (dep - ordered)) for item,dep in list(data.items())
918                 if item not in ordered])
919     assert not data, "A cyclic dependency exists amongst %r" % data
920
921 def sort_dependencies(messages):
922     '''Sort a list of Messages based on dependencies.'''
923     dependencies = {}
924     message_by_name = {}
925     for message in messages:
926         dependencies[str(message.name)] = set(message.get_dependencies())
927         message_by_name[str(message.name)] = message
928     
929     for msgname in toposort2(dependencies):
930         if msgname in message_by_name:
931             yield message_by_name[msgname]
932
933 def make_identifier(headername):
934     '''Make #ifndef identifier that contains uppercase A-Z and digits 0-9'''
935     result = ""
936     for c in headername.upper():
937         if c.isalnum():
938             result += c
939         else:
940             result += '_'
941     return result
942
943 class ProtoFile:
944     def __init__(self, fdesc, file_options):
945         '''Takes a FileDescriptorProto and parses it.'''
946         self.fdesc = fdesc
947         self.file_options = file_options
948         self.dependencies = {}
949         self.parse()
950         
951         # Some of types used in this file probably come from the file itself.
952         # Thus it has implicit dependency on itself.
953         self.add_dependency(self)
954
955     def parse(self):
956         self.enums = []
957         self.messages = []
958         self.extensions = []
959         
960         if self.fdesc.package:
961             base_name = Names(self.fdesc.package.split('.'))
962         else:
963             base_name = Names()
964     
965         for enum in self.fdesc.enum_type:
966             enum_options = get_nanopb_suboptions(enum, self.file_options, base_name + enum.name)
967             self.enums.append(Enum(base_name, enum, enum_options))
968         
969         for names, message in iterate_messages(self.fdesc, base_name):
970             message_options = get_nanopb_suboptions(message, self.file_options, names)
971             
972             if message_options.skip_message:
973                 continue
974        
975             self.messages.append(Message(names, message, message_options))
976             for enum in message.enum_type:
977                 enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name)
978                 self.enums.append(Enum(names, enum, enum_options))
979         
980         for names, extension in iterate_extensions(self.fdesc, base_name):
981             field_options = get_nanopb_suboptions(extension, self.file_options, names + extension.name)
982             if field_options.type != nanopb_pb2.FT_IGNORE:
983                 self.extensions.append(ExtensionField(names, extension, field_options))
984     
985     def add_dependency(self, other):
986         for enum in other.enums:
987             self.dependencies[str(enum.names)] = enum
988         
989         for msg in other.messages:
990             self.dependencies[str(msg.name)] = msg
991         
992         # Fix field default values where enum short names are used.
993         for enum in other.enums:
994             if not enum.options.long_names:
995                 for message in self.messages:
996                     for field in message.fields:
997                         if field.default in enum.value_longnames:
998                             idx = enum.value_longnames.index(field.default)
999                             field.default = enum.values[idx][0]
1000         
1001         # Fix field data types where enums have negative values.
1002         for enum in other.enums:
1003             if not enum.has_negative():
1004                 for message in self.messages:
1005                     for field in message.fields:
1006                         if field.pbtype == 'ENUM' and field.ctype == enum.names:
1007                             field.pbtype = 'UENUM'
1008
1009     def generate_header(self, includes, headername, options):
1010         '''Generate content for a header file.
1011         Generates strings, which should be concatenated and stored to file.
1012         '''
1013         
1014         yield '/* Automatically generated nanopb header */\n'
1015         if options.notimestamp:
1016             yield '/* Generated by %s */\n\n' % (nanopb_version)
1017         else:
1018             yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
1019         
1020         symbol = make_identifier(headername)
1021         yield '#ifndef PB_%s_INCLUDED\n' % symbol
1022         yield '#define PB_%s_INCLUDED\n' % symbol
1023         try:
1024             yield options.libformat % ('pb.h')
1025         except TypeError:
1026             # no %s specified - use whatever was passed in as options.libformat
1027             yield options.libformat
1028         yield '\n'
1029         
1030         for incfile in includes:
1031             noext = os.path.splitext(incfile)[0]
1032             yield options.genformat % (noext + options.extension + '.h')
1033             yield '\n'
1034
1035         yield '#if PB_PROTO_HEADER_VERSION != 30\n'
1036         yield '#error Regenerate this file with the current version of nanopb generator.\n'
1037         yield '#endif\n'
1038         yield '\n'
1039
1040         yield '#ifdef __cplusplus\n'
1041         yield 'extern "C" {\n'
1042         yield '#endif\n\n'
1043         
1044         if self.enums:
1045             yield '/* Enum definitions */\n'
1046             for enum in self.enums:
1047                 yield str(enum) + '\n\n'
1048         
1049         if self.messages:
1050             yield '/* Struct definitions */\n'
1051             for msg in sort_dependencies(self.messages):
1052                 yield msg.types()
1053                 yield str(msg) + '\n\n'
1054         
1055         if self.extensions:
1056             yield '/* Extensions */\n'
1057             for extension in self.extensions:
1058                 yield extension.extension_decl()
1059             yield '\n'
1060         
1061         if self.messages:
1062             yield '/* Default values for struct fields */\n'
1063             for msg in self.messages:
1064                 yield msg.default_decl(True)
1065             yield '\n'
1066         
1067             yield '/* Initializer values for message structs */\n'
1068             for msg in self.messages:
1069                 identifier = '%s_init_default' % msg.name
1070                 yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False))
1071             for msg in self.messages:
1072                 identifier = '%s_init_zero' % msg.name
1073                 yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True))
1074             yield '\n'
1075         
1076             yield '/* Field tags (for use in manual encoding/decoding) */\n'
1077             for msg in sort_dependencies(self.messages):
1078                 for field in msg.fields:
1079                     yield field.tags()
1080             for extension in self.extensions:
1081                 yield extension.tags()
1082             yield '\n'
1083         
1084             yield '/* Struct field encoding specification for nanopb */\n'
1085             for msg in self.messages:
1086                 yield msg.fields_declaration() + '\n'
1087             yield '\n'
1088         
1089             yield '/* Maximum encoded size of messages (where known) */\n'
1090             for msg in self.messages:
1091                 msize = msg.encoded_size(self.dependencies)
1092                 if msize is not None:
1093                     identifier = '%s_size' % msg.name
1094                     yield '#define %-40s %s\n' % (identifier, msize)
1095             yield '\n'
1096
1097             yield '/* Message IDs (where set with "msgid" option) */\n'
1098             
1099             yield '#ifdef PB_MSGID\n'
1100             for msg in self.messages:
1101                 if hasattr(msg,'msgid'):
1102                     yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name)
1103             yield '\n'
1104
1105             symbol = make_identifier(headername.split('.')[0])
1106             yield '#define %s_MESSAGES \\\n' % symbol
1107
1108             for msg in self.messages:
1109                 m = "-1"
1110                 msize = msg.encoded_size(self.dependencies)
1111                 if msize is not None:
1112                     m = msize
1113                 if hasattr(msg,'msgid'):
1114                     yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name)
1115             yield '\n'
1116
1117             for msg in self.messages:
1118                 if hasattr(msg,'msgid'):
1119                     yield '#define %s_msgid %d\n' % (msg.name, msg.msgid)
1120             yield '\n'
1121
1122             yield '#endif\n\n'
1123
1124         yield '#ifdef __cplusplus\n'
1125         yield '} /* extern "C" */\n'
1126         yield '#endif\n'
1127         
1128         # End of header
1129         yield '\n#endif\n'
1130
1131     def generate_source(self, headername, options):
1132         '''Generate content for a source file.'''
1133         
1134         yield '/* Automatically generated nanopb constant definitions */\n'
1135         if options.notimestamp:
1136             yield '/* Generated by %s */\n\n' % (nanopb_version)
1137         else:
1138             yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
1139         yield options.genformat % (headername)
1140         yield '\n'
1141         
1142         yield '#if PB_PROTO_HEADER_VERSION != 30\n'
1143         yield '#error Regenerate this file with the current version of nanopb generator.\n'
1144         yield '#endif\n'
1145         yield '\n'
1146         
1147         for msg in self.messages:
1148             yield msg.default_decl(False)
1149         
1150         yield '\n\n'
1151         
1152         for msg in self.messages:
1153             yield msg.fields_definition() + '\n\n'
1154         
1155         for ext in self.extensions:
1156             yield ext.extension_def() + '\n'
1157             
1158         # Add checks for numeric limits
1159         if self.messages:
1160             largest_msg = max(self.messages, key = lambda m: m.count_required_fields())
1161             largest_count = largest_msg.count_required_fields()
1162             if largest_count > 64:
1163                 yield '\n/* Check that missing required fields will be properly detected */\n'
1164                 yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count
1165                 yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name
1166                 yield '       setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count
1167                 yield '#endif\n'
1168
1169         max_field = FieldMaxSize()
1170         checks_msgnames = []
1171         for msg in self.messages:
1172             checks_msgnames.append(msg.name)
1173             for field in msg.fields:
1174                 max_field.extend(field.largest_field_value())
1175
1176         worst = max_field.worst
1177         worst_field = max_field.worst_field
1178         checks = max_field.checks
1179
1180         if worst > 255 or checks:
1181             yield '\n/* Check that field information fits in pb_field_t */\n'
1182             
1183             if worst > 65535 or checks:
1184                 yield '#if !defined(PB_FIELD_32BIT)\n'
1185                 if worst > 65535:
1186                     yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field
1187                 else:
1188                     assertion = ' && '.join(str(c) + ' < 65536' for c in checks)
1189                     msgs = '_'.join(str(n) for n in checks_msgnames)
1190                     yield '/* If you get an error here, it means that you need to define PB_FIELD_32BIT\n'
1191                     yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n'
1192                     yield ' * \n'
1193                     yield ' * The reason you need to do this is that some of your messages contain tag\n'
1194                     yield ' * numbers or field sizes that are larger than what can fit in 8 or 16 bit\n'
1195                     yield ' * field descriptors.\n'
1196                     yield ' */\n'
1197                     yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
1198                 yield '#endif\n\n'
1199             
1200             if worst < 65536:
1201                 yield '#if !defined(PB_FIELD_16BIT) && !defined(PB_FIELD_32BIT)\n'
1202                 if worst > 255:
1203                     yield '#error Field descriptor for %s is too large. Define PB_FIELD_16BIT to fix this.\n' % worst_field
1204                 else:
1205                     assertion = ' && '.join(str(c) + ' < 256' for c in checks)
1206                     msgs = '_'.join(str(n) for n in checks_msgnames)
1207                     yield '/* If you get an error here, it means that you need to define PB_FIELD_16BIT\n'
1208                     yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n'
1209                     yield ' * \n'
1210                     yield ' * The reason you need to do this is that some of your messages contain tag\n'
1211                     yield ' * numbers or field sizes that are larger than what can fit in the default\n'
1212                     yield ' * 8 bit descriptors.\n'
1213                     yield ' */\n'
1214                     yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
1215                 yield '#endif\n\n'
1216         
1217         # Add check for sizeof(double)
1218         has_double = False
1219         for msg in self.messages:
1220             for field in msg.fields:
1221                 if field.ctype == 'double':
1222                     has_double = True
1223         
1224         if has_double:
1225             yield '\n'
1226             yield '/* On some platforms (such as AVR), double is really float.\n'
1227             yield ' * These are not directly supported by nanopb, but see example_avr_double.\n'
1228             yield ' * To get rid of this error, remove any double fields from your .proto.\n'
1229             yield ' */\n'
1230             yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n'
1231         
1232         yield '\n'
1233
1234 # ---------------------------------------------------------------------------
1235 #                    Options parsing for the .proto files
1236 # ---------------------------------------------------------------------------
1237
1238 from fnmatch import fnmatch
1239
1240 def read_options_file(infile):
1241     '''Parse a separate options file to list:
1242         [(namemask, options), ...]
1243     '''
1244     results = []
1245     data = infile.read()
1246     data = re.sub('/\*.*?\*/', '', data, flags = re.MULTILINE)
1247     data = re.sub('//.*?$', '', data, flags = re.MULTILINE)
1248     data = re.sub('#.*?$', '', data, flags = re.MULTILINE)
1249     for i, line in enumerate(data.split('\n')):
1250         line = line.strip()
1251         if not line:
1252             continue
1253         
1254         parts = line.split(None, 1)
1255         
1256         if len(parts) < 2:
1257             sys.stderr.write("%s:%d: " % (infile.name, i + 1) +
1258                              "Option lines should have space between field name and options. " +
1259                              "Skipping line: '%s'\n" % line)
1260             continue
1261         
1262         opts = nanopb_pb2.NanoPBOptions()
1263         
1264         try:
1265             text_format.Merge(parts[1], opts)
1266         except Exception as e:
1267             sys.stderr.write("%s:%d: " % (infile.name, i + 1) +
1268                              "Unparseable option line: '%s'. " % line +
1269                              "Error: %s\n" % str(e))
1270             continue
1271         results.append((parts[0], opts))
1272
1273     return results
1274
1275 class Globals:
1276     '''Ugly global variables, should find a good way to pass these.'''
1277     verbose_options = False
1278     separate_options = []
1279     matched_namemasks = set()
1280
1281 def get_nanopb_suboptions(subdesc, options, name):
1282     '''Get copy of options, and merge information from subdesc.'''
1283     new_options = nanopb_pb2.NanoPBOptions()
1284     new_options.CopyFrom(options)
1285     
1286     # Handle options defined in a separate file
1287     dotname = '.'.join(name.parts)
1288     for namemask, options in Globals.separate_options:
1289         if fnmatch(dotname, namemask):
1290             Globals.matched_namemasks.add(namemask)
1291             new_options.MergeFrom(options)
1292     
1293     # Handle options defined in .proto
1294     if isinstance(subdesc.options, descriptor.FieldOptions):
1295         ext_type = nanopb_pb2.nanopb
1296     elif isinstance(subdesc.options, descriptor.FileOptions):
1297         ext_type = nanopb_pb2.nanopb_fileopt
1298     elif isinstance(subdesc.options, descriptor.MessageOptions):
1299         ext_type = nanopb_pb2.nanopb_msgopt
1300     elif isinstance(subdesc.options, descriptor.EnumOptions):
1301         ext_type = nanopb_pb2.nanopb_enumopt
1302     else:
1303         raise Exception("Unknown options type")
1304     
1305     if subdesc.options.HasExtension(ext_type):
1306         ext = subdesc.options.Extensions[ext_type]
1307         new_options.MergeFrom(ext)
1308     
1309     if Globals.verbose_options:
1310         sys.stderr.write("Options for " + dotname + ": ")
1311         sys.stderr.write(text_format.MessageToString(new_options) + "\n")
1312     
1313     return new_options
1314
1315
1316 # ---------------------------------------------------------------------------
1317 #                         Command line interface
1318 # ---------------------------------------------------------------------------
1319
1320 import sys
1321 import os.path    
1322 from optparse import OptionParser
1323
1324 optparser = OptionParser(
1325     usage = "Usage: nanopb_generator.py [options] file.pb ...",
1326     epilog = "Compile file.pb from file.proto by: 'protoc -ofile.pb file.proto'. " +
1327              "Output will be written to file.pb.h and file.pb.c.")
1328 optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[],
1329     help="Exclude file from generated #include list.")
1330 optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default=".pb",
1331     help="Set extension to use instead of '.pb' for generated files. [default: %default]")
1332 optparser.add_option("-f", "--options-file", dest="options_file", metavar="FILE", default="%s.options",
1333     help="Set name of a separate generator options file.")
1334 optparser.add_option("-I", "--options-path", dest="options_path", metavar="DIR",
1335     action="append", default = [],
1336     help="Search for .options files additionally in this path")
1337 optparser.add_option("-Q", "--generated-include-format", dest="genformat",
1338     metavar="FORMAT", default='#include "%s"\n',
1339     help="Set format string to use for including other .pb.h files. [default: %default]")
1340 optparser.add_option("-L", "--library-include-format", dest="libformat",
1341     metavar="FORMAT", default='#include <%s>\n',
1342     help="Set format string to use for including the nanopb pb.h header. [default: %default]")
1343 optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=False,
1344     help="Don't add timestamp to .pb.h and .pb.c preambles")
1345 optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False,
1346     help="Don't print anything except errors.")
1347 optparser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False,
1348     help="Print more information.")
1349 optparser.add_option("-s", dest="settings", metavar="OPTION:VALUE", action="append", default=[],
1350     help="Set generator option (max_size, max_count etc.).")
1351
1352 def parse_file(filename, fdesc, options):
1353     '''Parse a single file. Returns a ProtoFile instance.'''
1354     toplevel_options = nanopb_pb2.NanoPBOptions()
1355     for s in options.settings:
1356         text_format.Merge(s, toplevel_options)
1357     
1358     if not fdesc:
1359         data = open(filename, 'rb').read()
1360         fdesc = descriptor.FileDescriptorSet.FromString(data).file[0]
1361     
1362     # Check if there is a separate .options file
1363     had_abspath = False
1364     try:
1365         optfilename = options.options_file % os.path.splitext(filename)[0]
1366     except TypeError:
1367         # No %s specified, use the filename as-is
1368         optfilename = options.options_file
1369         had_abspath = True
1370
1371     paths = ['.'] + options.options_path
1372     for p in paths:
1373         if os.path.isfile(os.path.join(p, optfilename)):
1374             optfilename = os.path.join(p, optfilename)
1375             if options.verbose:
1376                 sys.stderr.write('Reading options from ' + optfilename + '\n')
1377             Globals.separate_options = read_options_file(open(optfilename, "rU"))
1378             break
1379     else:
1380         # If we are given a full filename and it does not exist, give an error.
1381         # However, don't give error when we automatically look for .options file
1382         # with the same name as .proto.
1383         if options.verbose or had_abspath:
1384             sys.stderr.write('Options file not found: ' + optfilename + '\n')
1385         Globals.separate_options = []
1386
1387     Globals.matched_namemasks = set()
1388     
1389     # Parse the file
1390     file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename]))
1391     f = ProtoFile(fdesc, file_options)
1392     f.optfilename = optfilename
1393     
1394     return f
1395
1396 def process_file(filename, fdesc, options, other_files = {}):
1397     '''Process a single file.
1398     filename: The full path to the .proto or .pb source file, as string.
1399     fdesc: The loaded FileDescriptorSet, or None to read from the input file.
1400     options: Command line options as they come from OptionsParser.
1401     
1402     Returns a dict:
1403         {'headername': Name of header file,
1404          'headerdata': Data for the .h header file,
1405          'sourcename': Name of the source code file,
1406          'sourcedata': Data for the .c source code file
1407         }
1408     '''
1409     f = parse_file(filename, fdesc, options)
1410
1411     # Provide dependencies if available
1412     for dep in f.fdesc.dependency:
1413         if dep in other_files:
1414             f.add_dependency(other_files[dep])
1415
1416     # Decide the file names
1417     noext = os.path.splitext(filename)[0]
1418     headername = noext + options.extension + '.h'
1419     sourcename = noext + options.extension + '.c'
1420     headerbasename = os.path.basename(headername)
1421     
1422     # List of .proto files that should not be included in the C header file
1423     # even if they are mentioned in the source .proto.
1424     excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto'] + options.exclude
1425     includes = [d for d in f.fdesc.dependency if d not in excludes]
1426     
1427     headerdata = ''.join(f.generate_header(includes, headerbasename, options))
1428     sourcedata = ''.join(f.generate_source(headerbasename, options))
1429
1430     # Check if there were any lines in .options that did not match a member
1431     unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks]
1432     if unmatched and not options.quiet:
1433         sys.stderr.write("Following patterns in " + f.optfilename + " did not match any fields: "
1434                          + ', '.join(unmatched) + "\n")
1435         if not Globals.verbose_options:
1436             sys.stderr.write("Use  protoc --nanopb-out=-v:.   to see a list of the field names.\n")
1437
1438     return {'headername': headername, 'headerdata': headerdata,
1439             'sourcename': sourcename, 'sourcedata': sourcedata}
1440     
1441 def main_cli():
1442     '''Main function when invoked directly from the command line.'''
1443     
1444     options, filenames = optparser.parse_args()
1445     
1446     if not filenames:
1447         optparser.print_help()
1448         sys.exit(1)
1449     
1450     if options.quiet:
1451         options.verbose = False
1452
1453     Globals.verbose_options = options.verbose
1454     
1455     for filename in filenames:
1456         results = process_file(filename, None, options)
1457         
1458         if not options.quiet:
1459             sys.stderr.write("Writing to " + results['headername'] + " and "
1460                              + results['sourcename'] + "\n")
1461     
1462         open(results['headername'], 'w').write(results['headerdata'])
1463         open(results['sourcename'], 'w').write(results['sourcedata'])        
1464
1465 def main_plugin():
1466     '''Main function when invoked as a protoc plugin.'''
1467
1468     import io, sys
1469     if sys.platform == "win32":
1470         import os, msvcrt
1471         # Set stdin and stdout to binary mode
1472         msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
1473         msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
1474     
1475     data = io.open(sys.stdin.fileno(), "rb").read()
1476
1477     request = plugin_pb2.CodeGeneratorRequest.FromString(data)
1478     
1479     try:
1480         # Versions of Python prior to 2.7.3 do not support unicode
1481         # input to shlex.split(). Try to convert to str if possible.
1482         params = str(request.parameter)
1483     except UnicodeEncodeError:
1484         params = request.parameter
1485     
1486     import shlex
1487     args = shlex.split(params)
1488     options, dummy = optparser.parse_args(args)
1489     
1490     Globals.verbose_options = options.verbose
1491     
1492     response = plugin_pb2.CodeGeneratorResponse()
1493     
1494     # Google's protoc does not currently indicate the full path of proto files.
1495     # Instead always add the main file path to the search dirs, that works for
1496     # the common case.
1497     import os.path
1498     options.options_path.append(os.path.dirname(request.file_to_generate[0]))
1499     
1500     # Process any include files first, in order to have them
1501     # available as dependencies
1502     other_files = {}
1503     for fdesc in request.proto_file:
1504         other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options)
1505     
1506     for filename in request.file_to_generate:
1507         for fdesc in request.proto_file:
1508             if fdesc.name == filename:
1509                 results = process_file(filename, fdesc, options, other_files)
1510                 
1511                 f = response.file.add()
1512                 f.name = results['headername']
1513                 f.content = results['headerdata']
1514
1515                 f = response.file.add()
1516                 f.name = results['sourcename']
1517                 f.content = results['sourcedata']    
1518     
1519     io.open(sys.stdout.fileno(), "wb").write(response.SerializeToString())
1520
1521 if __name__ == '__main__':
1522     # Check if we are running as a plugin under protoc
1523     if 'protoc-gen-' in sys.argv[0] or '--protoc-plugin' in sys.argv:
1524         main_plugin()
1525     else:
1526         main_cli()
1527