fix: don't use Aggregate.new for builtin aggregates

This commit is contained in:
Zach Daniel 2024-02-24 19:17:35 -05:00
parent 4d39ab8a80
commit 993355f938
3 changed files with 140 additions and 135 deletions

View file

@ -8,10 +8,10 @@ defmodule Ash.Actions.Aggregate do
action = query.action || Ash.Resource.Info.primary_action!(query.resource, :read) action = query.action || Ash.Resource.Info.primary_action!(query.resource, :read)
opts = Keyword.put_new(opts, :read_action, action.name) opts = Keyword.put_new(opts, :read_action, action.name)
with {:ok, aggregates} <- validate_aggregates(query, aggregates, opts), with %{valid?: true} = query <- Ash.Actions.Read.handle_attribute_multitenancy(query) do
%{valid?: true} = query <- Ash.Actions.Read.handle_attribute_multitenancy(query) do
aggregates aggregates
|> Enum.group_by(fn aggregate -> |> Enum.group_by(fn
%Ash.Query.Aggregate{} = aggregate ->
agg_authorize? = aggregate.authorize? && opts[:authorize?] agg_authorize? = aggregate.authorize? && opts[:authorize?]
read_action = read_action =
@ -19,9 +19,18 @@ defmodule Ash.Actions.Aggregate do
Ash.Resource.Info.primary_action!(query.resource, :read).name Ash.Resource.Info.primary_action!(query.resource, :read).name
{agg_authorize?, read_action} {agg_authorize?, read_action}
{_name, _kind} ->
{!!opts[:authorize?], opts[:read_action]}
{_name, _kind, agg_opts} ->
authorize? =
Keyword.get(agg_opts, :authorize?, true) && opts[:authorize?]
{authorize?, agg_opts[:read_action] || opts[:read_action] || action.name}
end) end)
|> Enum.reduce_while({:ok, %{}}, fn {{agg_authorize?, read_action}, aggregates}, |> Enum.reduce_while({:ok, %{}}, fn
{:ok, acc} -> {{agg_authorize?, read_action}, aggregates}, {:ok, acc} ->
query = query =
if query.__validated_for_action__ == read_action do if query.__validated_for_action__ == read_action do
query query
@ -52,14 +61,17 @@ defmodule Ash.Actions.Aggregate do
Ash.Tracer.telemetry_span [:ash, Ash.Api.Info.short_name(query.api), :aggregate], Ash.Tracer.telemetry_span [:ash, Ash.Api.Info.short_name(query.api), :aggregate],
metadata do metadata do
Ash.Tracer.set_metadata(opts[:tracer], :action, metadata) Ash.Tracer.set_metadata(opts[:tracer], :action, metadata)
query = Map.put(query, :aggregates, Map.new(aggregates, &{&1.name, &1}))
with {:ok, query} <- authorize_query(query, opts, agg_authorize?), with {:ok, query} <- authorize_query(query, opts, agg_authorize?),
{:ok, aggregates} <- validate_aggregates(query, aggregates, opts),
{:ok, data_layer_query} <- {:ok, data_layer_query} <-
Ash.Query.data_layer_query(Ash.Query.new(query.resource)), Ash.Query.data_layer_query(Ash.Query.new(query.resource)),
aggregates <- merge_query_into_aggregates(query, aggregates),
{:ok, result} <- {:ok, result} <-
Ash.DataLayer.run_aggregate_query(data_layer_query, aggregates, query.resource) do Ash.DataLayer.run_aggregate_query(
data_layer_query,
aggregates,
query.resource
) do
{:cont, {:ok, Map.merge(acc, result)}} {:cont, {:ok, Map.merge(acc, result)}}
else else
{:error, error} -> {:error, error} ->
@ -71,21 +83,15 @@ defmodule Ash.Actions.Aggregate do
end end
end end
defp merge_query_into_aggregates(query, aggregates) do defp merge_query(left, right) do
Enum.map(aggregates, fn aggregate -> left
%{ |> Ash.Query.do_filter(right.filter)
aggregate |> Ash.Query.sort(right.sort, prepend?: true)
| query: |> Ash.Query.distinct_sort(right.distinct_sort, prepend?: true)
aggregate.query |> Ash.Query.limit(right.limit)
|> Ash.Query.do_filter(query.filter) |> Ash.Query.set_tenant(right.tenant)
|> Ash.Query.sort(query.sort, prepend?: true) |> merge_offset(right.offset)
|> Ash.Query.distinct_sort(query.distinct_sort, prepend?: true) |> Ash.Query.set_context(right.context)
|> Ash.Query.limit(query.limit)
|> Ash.Query.set_tenant(query.tenant)
|> merge_offset(query.offset)
|> Ash.Query.set_context(query.context)
}
end)
end end
defp merge_offset(query, offset) do defp merge_offset(query, offset) do
@ -128,7 +134,7 @@ defmodule Ash.Actions.Aggregate do
query.resource, query.resource,
name, name,
kind, kind,
set_opts([], opts) set_opts(query, [], opts)
) do ) do
{:ok, aggregate} -> {:ok, aggregate} ->
{:cont, {:ok, [aggregate | aggregates]}} {:cont, {:ok, [aggregate | aggregates]}}
@ -138,7 +144,7 @@ defmodule Ash.Actions.Aggregate do
end end
{name, kind, agg_opts}, {:ok, aggregates} -> {name, kind, agg_opts}, {:ok, aggregates} ->
case Ash.Query.Aggregate.new(query.resource, name, kind, set_opts(agg_opts, opts)) do case Ash.Query.Aggregate.new(query.resource, name, kind, set_opts(query, agg_opts, opts)) do
{:ok, aggregate} -> {:ok, aggregate} ->
{:cont, {:ok, [aggregate | aggregates]}} {:cont, {:ok, [aggregate | aggregates]}}
@ -148,8 +154,23 @@ defmodule Ash.Actions.Aggregate do
end) end)
end end
defp set_opts(specified, others) do defp set_opts(query, specified, others) do
{agg_opts, _} = Ash.Query.Aggregate.split_aggregate_opts(others) {agg_opts, _} = Ash.Query.Aggregate.split_aggregate_opts(others)
Keyword.merge(agg_opts, specified)
query =
case agg_opts[:query] do
%Ash.Query{} = agg_query ->
merge_query(agg_query, query)
nil ->
query
opts ->
Ash.Query.build(query, opts)
end
agg_opts
|> Keyword.merge(specified)
|> Keyword.put(:query, query)
end end
end end

View file

@ -1623,16 +1623,36 @@ defmodule Ash.Api do
end end
end end
@doc """
Runs an aggregate or aggregates over a resource query
If you pass an `%Ash.Query.Aggregate{}`, gotten from `Ash.Query.Aggregate.new()`,
the query provided as the first argument to this function will not apply. For this
reason, it is preferred that you pass in the tuple format, i.e
Prefer this:
`Api.aggregate(query, {:count_of_things, :count})`
Over this:
`Api.aggregate(query, Ash.Query.Aggregate.new(...))`
#{Spark.OptionsHelpers.docs(@aggregate_opts)}
"""
@callback aggregate( @callback aggregate(
Ash.Query.t(), Ash.Query.t(),
Ash.Api.aggregate() | list(Ash.Api.aggregate()), aggregate() | list(aggregate()),
opts :: Keyword.t() opts :: Keyword.t()
) :: ) ::
{:ok, any} | {:error, Ash.Error.t()} {:ok, any} | {:error, Ash.Error.t()}
@doc """
Runs an aggregate or aggregates over a resource query
See `c:aggregate/3` for more.
"""
@callback aggregate!( @callback aggregate!(
Ash.Query.t(), Ash.Query.t(),
Ash.Api.aggregate() | list(Ash.Api.aggregate()), aggregate() | list(aggregate()),
opts :: Keyword.t() opts :: Keyword.t()
) :: ) ::
any | no_return any | no_return

View file

@ -81,19 +81,13 @@ defmodule Ash.Api.Interface do
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts) {aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
case Ash.Query.Aggregate.new(query.resource, :count, :count, aggregate_opts) do case Api.aggregate(__MODULE__, query, {:count, :count, aggregate_opts}, opts) do
{:ok, aggregate} ->
case Api.aggregate(__MODULE__, query, aggregate, opts) do
{:ok, %{count: count}} -> {:ok, %{count: count}} ->
count count
{:error, error} -> {:error, error} ->
raise Ash.Error.to_error_class(error) raise Ash.Error.to_error_class(error)
end end
{:error, error} ->
raise Ash.Error.to_error_class(error)
end
end end
def count(query, opts \\ []) do def count(query, opts \\ []) do
@ -108,19 +102,13 @@ defmodule Ash.Api.Interface do
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts) {aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
case Ash.Query.Aggregate.new(query.resource, :count, :count, aggregate_opts) do case Api.aggregate(__MODULE__, query, {:count, :count, aggregate_opts}, opts) do
{:ok, aggregate} ->
case Api.aggregate(__MODULE__, query, aggregate, opts) do
{:ok, %{count: count}} -> {:ok, %{count: count}} ->
{:ok, count} {:ok, count}
{:error, error} -> {:error, error} ->
{:error, Ash.Error.to_error_class(error)} {:error, Ash.Error.to_error_class(error)}
end end
{:error, error} ->
{:error, Ash.Error.to_error_class(error)}
end
end end
def exists?(query, opts \\ []) do def exists?(query, opts \\ []) do
@ -135,19 +123,13 @@ defmodule Ash.Api.Interface do
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts) {aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
case Ash.Query.Aggregate.new(query.resource, :exists, :exists, aggregate_opts) do case Api.aggregate(__MODULE__, query, {:exists, :exists, aggregate_opts}, opts) do
{:ok, aggregate} ->
case Api.aggregate(__MODULE__, query, aggregate, opts) do
{:ok, %{exists: exists}} -> {:ok, %{exists: exists}} ->
exists exists
{:error, error} -> {:error, error} ->
raise Ash.Error.to_error_class(error) raise Ash.Error.to_error_class(error)
end end
{:error, error} ->
raise Ash.Error.to_error_class(error)
end
end end
def exists(query, opts \\ []) do def exists(query, opts \\ []) do
@ -162,19 +144,13 @@ defmodule Ash.Api.Interface do
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts) {aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
case Ash.Query.Aggregate.new(query.resource, :exists, :exists, aggregate_opts) do case Api.aggregate(__MODULE__, query, {:exists, :exists, aggregate_opts}, opts) do
{:ok, aggregate} ->
case Api.aggregate(__MODULE__, query, aggregate, opts) do
{:ok, %{exists: exists}} -> {:ok, %{exists: exists}} ->
{:ok, exists} {:ok, exists}
{:error, error} -> {:error, error} ->
{:error, Ash.Error.to_error_class(error)} {:error, Ash.Error.to_error_class(error)}
end end
{:error, error} ->
{:error, Ash.Error.to_error_class(error)}
end
end end
for kind <- [:first, :sum, :list, :max, :min, :avg] do for kind <- [:first, :sum, :list, :max, :min, :avg] do
@ -190,24 +166,18 @@ defmodule Ash.Api.Interface do
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts) {aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
case Ash.Query.Aggregate.new( case Api.aggregate(
query.resource, __MODULE__,
unquote(kind), query,
unquote(kind), {unquote(kind), unquote(kind), Keyword.put(aggregate_opts, :field, field)},
Keyword.put(aggregate_opts, :field, field) opts
) do ) do
{:ok, aggregate} ->
case Api.aggregate(__MODULE__, query, aggregate, opts) do
{:ok, %{unquote(kind) => value}} -> {:ok, %{unquote(kind) => value}} ->
{:ok, value} {:ok, value}
{:error, error} -> {:error, error} ->
{:error, Ash.Error.to_error_class(error)} {:error, Ash.Error.to_error_class(error)}
end end
{:error, error} ->
{:error, Ash.Error.to_error_class(error)}
end
end end
# sobelow_skip ["DOS.BinToAtom"] # sobelow_skip ["DOS.BinToAtom"]
@ -223,24 +193,18 @@ defmodule Ash.Api.Interface do
opts opts
end end
case Ash.Query.Aggregate.new( case Api.aggregate(
query.resource, __MODULE__,
unquote(kind), query,
unquote(kind), {unquote(kind), unquote(kind), Keyword.put(aggregate_opts, :field, field)},
Keyword.put(aggregate_opts, :field, field) opts
) do ) do
{:ok, aggregate} ->
case Api.aggregate(__MODULE__, query, aggregate, opts) do
{:ok, %{unquote(kind) => value}} -> {:ok, %{unquote(kind) => value}} ->
value value
{:error, error} -> {:error, error} ->
raise Ash.Error.to_error_class(error) raise Ash.Error.to_error_class(error)
end end
{:error, error} ->
raise Ash.Error.to_error_class(error)
end
end end
end end