fix: fix aggregate/base filters

This commit is contained in:
Zach Daniel 2022-02-11 16:06:51 -05:00
parent a6577d5175
commit e6ff1d8b4b
7 changed files with 315 additions and 178 deletions

View file

@ -18,7 +18,7 @@ jobs:
matrix: matrix:
otp: ["23"] otp: ["23"]
elixir: ["1.11.0"] elixir: ["1.11.0"]
ash: ["master", "1.50.19"] ash: ["master", "1.50.20"]
pg_version: ["9.6", "11"] pg_version: ["9.6", "11"]
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View file

@ -286,6 +286,20 @@ defmodule AshPostgres.Aggregate do
false false
) )
{:ok, aggregate_query} =
if aggregate.query && aggregate.query.filter do
AshPostgres.Join.join_all_relationships(
aggregate_query,
aggregate.query.filter,
nil,
[],
nil,
true
)
else
{:ok, aggregate_query}
end
new_aggregate_query = add_subquery_aggregate_select(aggregate_query, aggregate, resource) new_aggregate_query = add_subquery_aggregate_select(aggregate_query, aggregate, resource)
put_in(join.source.from.source.query, new_aggregate_query) put_in(join.source.from.source.query, new_aggregate_query)
@ -298,7 +312,7 @@ defmodule AshPostgres.Aggregate do
end end
def used_aggregates(filter, relationship, used_calculations, path) do def used_aggregates(filter, relationship, used_calculations, path) do
Ash.Filter.used_aggregates(filter, path ++ [relationship.name]) ++ Ash.Filter.used_aggregates(filter, path) ++
Enum.flat_map( Enum.flat_map(
used_calculations, used_calculations,
fn calculation -> fn calculation ->

View file

@ -600,7 +600,8 @@ defmodule AshPostgres.Expr do
defp ref_binding(%{attribute: %Ash.Resource.Attribute{}} = ref, bindings) do defp ref_binding(%{attribute: %Ash.Resource.Attribute{}} = ref, bindings) do
Enum.find_value(bindings.bindings, fn {binding, data} -> Enum.find_value(bindings.bindings, fn {binding, data} ->
data.path == ref.relationship_path && data.type in [:inner, :left, :root] && binding data.path == ref.relationship_path && data.type in [:inner, :left, :root, :aggregate] &&
binding
end) end)
end end

View file

@ -23,7 +23,14 @@ defmodule AshPostgres.Join do
@known_inner_join_predicates @known_inner_join_functions ++ @known_inner_join_operators @known_inner_join_predicates @known_inner_join_functions ++ @known_inner_join_operators
def join_all_relationships(query, filter, relationship_paths \\ nil, path \\ [], source \\ nil) do def join_all_relationships(
query,
filter,
relationship_paths \\ nil,
path \\ [],
source \\ nil,
use_root_query_bindings? \\ false
) do
relationship_paths = relationship_paths =
relationship_paths || relationship_paths ||
filter filter
@ -63,7 +70,8 @@ defmodule AshPostgres.Join do
Enum.map(path, & &1.name), Enum.map(path, & &1.name),
current_join_type, current_join_type,
source, source,
filter filter,
use_root_query_bindings?
) do ) do
{:ok, joined_query} -> {:ok, joined_query} ->
joined_query_with_distinct = add_distinct(relationship, join_type, joined_query) joined_query_with_distinct = add_distinct(relationship, join_type, joined_query)
@ -73,7 +81,8 @@ defmodule AshPostgres.Join do
filter, filter,
[{join_type, rest_rels}], [{join_type, rest_rels}],
current_path, current_path,
source source,
use_root_query_bindings?
) do ) do
{:ok, query} -> {:ok, query} ->
{:cont, {:ok, query}} {:cont, {:ok, query}}
@ -98,14 +107,22 @@ defmodule AshPostgres.Join do
relationship_path_to_relationships(relationship.destination, rest, [relationship | acc]) relationship_path_to_relationships(relationship.destination, rest, [relationship | acc])
end end
def maybe_get_resource_query(resource, relationship, root_query) do def maybe_get_resource_query(
resource,
relationship,
root_query,
path \\ [],
use_root_query_bindings? \\ false
) do
resource resource
|> Ash.Query.new() |> Ash.Query.new(nil, base_filter?: false)
|> Map.put(:context, root_query.__ash_bindings__.context) |> Ash.Query.set_context(root_query.__ash_bindings__.context)
|> Ash.Query.set_context(relationship.context) |> Ash.Query.set_context(relationship.context)
|> Ash.Query.do_filter(relationship.filter) |> Ash.Query.do_filter(relationship.filter)
|> case do |> case do
%{valid?: true} = query -> %{valid?: true} = query ->
ash_query = query
initial_query = %{ initial_query = %{
AshPostgres.DataLayer.resource_to_query(resource, nil) AshPostgres.DataLayer.resource_to_query(resource, nil)
| prefix: Map.get(root_query, :prefix) | prefix: Map.get(root_query, :prefix)
@ -116,7 +133,15 @@ defmodule AshPostgres.Join do
initial_query: initial_query initial_query: initial_query
) do ) do
{:ok, query} -> {:ok, query} ->
{:ok, AshPostgres.DataLayer.default_bindings(query, resource)} {:ok,
do_base_filter(
query,
root_query,
ash_query,
resource,
path,
use_root_query_bindings?
)}
{:error, error} -> {:error, error} ->
{:error, error} {:error, error}
@ -127,6 +152,34 @@ defmodule AshPostgres.Join do
end end
end end
defp do_base_filter(query, root_query, ash_query, resource, path, use_root_query_bindings?) do
case Ash.Resource.Info.base_filter(resource) do
nil ->
query
filter ->
filter =
resource
|> Ash.Filter.parse!(
filter,
ash_query.aggregates,
ash_query.calculations,
ash_query.context
)
dynamic =
if use_root_query_bindings? do
filter = Ash.Filter.move_to_relationship_path(filter, path)
AshPostgres.Expr.dynamic_expr(root_query, filter, root_query.__ash_bindings__, true)
else
AshPostgres.Expr.dynamic_expr(query, filter, query.__ash_bindings__, true)
end
from(row in query, where: ^dynamic)
end
end
def set_join_prefix(join_query, query, resource) do def set_join_prefix(join_query, query, resource) do
if Ash.Resource.Info.multitenancy_strategy(resource) == :context do if Ash.Resource.Info.multitenancy_strategy(resource) == :context do
%{join_query | prefix: query.prefix || "public"} %{join_query | prefix: query.prefix || "public"}
@ -219,13 +272,29 @@ defmodule AshPostgres.Join do
end end
end end
defp join_relationship(query, relationship, path, join_type, source, filter) do defp join_relationship(
query,
relationship,
path,
join_type,
source,
filter,
use_root_query_bindings?
) do
case Map.get(query.__ash_bindings__.bindings, path) do case Map.get(query.__ash_bindings__.bindings, path) do
%{type: existing_join_type} when join_type != existing_join_type -> %{type: existing_join_type} when join_type != existing_join_type ->
raise "unreachable?" raise "unreachable?"
nil -> nil ->
do_join_relationship(query, relationship, path, join_type, source, filter) do_join_relationship(
query,
relationship,
path,
join_type,
source,
filter,
use_root_query_bindings?
)
_ -> _ ->
{:ok, query} {:ok, query}
@ -238,14 +307,76 @@ defmodule AshPostgres.Join do
path, path,
kind, kind,
source, source,
filter filter,
use_root_query_bindings?
) do ) do
join_relationship = Ash.Resource.Info.relationship(source, relationship.join_relationship) join_relationship = Ash.Resource.Info.relationship(source, relationship.join_relationship)
join_path =
Enum.reverse([
String.to_existing_atom(to_string(relationship.name) <> "_join_assoc") | path
])
full_path = Enum.reverse([relationship.name | path])
initial_ash_bindings = query.__ash_bindings__
binding_data =
case kind do
{:aggregate, name, _agg} ->
%{type: :aggregate, name: name, path: full_path, source: source}
_ ->
%{type: kind, path: full_path, source: source}
end
additional_binding? =
case kind do
{:aggregate, _, _subquery} ->
false
_ ->
true
end
query =
case kind do
{:aggregate, _, _subquery} ->
additional_bindings =
if additional_binding? do
1
else
0
end
query
|> AshPostgres.DataLayer.add_binding(binding_data, additional_bindings)
_ ->
query
|> AshPostgres.DataLayer.add_binding(%{
path: join_path,
type: :left,
source: source
})
|> AshPostgres.DataLayer.add_binding(binding_data)
end
with {:ok, relationship_through} <- with {:ok, relationship_through} <-
maybe_get_resource_query(relationship.through, join_relationship, query), maybe_get_resource_query(
relationship.through,
join_relationship,
query,
join_path,
use_root_query_bindings?
),
{:ok, relationship_destination} <- {:ok, relationship_destination} <-
maybe_get_resource_query(relationship.destination, relationship, query) do maybe_get_resource_query(
relationship.destination,
relationship,
query,
use_root_query_bindings?
) do
relationship_through = relationship_through =
relationship_through relationship_through
|> Ecto.Queryable.to_query() |> Ecto.Queryable.to_query()
@ -266,7 +397,7 @@ defmodule AshPostgres.Join do
end end
current_binding = current_binding =
Enum.find_value(query.__ash_bindings__.bindings, 0, fn {binding, data} -> Enum.find_value(initial_ash_bindings.bindings, 0, fn {binding, data} ->
if data.type == binding_kind && data.path == Enum.reverse(path) do if data.type == binding_kind && data.path == Enum.reverse(path) do
binding binding
end end
@ -276,23 +407,18 @@ defmodule AshPostgres.Join do
Ash.Filter.used_calculations( Ash.Filter.used_calculations(
filter, filter,
relationship.destination, relationship.destination,
path ++ [relationship.name] full_path
) )
used_aggregates = used_aggregates =
AshPostgres.Aggregate.used_aggregates(filter, relationship, used_calculations, path) filter
|> AshPostgres.Aggregate.used_aggregates(relationship, used_calculations, full_path)
Enum.reduce_while(used_aggregates, {:ok, relationship_destination}, fn agg, {:ok, query} -> |> Enum.map(fn aggregate ->
agg = %{agg | load: agg.name} %{aggregate | load: aggregate.name}
case AshPostgres.Aggregate.add_aggregates(query, [agg], relationship.destination) do
{:ok, query} ->
{:cont, {:ok, query}}
{:error, error} ->
{:halt, {:error, error}}
end
end) end)
relationship_destination
|> AshPostgres.Aggregate.add_aggregates(used_aggregates, relationship.destination)
|> case do |> case do
{:ok, relationship_destination} -> {:ok, relationship_destination} ->
relationship_destination = relationship_destination =
@ -304,7 +430,6 @@ defmodule AshPostgres.Join do
subquery(relationship_destination) subquery(relationship_destination)
end end
{new_query, additional_binding?} =
case kind do case kind do
{:aggregate, _, subquery} -> {:aggregate, _, subquery} ->
subquery = subquery =
@ -315,55 +440,60 @@ defmodule AshPostgres.Join do
relationship relationship
) )
q = {:ok,
from([{row, current_binding}] in query, from([{row, current_binding}] in query,
left_lateral_join: through in ^subquery, left_lateral_join: through in ^subquery,
as: ^query.__ash_bindings__.current as: ^initial_ash_bindings.current
) )}
{q, false}
:inner -> :inner ->
q = {:ok,
from([{row, current_binding}] in query, from([{row, current_binding}] in query,
join: through in ^relationship_through, join: through in ^relationship_through,
as: ^query.__ash_bindings__.current, as: ^initial_ash_bindings.current,
on: on:
field(row, ^relationship.source_field) == field(row, ^relationship.source_field) ==
field(through, ^relationship.source_field_on_join_table), field(through, ^relationship.source_field_on_join_table),
join: destination in ^relationship_destination, join: destination in ^relationship_destination,
as: ^query.__ash_bindings__.current, as: ^initial_ash_bindings.current,
on: on:
field(destination, ^relationship.destination_field) == field(destination, ^relationship.destination_field) ==
field(through, ^relationship.destination_field_on_join_table) field(through, ^relationship.destination_field_on_join_table)
) )}
{q, true}
_ -> _ ->
q = {:ok,
from([{row, current_binding}] in query, from([{row, current_binding}] in query,
left_join: through in ^relationship_through, left_join: through in ^relationship_through,
as: ^query.__ash_bindings__.current, as: ^initial_ash_bindings.current,
on: on:
field(row, ^relationship.source_field) == field(row, ^relationship.source_field) ==
field(through, ^relationship.source_field_on_join_table), field(through, ^relationship.source_field_on_join_table),
left_join: destination in ^relationship_destination, left_join: destination in ^relationship_destination,
as: ^(query.__ash_bindings__.current + 1), as: ^(initial_ash_bindings.current + 1),
on: on:
field(destination, ^relationship.destination_field) == field(destination, ^relationship.destination_field) ==
field(through, ^relationship.destination_field_on_join_table) field(through, ^relationship.destination_field_on_join_table)
) )}
{q, true}
end end
join_path = {:error, error} ->
Enum.reverse([ {:error, error}
String.to_existing_atom(to_string(relationship.name) <> "_join_assoc") | path end
]) end
end
defp do_join_relationship(
query,
relationship,
path,
kind,
source,
filter,
use_root_query_bindings?
) do
full_path = Enum.reverse([relationship.name | path]) full_path = Enum.reverse([relationship.name | path])
initial_ash_bindings = query.__ash_bindings__
binding_data = binding_data =
case kind do case kind do
@ -374,38 +504,15 @@ defmodule AshPostgres.Join do
%{type: kind, path: full_path, source: source} %{type: kind, path: full_path, source: source}
end end
case kind do query = AshPostgres.DataLayer.add_binding(query, binding_data)
{:aggregate, _, _subquery} ->
additional_bindings =
if additional_binding? do
1
else
0
end
{:ok, case maybe_get_resource_query(
new_query relationship.destination,
|> AshPostgres.DataLayer.add_binding(binding_data, additional_bindings)} relationship,
query,
_ -> full_path,
{:ok, use_root_query_bindings?
new_query ) do
|> AshPostgres.DataLayer.add_binding(%{
path: join_path,
type: :left,
source: source
})
|> AshPostgres.DataLayer.add_binding(binding_data)}
end
{:error, error} ->
{:error, error}
end
end
end
defp do_join_relationship(query, relationship, path, kind, source, filter) do
case maybe_get_resource_query(relationship.destination, relationship, query) do
{:error, error} -> {:error, error} ->
{:error, error} {:error, error}
@ -425,7 +532,7 @@ defmodule AshPostgres.Join do
end end
current_binding = current_binding =
Enum.find_value(query.__ash_bindings__.bindings, 0, fn {binding, data} -> Enum.find_value(initial_ash_bindings.bindings, 0, fn {binding, data} ->
if data.type == binding_kind && data.path == Enum.reverse(path) do if data.type == binding_kind && data.path == Enum.reverse(path) do
binding binding
end end
@ -435,24 +542,18 @@ defmodule AshPostgres.Join do
Ash.Filter.used_calculations( Ash.Filter.used_calculations(
filter, filter,
relationship.destination, relationship.destination,
path ++ [relationship.name] full_path
) )
used_aggregates = used_aggregates =
AshPostgres.Aggregate.used_aggregates(filter, relationship, used_calculations, path) filter
|> AshPostgres.Aggregate.used_aggregates(relationship, used_calculations, full_path)
Enum.reduce_while(used_aggregates, {:ok, relationship_destination}, fn agg, |> Enum.map(fn aggregate ->
{:ok, query} -> %{aggregate | load: aggregate.name}
agg = %{agg | load: agg.name}
case AshPostgres.Aggregate.add_aggregates(query, [agg], relationship.destination) do
{:ok, query} ->
{:cont, {:ok, query}}
{:error, error} ->
{:halt, {:error, error}}
end
end) end)
relationship_destination
|> AshPostgres.Aggregate.add_aggregates(used_aggregates, relationship.destination)
|> case do |> case do
{:ok, relationship_destination} -> {:ok, relationship_destination} ->
relationship_destination = relationship_destination =
@ -464,7 +565,6 @@ defmodule AshPostgres.Join do
subquery(relationship_destination) subquery(relationship_destination)
end end
new_query =
case kind do case kind do
{:aggregate, _, subquery} -> {:aggregate, _, subquery} ->
subquery = subquery =
@ -475,45 +575,33 @@ defmodule AshPostgres.Join do
relationship relationship
) )
{:ok,
from([{row, current_binding}] in query, from([{row, current_binding}] in query,
left_lateral_join: destination in ^subquery, left_lateral_join: destination in ^subquery,
as: ^query.__ash_bindings__.current as: ^initial_ash_bindings.current
) )}
:inner -> :inner ->
{:ok,
from([{row, current_binding}] in query, from([{row, current_binding}] in query,
join: destination in ^relationship_destination, join: destination in ^relationship_destination,
as: ^query.__ash_bindings__.current, as: ^initial_ash_bindings.current,
on: on:
field(row, ^relationship.source_field) == field(row, ^relationship.source_field) ==
field(destination, ^relationship.destination_field) field(destination, ^relationship.destination_field)
) )}
_ -> _ ->
{:ok,
from([{row, current_binding}] in query, from([{row, current_binding}] in query,
left_join: destination in ^relationship_destination, left_join: destination in ^relationship_destination,
as: ^query.__ash_bindings__.current, as: ^initial_ash_bindings.current,
on: on:
field(row, ^relationship.source_field) == field(row, ^relationship.source_field) ==
field(destination, ^relationship.destination_field) field(destination, ^relationship.destination_field)
) )}
end end
full_path = Enum.reverse([relationship.name | path])
binding_data =
case kind do
{:aggregate, name, _agg} ->
%{type: :aggregate, name: name, path: full_path, source: source}
_ ->
%{type: kind, path: full_path, source: source}
end
{:ok,
new_query
|> AshPostgres.DataLayer.add_binding(binding_data)}
{:error, error} -> {:error, error} ->
{:error, error} {:error, error}
end end

View file

@ -97,7 +97,7 @@ defmodule AshPostgres.MixProject do
{:ecto, github: "elixir-ecto/ecto", branch: "master", override: true}, {:ecto, github: "elixir-ecto/ecto", branch: "master", override: true},
{:jason, "~> 1.0"}, {:jason, "~> 1.0"},
{:postgrex, ">= 0.0.0"}, {:postgrex, ">= 0.0.0"},
{:ash, ash_version("~> 1.50 and >= 1.50.19")}, {:ash, ash_version("~> 1.50 and >= 1.50.20")},
{:git_ops, "~> 2.4.5", only: :dev}, {:git_ops, "~> 2.4.5", only: :dev},
{:ex_doc, "~> 0.22", only: :dev, runtime: false}, {:ex_doc, "~> 0.22", only: :dev, runtime: false},
{:ex_check, "~> 0.11.0", only: :dev}, {:ex_check, "~> 0.11.0", only: :dev},

View file

@ -494,5 +494,34 @@ defmodule AshPostgres.AggregateTest do
]) ])
|> Api.read_one!() |> Api.read_one!()
end end
test "a count aggregate with a related filter returns the proper value" do
post =
Post
|> Ash.Changeset.new(%{title: "title", category: "foo"})
|> Api.create!()
Comment
|> Ash.Changeset.for_create(:create, %{title: "match"})
|> Ash.Changeset.replace_relationship(:post, post)
|> Api.create!()
Comment
|> Ash.Changeset.for_create(:create, %{title: "match"})
|> Ash.Changeset.replace_relationship(:post, post)
|> Api.create!()
Comment
|> Ash.Changeset.for_create(:create, %{title: "match"})
|> Ash.Changeset.replace_relationship(:post, post)
|> Api.create!()
assert %Post{count_of_comments_that_have_a_post: 3} =
Post
|> Ash.Query.load([
:count_of_comments_that_have_a_post
])
|> Api.read_one!()
end
end end
end end

View file

@ -117,6 +117,11 @@ defmodule AshPostgres.Test.Post do
filter(title: "match") filter(title: "match")
end end
# All of them will, but we want to test a related field
count :count_of_comments_that_have_a_post, :comments do
filter(expr(not is_nil(post.id)))
end
sum :sum_of_recent_popular_comment_likes, :popular_comments, :likes do sum :sum_of_recent_popular_comment_likes, :popular_comments, :likes do
# not(is_nil(post_category)) is silly but its here for tests # not(is_nil(post_category)) is silly but its here for tests
filter(expr(created_at > ago(10, :day) and not is_nil(post_category))) filter(expr(created_at > ago(10, :day) and not is_nil(post_category)))