Extension support implemented for decoder.
authorPetteri Aimonen <jpa@git.mail.kapsi.fi>
Wed, 17 Jul 2013 17:21:51 +0000 (20:21 +0300)
committerPetteri Aimonen <jpa@git.mail.kapsi.fi>
Wed, 17 Jul 2013 17:21:51 +0000 (20:21 +0300)
Testing is still needed. Also only 'optional' extension fields
are supported now, 'repeated' fields are not yet supported.

generator/nanopb_generator.py
pb_decode.c
pb_encode.c

index 61f4d7b..3bac9a9 100644 (file)
@@ -276,9 +276,13 @@ class Field:
 
 
 class ExtensionRange(Field):
-    def __init__(self, struct_name, desc, field_options):
-        '''desc is ExtensionRange'''
-        self.tag = desc.start
+    def __init__(self, struct_name, range_start, field_options):
+        '''Implements a special pb_extension_t* field in an extensible message
+        structure. The range_start signifies the index at which the extensions
+        start. Not necessarily all tags above this are extensions, it is merely
+        a speed optimization.
+        '''
+        self.tag = range_start
         self.struct_name = struct_name
         self.name = 'extensions'
         self.pbtype = 'EXTENSION'
@@ -304,6 +308,10 @@ class ExtensionField(Field):
         self.fullname = struct_name + desc.name
         self.extendee_name = names_from_type_name(desc.extendee)
         Field.__init__(self, self.fullname + 'struct', desc, field_options)
+        
+        if self.rules != 'OPTIONAL':
+            raise NotImplementedError("Only 'optional' is supported for extension fields. "
+               + "(%s.rules == %s)" % (self.fullname, self.rules))
 
     def extension_decl(self):
         '''Declaration of the extension type in the .pb.h file'''
@@ -341,8 +349,9 @@ class Message:
         
         if len(desc.extension_range) > 0:
             field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions')
+            range_start = min([r.start for r in desc.extension_range])
             if field_options.type != nanopb_pb2.FT_IGNORE:
-                self.fields.append(ExtensionRange(self.name, desc.extension_range[0], field_options))
+                self.fields.append(ExtensionRange(self.name, range_start, field_options))
         
         self.packed = message_options.packed_struct
         self.ordered_fields = self.fields[:]
index a079556..e3be412 100644 (file)
@@ -28,7 +28,8 @@ static const pb_decoder_t PB_DECODERS[PB_LTYPES_COUNT] = {
     
     &pb_dec_bytes,
     &pb_dec_string,
-    &pb_dec_submessage
+    &pb_dec_submessage,
+    NULL /* extensions */
 };
 
 /**************
@@ -336,8 +337,11 @@ static bool checkreturn pb_field_find(pb_field_iterator_t *iter, uint32_t tag)
     unsigned start = iter->field_index;
     
     do {
-        if (iter->pos->tag == tag)
+        if (iter->pos->tag == tag &&
+            PB_LTYPE(iter->pos->type) != PB_LTYPE_EXTENSION)
+        {
             return true;
+        }
         pb_field_next(iter);
     } while (iter->field_index != start);
     
@@ -472,6 +476,70 @@ static bool checkreturn decode_field(pb_istream_t *stream, pb_wire_type_t wire_t
     }
 }
 
+/* Default handler for extension fields. Expects a pb_field_t structure
+ * in extension->type->arg. */
+static bool checkreturn default_extension_handler(pb_istream_t *stream,
+    pb_extension_t *extension, uint32_t tag, pb_wire_type_t wire_type)
+{
+    const pb_field_t *field = (const pb_field_t*)extension->type->arg;
+    pb_field_iterator_t iter;
+    bool dummy;
+    
+    if (field->tag != tag)
+        return true;
+    
+    iter.start = field;
+    iter.pos = field;
+    iter.field_index = 0;
+    iter.required_field_index = 0;
+    iter.dest_struct = extension->dest;
+    iter.pData = extension->dest;
+    iter.pSize = &dummy;
+    
+    return decode_field(stream, wire_type, &iter);
+}
+
+/* Try to decode an unknown field as an extension field. Tries each extension
+ * decoder in turn, until one of them handles the field or loop ends. */
+static bool checkreturn decode_extension(pb_istream_t *stream,
+    uint32_t tag, pb_wire_type_t wire_type, pb_field_iterator_t *iter)
+{
+    pb_extension_t *extension = *(pb_extension_t* const *)iter->pData;
+    size_t pos = stream->bytes_left;
+    
+    while (extension && pos == stream->bytes_left)
+    {
+        bool status;
+        if (extension->type->decode)
+            status = extension->type->decode(stream, extension, tag, wire_type);
+        else
+            status = default_extension_handler(stream, extension, tag, wire_type);
+
+        if (!status)
+            return false;
+        
+        extension = extension->next;
+    }
+    
+    return true;
+}
+
+/* Step through the iterator until an extension field is found or until all
+ * entries have been checked. There can be only one extension field per
+ * message. Returns false if no extension field is found. */
+static bool checkreturn find_extension_field(pb_field_iterator_t *iter)
+{
+    unsigned start = iter->field_index;
+    
+    do {
+        if (PB_LTYPE(iter->pos->type) == PB_LTYPE_EXTENSION)
+            return true;
+        pb_field_next(iter);
+    } while (iter->field_index != start);
+    
+    return false;
+}
+
 /* Initialize message fields to default values, recursively */
 static void pb_message_set_to_defaults(const pb_field_t fields[], void *dest_struct)
 {
@@ -528,6 +596,7 @@ static void pb_message_set_to_defaults(const pb_field_t fields[], void *dest_str
 bool checkreturn pb_decode_noinit(pb_istream_t *stream, const pb_field_t fields[], void *dest_struct)
 {
     uint8_t fields_seen[(PB_MAX_REQUIRED_FIELDS + 7) / 8] = {0}; /* Used to check for required fields */
+    uint32_t extension_range_start = 0;
     pb_field_iterator_t iter;
     
     pb_field_init(&iter, fields, dest_struct);
@@ -548,6 +617,29 @@ bool checkreturn pb_decode_noinit(pb_istream_t *stream, const pb_field_t fields[
         
         if (!pb_field_find(&iter, tag))
         {
+            /* No match found, check if it matches an extension. */
+            if (tag >= extension_range_start)
+            {
+                if (!find_extension_field(&iter))
+                    extension_range_start = (uint32_t)-1;
+                else
+                    extension_range_start = iter.pos->tag;
+                
+                if (tag >= extension_range_start)
+                {
+                    size_t pos = stream->bytes_left;
+                
+                    if (!decode_extension(stream, tag, wire_type, &iter))
+                        return false;
+                    
+                    if (pos != stream->bytes_left)
+                    {
+                        /* The field was handled */
+                        continue;                    
+                    }
+                }
+            }
+        
             /* No match found, skip data */
             if (!pb_skip_field(stream, wire_type))
                 return false;
index f3c62a1..58d76a7 100644 (file)
@@ -28,7 +28,8 @@ static const pb_encoder_t PB_ENCODERS[PB_LTYPES_COUNT] = {
     
     &pb_enc_bytes,
     &pb_enc_string,
-    &pb_enc_submessage
+    &pb_enc_submessage,
+    NULL /* extensions */
 };
 
 /* pb_ostream_t implementation */