diff --git a/src_plugins/ord/ppx_deriving_ord.cppo.ml b/src_plugins/ord/ppx_deriving_ord.cppo.ml index c36d57d..0e781d9 100644 --- a/src_plugins/ord/ppx_deriving_ord.cppo.ml +++ b/src_plugins/ord/ppx_deriving_ord.cppo.ml @@ -35,10 +35,14 @@ let reduce_compare l = | [] -> [%expr 0] | x :: xs -> List.fold_left compare_reduce x xs -let wildcard_case int_cases = +let wildcard_case ?typ int_cases = let loc = !Ast_helper.default_loc in + let typ = match typ with + | Some typ -> typ + | None -> [%type: _] (* don't constrain *) + in Exp.case [%pat? _] [%expr - let to_int = [%e Exp.function_ int_cases] in + let to_int (x: [%t typ]) = [%e Exp.match_ [%expr x] int_cases] in Ppx_deriving_runtime.compare (to_int lhs) (to_int rhs)] let pattn side typs = @@ -185,6 +189,24 @@ let sig_of_type ~options ~path type_decl = let str_of_type ~options ~path ({ ptype_loc = loc } as type_decl) = parse_options options; let quoter = Ppx_deriving.create_quoter () in + (* Capture type in helper module outside Ppx_deriving_runtime wrapper (added by sanitize). + Required for to_int constraint in variant type wildcard_case if the type name + conflicts with a Stdlib type from Ppx_deriving_runtime (e.g. bool in test). + In that case we must refer to the type being declared, not the one opened by Ppx_deriving_runtime. *) + let helper_type = + Type.mk ~loc + ~params:type_decl.ptype_params + ~manifest:(Ppx_deriving.core_type_of_type_decl type_decl) + (mkloc "t" loc) + in + let helper_typ = + let name = mkloc (Longident.parse "Ppx_deriving_ord_helper.t") loc in + let params = match helper_type.ptype_params with + | [] -> [] + | _ :: _ -> [Typ.any ()] (* match all params with single wildcard *) + in + Typ.constr name params + in let comparator = match type_decl.ptype_kind, type_decl.ptype_manifest with | Ptype_abstract, Some manifest -> expr_of_typ quoter manifest @@ -208,7 +230,7 @@ let str_of_type ~options ~path ({ ptype_loc = loc } as type_decl) = ) in [%expr fun lhs rhs -> - [%e Exp.match_ [%expr lhs, rhs] (cases @ [wildcard_case int_cases])]] + [%e Exp.match_ [%expr lhs, rhs] (cases @ [wildcard_case ~typ:helper_typ int_cases])]] | Ptype_record labels, _ -> let exprs = labels |> List.map (fun ({ pld_name = { txt = name }; _ } as pld) -> @@ -235,9 +257,18 @@ let str_of_type ~options ~path ({ ptype_loc = loc } as type_decl) = core_type_of_decl ~options ~path type_decl in let out_var = pvar (Ppx_deriving.mangle_type_decl (`Prefix "compare") type_decl) in + let comparator_with_helper = + [%expr let module Ppx_deriving_ord_helper = + struct + [@@@warning "-unused-type-declaration"] + [%%i Str.type_ Nonrecursive [helper_type]] + end + in + [%e Ppx_deriving.sanitize ~quoter (eta_expand (polymorphize comparator))]] + in [Vb.mk ~attrs:[Ppx_deriving.attr_warning [%expr "-39"]] (Pat.constraint_ out_var out_type) - (Ppx_deriving.sanitize ~quoter (eta_expand (polymorphize comparator)))] + comparator_with_helper] let () = Ppx_deriving.(register (create deriver diff --git a/src_test/eq/test_deriving_eq.cppo.ml b/src_test/eq/test_deriving_eq.cppo.ml index a2c4674..6a7c0b6 100644 --- a/src_test/eq/test_deriving_eq.cppo.ml +++ b/src_test/eq/test_deriving_eq.cppo.ml @@ -131,6 +131,8 @@ and 'a poly_abs_custom = 'a module List = struct type 'a t = [`Cons of 'a | `Nil] [@@deriving eq] + type 'a u = Cons of 'a | Nil + [@@deriving eq] end type 'a std_clash = 'a List.t option [@@deriving eq] @@ -148,6 +150,13 @@ let test_result_result ctxt = assert_equal ~printer false (eq (Ok "123") (Error 123)); assert_equal ~printer false (eq (Error 123) (Error 0)) +module ResultOverride = struct + type t = + | Ok + | Error + [@@deriving eq] +end + let suite = "Test deriving(eq)" >::: [ "test_simple" >:: test_simple; "test_array" >:: test_arr; diff --git a/src_test/ord/test_deriving_ord.cppo.ml b/src_test/ord/test_deriving_ord.cppo.ml index fdd4d0d..ab23f9d 100644 --- a/src_test/ord/test_deriving_ord.cppo.ml +++ b/src_test/ord/test_deriving_ord.cppo.ml @@ -158,6 +158,8 @@ and 'a poly_abs_custom = 'a module List = struct type 'a t = [`Cons of 'a | `Nil] [@@deriving ord] + type 'a u = Cons of 'a | Nil + [@@deriving ord] end type 'a std_clash = 'a List.t option [@@deriving ord] @@ -179,6 +181,13 @@ let test_record_order ctxt = assert_equal ~printer (0) (compare_ab { a = 1; b = 2; } { a = 1; b = 2; }); assert_equal ~printer (1) (compare_ab { a = 2; b = 2; } { a = 1; b = 2; }) +module ResultOverride = struct + type t = + | Ok + | Error + [@@deriving ord] +end + let suite = "Test deriving(ord)" >::: [ "test_simple" >:: test_simple; "test_variant" >:: test_variant;