Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement message merging. #60

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 83 additions & 65 deletions src/ocaml_protoc_plugin/deserialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@ module S = Spec.Deserialize
module C = S.C
open S

type required = Required | Optional

type 'a reader = 'a -> Reader.t -> Field.field_type -> 'a
type 'a getter = 'a -> 'a
type ('a, 'b) getter = 'a -> 'b
type 'a field_spec = (int * 'a reader)
type 'a value = ('a field_spec list * required * 'a * 'a getter)
type _ value = Value: ('b field_spec list * 'b * ('b, 'a) getter) -> 'a value
type extensions = (int * Field.t) list

type (_, _) value_list =
| VNil : ('a, 'a) value_list
| VCons : ('a value) * ('b, 'c) value_list -> ('a -> 'b, 'c) value_list
| VNil_ext : (extensions -> 'a, 'a) value_list
| VCons : 'a value * ('b, 'c) value_list -> ('a -> 'b, 'c) value_list

type sentinel_field_spec = int * (Reader.t -> Field.field_type -> unit)
type 'a sentinel_getter = unit -> 'a

type (_, _) sentinel_list =
| NNil : ('a, 'a) sentinel_list
| NNil_ext: (extensions -> 'a, 'a) sentinel_list
| NCons : (sentinel_field_spec list * 'a sentinel_getter) * ('b, 'c) sentinel_list -> ('a -> 'b, 'c) sentinel_list

let error_wrong_field str field = Result.raise (`Wrong_field_type (str, field))
Expand Down Expand Up @@ -82,10 +83,10 @@ let read_of_spec: type a. a spec -> Field.field_type * (Reader.t -> a) = functio
let v = Bytes.create length in
Bytes.blit_string ~src:data ~src_pos:offset ~dst:v ~dst_pos:0 ~len:length;
v
| Message from_proto -> Length_delimited, fun reader ->
| Message (from_proto, _merge) -> Length_delimited, fun reader ->
let Field.{ offset; length; data } = Reader.read_length_delimited reader in
from_proto (Reader.create ~offset ~length data)

(*
let default_value: type a. a spec -> a = function
| Double -> 0.0
| Float -> 0.0
Expand All @@ -99,7 +100,7 @@ let default_value: type a. a spec -> a = function
| Fixed64 -> Int64.zero
| SFixed32 -> Int32.zero
| SFixed64 -> Int64.zero
| Message of_proto -> of_proto (Reader.create "")
| Message (of_proto, _merge) -> of_proto (Reader.create "")
| String -> ""
| Bytes -> Bytes.empty
| Int32_int -> 0
Expand All @@ -114,7 +115,7 @@ let default_value: type a. a spec -> a = function
| SFixed64_int -> 0
| Enum of_int -> of_int 0
| Bool -> false

*)
let id x = x
let keep_last _ v = v

Expand All @@ -126,20 +127,29 @@ let read_field ~read:(expect, read_f) ~map v reader field_type =
error_wrong_field "Deserialize" field

let value: type a. a compound -> a value = function
| Basic_req (index, spec) ->
let map _ v2 = Some v2 in
let read = read_field ~read:(read_of_spec spec) ~map in
let getter = function Some v -> v | None -> error_required_field_missing () in
Value ([(index, read)], None, getter)
| Basic (index, spec, default) ->
let read = read_field ~read:(read_of_spec spec) ~map:keep_last in
let required = match default with
| Some _ -> Optional
| None -> Required
let map = keep_last
in
let default = match default with
| None -> default_value spec
| Some default -> default
in
([(index, read)], required, default, id)
let read = read_field ~read:(read_of_spec spec) ~map in
Value ([(index, read)], default, id)
| Basic_opt (index, spec) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun _ v -> Some v) in
([(index, read)], Optional, None, id)
let map = match spec with
| Message (_, merge) ->
let map v1 v2 =
match v1 with
| None -> Some v2
| Some v1 -> Some (merge v1 v2)
in
map
| _ -> fun _ v -> Some v (* Keep last for all other non-repeated types *)
in
let read = read_field ~read:(read_of_spec spec) ~map in
Value ([(index, read)], None, id)
| Repeated (index, spec, Packed) ->
let field_type, read_f = read_of_spec spec in
let rec read_packed_values read_f acc reader =
Expand All @@ -158,37 +168,34 @@ let value: type a. a compound -> a value = function
let field = Reader.read_field_content ft reader in
error_wrong_field "Deserialize" field
in
([(index, read)], Optional, [], List.rev)
Value ([(index, read)], [], List.rev)
| Repeated (index, spec, Not_packed) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun vs v -> v :: vs) in
([(index, read)], Optional, [], List.rev)
Value ([(index, read)], [], List.rev)
| Oneof oneofs ->
let make_reader: a oneof -> a field_spec = fun (Oneof_elem (index, spec, constr)) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun _ -> constr) in
(index, read)
in
(List.map ~f:make_reader oneofs, Optional, `not_set, id)
Value (List.map ~f:make_reader oneofs, `not_set, id)

module IntMap = Map.Make(struct type t = int let compare = Int.compare end)

let in_extension_ranges extension_ranges index =
List.exists ~f:(fun (start, end') -> index >= start && index <= end') extension_ranges

(** Full (slow) deserialization. *)
let deserialize_full: type constr a. (int * int) list -> (constr, (int * Field.t) list -> a) value_list -> constr -> Reader.t -> a = fun extension_ranges values constructor reader ->
(* Need to return the map also! *)
let deserialize_full: type constr a. extension_ranges -> (constr, a) value_list -> constr -> Reader.t -> a = fun extension_ranges values constructor reader ->
let rec make_sentinel_list: type a b. (a, b) value_list -> (a, b) sentinel_list = function
| VNil -> NNil
| VNil_ext -> NNil_ext
(* Consider optimizing when optional is true *)
| VCons ((fields, required, default, getter), rest) ->
let v = ref (default, required) in
let get () = match !v with
| _, Required -> error_required_field_missing ();
| v, Optional-> getter v
in
| VCons (Value (fields, default, getter), rest) ->
let v = ref default in
let get () = getter !v in
let fields =
List.map ~f:(fun (index, read) ->
let read reader field_type = let v' = fst !v in v := (read v' reader field_type, Optional) in
let read reader field_type = (v := read !v reader field_type) in
(index, read)
) fields
in
Expand All @@ -197,20 +204,22 @@ let deserialize_full: type constr a. (int * int) list -> (constr, (int * Field.t

let rec create_map: type a b. _ IntMap.t -> (a, b) sentinel_list -> _ IntMap.t = fun map -> function
| NNil -> map
| NNil_ext -> map
| NCons ((fields, _), rest) ->
let map =
List.fold_left ~init:map ~f:(fun map (index, read)-> IntMap.add index read map) fields
in
create_map map rest
in

let rec apply: type constr t. constr -> (constr, t) sentinel_list -> t = fun constr -> function
let rec apply: type constr a. extensions -> constr -> (constr, a) sentinel_list -> a = fun extensions constr -> function
| NNil -> constr
| NNil_ext -> constr extensions
| NCons ((_, get), rest) ->
apply (constr (get ())) rest
apply extensions (constr (get ())) rest
in

let rec read: (Reader.t -> Field.field_type -> unit) IntMap.t -> (int * Field.t) list -> (int * Field.t) list = fun map extensions ->
let rec read: (Reader.t -> Field.field_type -> unit) IntMap.t -> extensions -> extensions = fun map extensions ->
match Reader.has_more reader with
| false -> List.rev extensions
| true ->
Expand All @@ -229,78 +238,87 @@ let deserialize_full: type constr a. (int * int) list -> (constr, (int * Field.t
let sentinels = make_sentinel_list values in
let map = create_map IntMap.empty sentinels in
let extensions = read map [] in
apply constructor sentinels extensions
apply extensions constructor sentinels

let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t -> a = fun spec constr ->

(* Exception indicating that fast deserialization did not succeed and revert to full deserialization *)
let exception Restart_full in

let rec extension_ranges: type a b. (a, b) compound_list -> extension_ranges = function
| Nil -> []
| Nil_ext extension_ranges -> extension_ranges
| Cons (_, rest) -> extension_ranges rest
in

let deserialize: type constr a. (int * int) list -> (constr, (int * Field.t) list -> a) compound_list -> constr -> Reader.t -> a = fun extension_ranges spec constr ->
let rec make_values: type a b. (a, b) compound_list -> (a, b) value_list = function
| Nil -> VNil
| Nil_ext _extension_ranges -> VNil_ext
| Cons (spec, rest) ->
let value = value spec in
let values = make_values rest in
VCons (value, values)
in
let values = make_values spec in

let next_field reader =
match Reader.has_more reader with
| true -> Reader.read_field_header reader
| false -> Field.Varint, Int.max_int
in

let rec read_values: type constr a. (int * int) list -> Field.field_type -> int -> Reader.t -> constr -> (int * Field.t) list -> (constr, (int * Field.t) list -> a) value_list -> a option = fun extension_ranges tpe idx reader constr extensions ->
let rec read_repeated tpe index read_f default get reader =
let rec read_values: type constr a. extension_ranges -> Field.field_type -> int -> Reader.t -> constr -> extensions -> (constr, a) value_list -> a = fun extension_ranges tpe idx reader constr extensions ->
let rec read_repeated tpe index read_f default reader =
let default = read_f default reader tpe in
let (tpe, idx) = next_field reader in
match idx = index with
| true -> read_repeated tpe index read_f default get reader
| true -> read_repeated tpe index read_f default reader
| false -> default, tpe, idx
in
function
| VCons (([index, read_f], _required, default, get), vs) when index = idx ->
| VNil when idx = Int.max_int ->
constr
| VNil_ext when idx = Int.max_int ->
constr (List.rev extensions)
(* All fields read successfully. Apply extensions and return result. *)
| VCons (Value ([index, read_f], default, get), vs) when index = idx ->
(* Read all values, and apply constructor once all fields have been read.
This pattern is the most likely to be matched for all values, and is added
as an optimization to avoid reconstructing the value list for each recursion.
*)
let default, tpe, idx = read_repeated tpe index read_f default get reader in
let default, tpe, idx = read_repeated tpe index read_f default reader in
let constr = (constr (get default)) in
read_values extension_ranges tpe idx reader constr extensions vs
| VCons (((index, read_f) :: fields, _required, default, get), vs) when index = idx ->
| VCons (Value ((index, read_f) :: fields, default, get), vs) when index = idx ->
(* Read all values for the given field *)
let default, tpe, idx = read_repeated tpe index read_f default get reader in
read_values extension_ranges tpe idx reader constr extensions (VCons ((fields, Optional, default, get), vs))
let default, tpe, idx = read_repeated tpe index read_f default reader in
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, default, get), vs))
| vs when in_extension_ranges extension_ranges idx ->
(* Extensions may be sent inline. Store all valid extensions, before starting to apply constructors *)
let extensions = (idx, Reader.read_field_content tpe reader) :: extensions in
let (tpe, idx) = next_field reader in
read_values extension_ranges tpe idx reader constr extensions vs
| VCons (([], Required, _default, _get), _vs) ->
(* If there are no more fields to be read we will never find the value.
If all values are read, then raise, else revert to full deserialization *)
begin match (idx = Int.max_int) with
| true -> error_required_field_missing ()
| false -> None
end
| VCons ((_ :: fields, optional, default, get), vs) ->
| VCons (Value (_ :: fields, default, get), vs) ->
(* Drop the field, as we dont expect to find it. *)
read_values extension_ranges tpe idx reader constr extensions (VCons ((fields, optional, default, get), vs))
| VCons (([], Optional, default, get), vs) ->
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, default, get), vs))
| VCons (Value ([], default, get), vs) ->
(* Apply destructor. This case is only relevant for oneof fields *)
read_values extension_ranges tpe idx reader (constr (get default)) extensions vs
| VNil when idx = Int.max_int ->
(* All fields read successfully. Apply extensions and return result. *)
Some (constr (List.rev extensions))
| VNil ->
| VNil | VNil_ext ->
(* This implies that there are still fields to be read.
Revert to full deserialization.
*)
None
raise Restart_full
in

let extension_ranges = extension_ranges spec in
let values = make_values spec in

fun reader ->
let offset = Reader.offset reader in
let (tpe, idx) = next_field reader in
read_values extension_ranges tpe idx reader constr [] values
|> function
| Some t -> t
| None ->
try
read_values extension_ranges tpe idx reader constr [] values
with (Restart_full | Result.Error `Required_field_missing) ->
(* Revert to full deserialization *)
Reader.reset reader offset;
deserialize_full extension_ranges values constr reader
4 changes: 2 additions & 2 deletions src/ocaml_protoc_plugin/deserialize.mli
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module C = Spec.Deserialize.C

val deserialize: (int * int) list ->
('constr, (int * Field.t) list -> 'a) Spec.Deserialize.compound_list ->
val deserialize:
('constr, 'a) Spec.Deserialize.compound_list ->
'constr -> Reader.t -> 'a
5 changes: 3 additions & 2 deletions src/ocaml_protoc_plugin/extensions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ let compare _ _ = 0
let index_of_spec: type a. a Spec.Serialize.compound -> int = function
| Basic (index, _, _) -> index
| Basic_opt (index, _) -> index
| Basic_req (index, _) -> index
| Repeated (index, _, _) -> index
| Oneof _ -> failwith "Oneof fields not allowed in extensions"

let get: type a. a Spec.Deserialize.compound -> t -> a = fun spec t ->
let writer = Writer.of_list t in
let reader = Writer.contents writer |> Reader.create in
Deserialize.deserialize [] Spec.Deserialize.(Cons (spec, Nil)) (fun a _ -> a) reader
Deserialize.deserialize Spec.Deserialize.(Cons (spec, Nil)) (fun a -> a) reader

let set: type a. a Spec.Serialize.compound -> t -> a -> t = fun spec t v ->
let writer = Writer.init () in
let writer = Serialize.serialize [] Spec.Serialize.(Cons (spec, Nil)) [] writer v in
let writer = Serialize.serialize Spec.Serialize.(Cons (spec, Nil)) writer v in
let index = index_of_spec spec in
let fields =
Writer.contents writer
Expand Down
29 changes: 29 additions & 0 deletions src/ocaml_protoc_plugin/merge.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
(** Merge a two values. Need to match on the spec to merge messages recursivly *)
let merge: type t. t Spec.Deserialize.compound -> t -> t -> t = fun spec t t' -> match spec with
| Spec.Deserialize.Basic (_field, Message (_, _), _) -> failwith "Messages with defaults cannot happen"
| Spec.Deserialize.Basic (_field, _spec, default) when t' = default -> t
| Spec.Deserialize.Basic (_field, _spec, _) -> t'

(* The spec states that proto2 required fields must be transmitted exactly once.
So merging these fields is not possible. The essentially means that you cannot merge
proto2 messages containing required fields.
In this implementation, we choose to ignore this, and adopt 'keep last'
*)
| Spec.Deserialize.Basic_req (_field, Message (_, merge)) -> merge t t'
| Spec.Deserialize.Basic_req (_field, _spec) -> t'
| Spec.Deserialize.Basic_opt (_field, Message (_, merge)) ->
begin
match t, t' with
| None, None -> None
| Some t, None -> Some t
| None, Some t -> Some t
| Some t, Some t' -> Some (merge t t')
end
| Spec.Deserialize.Basic_opt (_field, _spec) -> begin
match t' with
| Some _ -> t'
| None -> t
end
| Spec.Deserialize.Repeated (_field, _, _) -> t @ t'
(* | Spec.Deserialize.Oneof _ when t' = `not_set -> t *)
| Spec.Deserialize.Oneof _ -> failwith "Implementation is part of generated code"
1 change: 1 addition & 0 deletions src/ocaml_protoc_plugin/ocaml_protoc_plugin.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Serialize = Serialize
module Deserialize = Deserialize
module Spec = Spec
module Runtime = Runtime
module Field = Field
(**/**)

module Reader = Reader
Expand Down
1 change: 1 addition & 0 deletions src/ocaml_protoc_plugin/runtime.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ module Runtime' = struct
module Extensions = Extensions
module Reader = Reader
module Writer = Writer
module Merge = Merge
end
Loading
Loading