Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions lib/strukt.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ defmodule Strukt do

NOTE: It is recommended that if you need to perform custom validations, that
you use the `validation/1` and `validation/2` facility for performing custom
validations in a module or function, and if necessary, override `c:validate/1`
instead of performing validations in this callback. If you need to override this
validations in a module or function, and if necessary, override `c:validate/1`
instead of performing validations in this callback. If you need to override this
callback specifically for some reason, make sure you call `super/2` at some point during
your implementation to ensure that validations are run.
"""
Expand Down Expand Up @@ -242,20 +242,22 @@ defmodule Strukt do
[]
end

opaque_fields = Enum.any?(special_attrs, fn {:@, _, [{attr, _, [value]}]} -> attr == :opaque_fields &&!!value end)

fields = Strukt.Field.parse(fields)

define_struct(env, name, meta, moduledoc, derives, schema_attrs, fields, body)
define_struct(env, name, meta, moduledoc, derives, opaque_fields, schema_attrs, fields, body)
end

# This clause handles the edge case where the definition only contains
# a single field and nothing else
defp define_struct(env, name, {type, _, _} = field) when is_supported(type) do
fields = Strukt.Field.parse([field])

define_struct(env, name, [], nil, [], [], fields, [])
define_struct(env, name, [], nil, [], false, [], fields, [])
end

defp define_struct(_env, name, meta, moduledoc, derives, schema_attrs, fields, body) do
defp define_struct(_env, name, meta, moduledoc, derives, opaque_fields, schema_attrs, fields, body) do
# Extract macros which should be defined at the top of the module
{macros, body} =
Enum.split_with(body, fn
Expand Down Expand Up @@ -356,6 +358,7 @@ defmodule Strukt do
end

@schema_name Macro.underscore(__MODULE__)
@opaque_fields unquote(opaque_fields)
@validated_fields unquote(validated_fields)
@cast_embed_fields unquote(Macro.escape(cast_embed_fields))

Expand Down Expand Up @@ -491,6 +494,7 @@ defmodule Strukt do
typespec_ast =
Strukt.Typespec.generate(%Strukt.Typespec{
caller: __MODULE__,
opaque: @opaque_fields,
info: @validated_fields,
fields: @cast_fields,
embeds: @cast_embed_fields
Expand Down
12 changes: 9 additions & 3 deletions lib/typespec.ex
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
defmodule Strukt.Typespec do
@moduledoc false

defstruct [:caller, :info, :fields, :embeds]
defstruct [:caller, :opaque, :info, :fields, :embeds]

@type t :: %__MODULE__{
# The module where the struct is being defined
caller: module,
# Defines whether the typespec should be opaque or the default type
opaque: boolean,
# Metadata about all fields in the struct
info: %{optional(atom) => map},
# A list of all non-embed field names
Expand Down Expand Up @@ -45,7 +47,7 @@ defmodule Strukt.Typespec do
* `fields` - This is a list of all field names which are defined via `field/3`
* `embeds` - This is a list of all field names which are defined via `embeds_one/3` or `embeds_many/3`
"""
def generate(%__MODULE__{caller: caller, info: info, fields: fields, embeds: embeds}) do
def generate(%__MODULE__{caller: caller, opaque: opaque, info: info, fields: fields, embeds: embeds}) do
# Build up the AST for each field's type spec
fields =
fields
Expand Down Expand Up @@ -87,7 +89,11 @@ defmodule Strukt.Typespec do
# Join all fields together
struct_fields = fields ++ embeds

quote(context: caller, do: @type(t :: %__MODULE__{unquote_splicing(struct_fields)}))
if opaque do
quote(context: caller, do: @opaque(t :: %__MODULE__{unquote_splicing(struct_fields)}))
else
quote(context: caller, do: @type(t :: %__MODULE__{unquote_splicing(struct_fields)}))
end
end

defp primitive(atom, args \\ []) when is_atom(atom) and is_list(args),
Expand Down