fix: properly authorize access to query aggregates in all cases

This commit is contained in:
Zach Daniel 2023-10-11 19:44:50 -04:00
parent ea51d4d096
commit 6e0be43c51
4 changed files with 80 additions and 26 deletions

View file

@ -6,8 +6,9 @@ defmodule Ash.Actions.Aggregate do
query = %{query | api: api}
{query, opts} = Ash.Actions.Helpers.add_process_context(query.api, query, opts)
action = query.action || Ash.Resource.Info.primary_action!(query.resource, :read)
opts = Keyword.put_new(opts, :read_action, action.name)
case validate_aggregates(query, aggregates) do
case validate_aggregates(query, aggregates, opts) do
{:ok, aggregates} ->
aggregates
|> Enum.group_by(fn aggregate ->
@ -120,14 +121,19 @@ defmodule Ash.Actions.Aggregate do
end
end
defp validate_aggregates(query, aggregates) do
defp validate_aggregates(query, aggregates, opts) do
aggregates
|> Enum.reduce_while({:ok, []}, fn
%Ash.Query.Aggregate{} = aggregate, {:ok, aggregates} ->
{:cont, {:ok, [aggregate | aggregates]}}
{name, kind}, {:ok, aggregates} ->
case Ash.Query.Aggregate.new(query.resource, name, kind) do
case Ash.Query.Aggregate.new(
query.resource,
name,
kind,
set_opts([], opts)
) do
{:ok, aggregate} ->
{:cont, {:ok, [aggregate | aggregates]}}
@ -135,8 +141,8 @@ defmodule Ash.Actions.Aggregate do
{:halt, {:error, error}}
end
{name, kind, opts}, {:ok, aggregates} ->
case Ash.Query.Aggregate.new(query.resource, name, kind, opts) do
{name, kind, agg_opts}, {:ok, aggregates} ->
case Ash.Query.Aggregate.new(query.resource, name, kind, set_opts(agg_opts, opts)) do
{:ok, aggregate} ->
{:cont, {:ok, [aggregate | aggregates]}}
@ -145,4 +151,9 @@ defmodule Ash.Actions.Aggregate do
end
end)
end
defp set_opts(specified, others) do
{agg_opts, _} = Ash.Query.Aggregate.split_aggregate_opts(others)
Keyword.merge(agg_opts, specified)
end
end

View file

@ -79,7 +79,7 @@ defmodule Ash.Api.Interface do
opts
end
{aggregate_opts, opts} = split_aggregate_opts(opts)
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
case Ash.Query.Aggregate.new(query.resource, :count, :count, aggregate_opts) do
{:ok, aggregate} ->
@ -106,7 +106,7 @@ defmodule Ash.Api.Interface do
opts
end
{aggregate_opts, opts} = split_aggregate_opts(opts)
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
case Ash.Query.Aggregate.new(query.resource, :count, :count, aggregate_opts) do
{:ok, aggregate} ->
@ -133,7 +133,7 @@ defmodule Ash.Api.Interface do
opts
end
{aggregate_opts, opts} = split_aggregate_opts(opts)
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
case Ash.Query.Aggregate.new(query.resource, :exists, :exists, aggregate_opts) do
{:ok, aggregate} ->
@ -160,7 +160,7 @@ defmodule Ash.Api.Interface do
opts
end
{aggregate_opts, opts} = split_aggregate_opts(opts)
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
case Ash.Query.Aggregate.new(query.resource, :exists, :exists, aggregate_opts) do
{:ok, aggregate} ->
@ -188,7 +188,7 @@ defmodule Ash.Api.Interface do
opts
end
{aggregate_opts, opts} = split_aggregate_opts(opts)
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
case Ash.Query.Aggregate.new(
query.resource,
@ -214,7 +214,7 @@ defmodule Ash.Api.Interface do
def unquote(:"#{kind}!")(query, field, opts \\ []) do
query = Ash.Query.to_query(query)
{aggregate_opts, opts} = split_aggregate_opts(opts)
{aggregate_opts, opts} = Ash.Query.Aggregate.split_aggregate_opts(opts)
opts =
if query.action do
@ -244,18 +244,6 @@ defmodule Ash.Api.Interface do
end
end
defp split_aggregate_opts(opts) do
{left, right} = Keyword.split(opts, Ash.Query.Aggregate.opt_keys())
case Keyword.fetch(left, :authorize?) do
{:ok, value} ->
{left, Keyword.put(right, :authorize?, value)}
:error ->
{left, right}
end
end
def aggregate(query, aggregate_or_aggregates, opts \\ []) do
case Api.aggregate(__MODULE__, query, aggregate_or_aggregates, opts) do
{:ok, result} ->

View file

@ -123,12 +123,22 @@ defmodule Ash.Query.Aggregate do
end)
with {:ok, opts} <- Spark.OptionsHelpers.validate(opts, @schema) do
query =
opts[:query] || Ash.Query.new(Ash.Resource.Info.related(resource, opts[:path] || []))
query =
if opts[:read_action] && query.__validated_for_action__ != opts[:read_action] do
Ash.Query.for_read(query, opts[:read_action], %{})
else
query
end
new(
resource,
name,
kind,
opts[:path] || [],
opts[:query] || Ash.Query.new(Ash.Resource.Info.related(resource, opts[:path] || [])),
query,
opts[:field],
opts[:default],
Keyword.get(opts, :filterable?, true),
@ -267,6 +277,28 @@ defmodule Ash.Query.Aggregate do
end
end
@doc false
def split_aggregate_opts(opts) do
{left, right} = Keyword.split(opts, opt_keys())
right =
case Keyword.fetch(left, :authorize?) do
{:ok, value} ->
Keyword.put(right, :authorize?, value)
:error ->
right
end
case Keyword.fetch(right, :action) do
{:ok, action} ->
{Keyword.put(left, :read_action, action), right}
:error ->
{left, right}
end
end
def default_value(:count), do: 0
def default_value(:first), do: nil
def default_value(:sum), do: nil
@ -476,6 +508,13 @@ defmodule Ash.Query.Aggregate do
request_path ++
[:aggregate, relationship_path, {authorize?, read_action}, :authorization_filter]
action =
if read_action do
Ash.Resource.Info.action(related, read_action)
else
Ash.Resource.Info.primary_action!(related, :read)
end
Request.new(
resource: related,
api: initial_query.api,
@ -483,7 +522,7 @@ defmodule Ash.Query.Aggregate do
query: related_query,
path: request_path ++ [:aggregate, relationship_path, {authorize?, read_action}],
strict_check_only?: true,
action: Ash.Resource.Info.primary_action(related, :read),
action: action,
name: "authorize aggregate: #{Enum.join(relationship_path, ".")}",
data:
Request.resolve([auth_filter_path], fn data ->

View file

@ -46,6 +46,7 @@ defmodule Ash.Test.Actions.AggregateTest do
actions do
defaults [:create, :read, :update, :destroy]
read :unpublic
end
attributes do
@ -70,9 +71,13 @@ defmodule Ash.Test.Actions.AggregateTest do
end
policies do
policy always() do
policy action(:read) do
authorize_if expr(public == true)
end
policy action(:unpublic) do
authorize_if expr(public == false)
end
end
end
@ -124,6 +129,17 @@ defmodule Ash.Test.Actions.AggregateTest do
assert %{count: 0} = Api.aggregate!(Post, {:count, :count}, authorize?: true)
assert 0 = Api.count!(Post, authorize?: true)
assert %{count: 1} =
Post
|> Ash.Query.for_read(:unpublic)
|> Api.aggregate!({:count, :count}, actor: nil)
assert %{count: 1} =
Api.aggregate!(Post, {:count, :count}, actor: nil, action: :unpublic)
assert 1 = Api.count!(Post, actor: nil, action: :unpublic)
assert 0 = Api.count!(Post, authorize?: true)
Post
|> Ash.Changeset.for_create(:create, %{title: "title", public: true})
|> Api.create!()