From 60e0dad7b6d27432f05ba9210aefac9b0b560df6 Mon Sep 17 00:00:00 2001 From: Zach Daniel Date: Thu, 31 Dec 2020 18:13:53 -0500 Subject: [PATCH] improvement: rework filter creation + subset checking This is one of the most complicated parts of Ash. In order to pass a filter statement to the satisfiability solver that we use, we have to first transpile a *value* statement into a *boolean* statement. This means that we need to embed the knowledge of mutual exclusivity wherever possible. Authorization still works if the system doesn't know the relationship between two value statements, as it will attach the authorization filters if its not sure. But having this in place should represent a fairly significant optimization in many cases. Additionally, filter creation has a set of optimizations around the `eq` and `in` operators to combine them whlie building a boolean statement --- lib/ash/filter/filter.ex | 47 +++++- lib/ash/query/expression.ex | 156 ++++++++++++++++++-- lib/ash/query/operator/in.ex | 38 ++++- lib/ash/query/query.ex | 12 +- lib/sat_solver.ex | 267 ++++++++++++++++++++++++++--------- 5 files changed, 421 insertions(+), 99 deletions(-) diff --git a/lib/ash/filter/filter.ex b/lib/ash/filter/filter.ex index b2235f9c..94b45659 100644 --- a/lib/ash/filter/filter.ex +++ b/lib/ash/filter/filter.ex @@ -780,19 +780,19 @@ defmodule Ash.Filter do end) end - def map(%__MODULE__{expression: nil} = filter, _) do + defp map(%__MODULE__{expression: nil} = filter, _) do filter end - def map(%__MODULE__{expression: expression} = filter, func) do + defp map(%__MODULE__{expression: expression} = filter, func) do %{filter | expression: do_map(func.(expression), func)} end - def map(expression, func) do + defp map(expression, func) do do_map(func.(expression), func) end - def do_map(expression, func) do + defp do_map(expression, func) do case expression do {:halt, expr} -> expr @@ -814,6 +814,45 @@ defmodule Ash.Filter do end end + @doc false + def embed_predicates(nil), do: nil + + def embed_predicates(%__MODULE__{expression: expression} = filter) do + %{filter | expression: embed_predicates(expression)} + end + + def embed_predicates(%Not{expression: expression} = not_expr) do + %{not_expr | expression: embed_predicates(expression)} + end + + def embed_predicates(%Expression{left: left, right: right} = expr) do + %{expr | left: embed_predicates(left), right: embed_predicates(right)} + end + + def embed_predicates(%{__predicate__?: true} = pred) do + %{pred | embedded?: true} + end + + def embed_predicates(other), do: other + + def find(%__MODULE__{expression: nil}, _), do: nil + + def find(%__MODULE__{expression: expression}, func) do + find(expression, func) + end + + def find(%Expression{left: left, right: right}, func) do + find(left, func) || find(right, func) + end + + def find(%Not{expression: not_expr}, func) do + find(not_expr, func) + end + + def find(other, func) do + if func.(other), do: other + end + def list_predicates(%__MODULE__{expression: expression}) do list_predicates(expression) end diff --git a/lib/ash/query/expression.ex b/lib/ash/query/expression.ex index 499bcba2..092824b2 100644 --- a/lib/ash/query/expression.ex +++ b/lib/ash/query/expression.ex @@ -2,6 +2,7 @@ defmodule Ash.Query.Expression do @moduledoc "Represents a boolean expression" alias Ash.Query.Operator.{Eq, In} + alias Ash.Query.Ref defstruct [:op, :left, :right] @@ -13,45 +14,170 @@ defmodule Ash.Query.Expression do %__MODULE__{op: op, left: left, right: right} end - def optimized_new(_, nil, nil), do: nil - def optimized_new(:and, false, _), do: false - def optimized_new(:and, _, false), do: false - def optimized_new(:or, true, _), do: true - def optimized_new(:or, _, true), do: true - def optimized_new(_, nil, right), do: right - def optimized_new(_, left, nil), do: left + def optimized_new(op, left, right, current_op \\ :and) + def optimized_new(_, nil, nil, _), do: nil + def optimized_new(:and, false, _, _), do: false + def optimized_new(:and, _, false, _), do: false + def optimized_new(:or, true, _, _), do: true + def optimized_new(:or, _, true, _), do: true + def optimized_new(_, nil, right, _), do: right + def optimized_new(_, left, nil, _), do: left - def optimized_new(op, left, right) when left > right do - optimized_new(op, right, left) + def optimized_new( + op, + %__MODULE__{op: op} = left_expr, + %__MODULE__{ + op: op, + left: left, + right: right + }, + op + ) do + optimized_new(op, optimized_new(op, left_expr, left, op), right, op) end - def optimized_new(op, %In{} = left, %Eq{} = right) do - optimized_new(op, left, right) + def optimized_new(op, %__MODULE__{} = left, %__MODULE__{} = right, _) do + do_new(op, left, right) end - def optimized_new(:or, %Eq{left: left, right: value}, %In{left: left, right: mapset} = right) do + def optimized_new(op, left, %__MODULE__{} = right, current_op) do + optimized_new(op, right, left, current_op) + end + + def optimized_new(op, %In{} = left, %Eq{} = right, current_op) do + optimized_new(op, right, left, current_op) + end + + def optimized_new(op, %Eq{right: %Ref{}} = left, right, _) do + do_new(op, left, right) + end + + def optimized_new(op, left, %Eq{right: %Ref{}} = right, _) do + do_new(op, left, right) + end + + def optimized_new( + :or, + %Eq{left: left, right: value}, + %In{left: left, right: %{__struct__: MapSet} = mapset} = right, + _ + ) do %{right | right: MapSet.put(mapset, value)} end - def optimized_new(:or, %Eq{left: left, right: left_value}, %Eq{left: left, right: right_value}) do + def optimized_new( + :and, + %Eq{left: left, right: value} = left_expr, + %In{left: left, right: %{__struct__: MapSet} = mapset}, + _ + ) do + if MapSet.member?(mapset, value) do + left_expr + else + false + end + end + + def optimized_new( + :or, + %Eq{left: left, right: left_value}, + %Eq{left: left, right: right_value}, + _ + ) do %In{left: left, right: MapSet.new([left_value, right_value])} end + def optimized_new( + :and, + %Eq{left: left, right: left_value} = left_expr, + %Eq{left: left, right: right_value}, + _ + ) do + if left_value == right_value do + left_expr + else + false + end + end + def optimized_new( :or, %In{left: left, right: left_values}, - %In{left: left, right: right_values} = right + %In{left: left, right: right_values} = right, + _ ) do %{right | right: MapSet.union(left_values, right_values)} end - def optimized_new(op, left, right) do + def optimized_new( + :and, + %In{left: left, right: left_values}, + %In{left: left, right: right_values} = right, + _ + ) do + intersection = MapSet.intersection(left_values, right_values) + + case MapSet.size(intersection) do + 0 -> false + 1 -> %Eq{left: left, right: Enum.at(intersection, 0)} + _ -> %{right | right: intersection} + end + end + + def optimized_new( + op, + %__MODULE__{op: op, left: left, right: right} = left_expr, + right_expr, + op + ) do + case right_expr do + %In{} = in_op -> + with {:left, nil} <- {:left, Ash.Filter.find(left, &simplify?(&1, in_op))}, + {:right, nil} <- {:right, Ash.Filter.find(right, &simplify?(&1, in_op))} do + do_new(:or, left_expr, in_op) + else + {:left, _} -> + %{left_expr | left: optimized_new(:or, left, in_op)} + + {:right, _} -> + %{left_expr | right: optimized_new(:or, right, in_op)} + end + + %Eq{} = eq_op -> + with {:left, nil} <- {:left, Ash.Filter.find(left, &simplify?(&1, eq_op))}, + {:right, nil} <- {:right, Ash.Filter.find(right, &simplify?(&1, eq_op))} do + do_new(:or, left_expr, eq_op) + else + {:left, _} -> + %{left_expr | left: optimized_new(:or, left, eq_op)} + + {:right, _} -> + %{left_expr | right: optimized_new(:or, right, eq_op)} + end + end + end + + def optimized_new(op, left, right, _) do # TODO: more optimization passes! # Remove predicates that are on both sides of an `and` # if a predicate is on both sides of an `or`, lift it to an `and` do_new(op, left, right) end + defp simplify?(%Eq{} = left, %In{} = right), do: simplify?(right, left) + + defp simplify?(%Eq{right: %Ref{}}, _), do: false + defp simplify?(_, %Eq{right: %Ref{}}), do: false + defp simplify?(%Eq{left: left}, %Eq{left: left}), do: true + + defp simplify?( + %Eq{left: left}, + %In{left: left, right: %MapSet{}} + ), + do: true + + defp simplify?(_, _), do: false + defp do_new(op, left, right) do if left == right do left diff --git a/lib/ash/query/operator/in.ex b/lib/ash/query/operator/in.ex index 87d35ec8..86a0a0ba 100644 --- a/lib/ash/query/operator/in.ex +++ b/lib/ash/query/operator/in.ex @@ -31,6 +31,37 @@ defmodule Ash.Query.Operator.In do left in right end + def compare(%__MODULE__{left: left, right: %MapSet{} = left_right}, %__MODULE__{ + left: left, + right: %MapSet{} = right_right + }) do + if MapSet.equal?(left_right, right_right) do + :mutually_inclusive + else + if MapSet.disjoint?(left_right, right_right) do + :mutually_exclusive + else + :unknown + end + end + end + + def compare(%__MODULE__{}, %Ash.Query.Operator.Eq{right: %Ref{}}), + do: false + + def compare(%__MODULE__{left: left, right: %MapSet{} = left_right}, %Ash.Query.Operator.Eq{ + left: left, + right: value + }) do + if MapSet.member?(left_right, value) do + :left_implies_right + else + :mutually_exclusive + end + end + + def compare(_, _), do: :unknown + def to_string(%{right: %Ref{}} = op, opts), do: super(op, opts) def to_string(%{left: left, right: mapset}, opts) do @@ -48,11 +79,4 @@ defmodule Ash.Query.Operator.In do list_doc ]) end - - def simplify(%__MODULE__{left: left, right: right}) do - Enum.reduce(right, nil, fn item, expr -> - {:ok, eq} = Ash.Query.Operator.new(Ash.Query.Operator.Eq, left, item) - Ash.Query.Expression.new(:or, expr, eq) - end) - end end diff --git a/lib/ash/query/query.ex b/lib/ash/query/query.ex index eef73b07..ae28013a 100644 --- a/lib/ash/query/query.ex +++ b/lib/ash/query/query.ex @@ -145,16 +145,10 @@ defmodule Ash.Query do query filter -> - filter = Ash.Filter.parse!(resource, filter) - filter = - Ash.Filter.map(filter, fn - %{__predicate__?: true} = pred -> - %{pred | embedded?: true} - - other -> - other - end) + resource + |> Ash.Filter.parse!(filter) + |> Ash.Filter.embed_predicates() do_filter(query, filter) end diff --git a/lib/sat_solver.ex b/lib/sat_solver.ex index 874fd5f7..eeeb7873 100644 --- a/lib/sat_solver.ex +++ b/lib/sat_solver.ex @@ -304,78 +304,217 @@ defmodule Ash.SatSolver do end defp build_expr_with_predicate_information(expression) do - simplified = simplify(expression) + expression = fully_simplify(expression) - if simplified == expression do - all_predicates = - expression - |> Filter.list_predicates() - |> Enum.uniq() - - comparison_expressions = - all_predicates - |> Enum.filter(fn %module{} -> - :erlang.function_exported(module, :compare, 2) - end) - |> Enum.reduce([], fn predicate, new_expressions -> - all_predicates - |> Enum.reject(&Kernel.==(&1, predicate)) - |> Enum.filter(&shares_ref?(&1, predicate)) - |> Enum.reduce(new_expressions, fn other_predicate, new_expressions -> - # With predicate as a and other_predicate as b - case Ash.Filter.Predicate.compare(predicate, other_predicate) do - :right_includes_left -> - # b || !a - - [b(other_predicate or not predicate) | new_expressions] - - :left_includes_right -> - # a || ! b - [b(predicate or not other_predicate) | new_expressions] - - :mutually_inclusive -> - # (a && b) || (! a && ! b) - [ - b((predicate and other_predicate) or (not predicate and not other_predicate)) - | new_expressions - ] - - :mutually_exclusive -> - [b(not (other_predicate and predicate)) | new_expressions] - - _other -> - # If we can't tell, we assume that both could be true - new_expressions - end - end) - end) - |> Enum.uniq() - - expression = filter_to_expr(expression) - - expression_with_comparisons = - Enum.reduce(comparison_expressions, expression, fn comparison_expression, expression -> - b(comparison_expression and expression) - end) - - all_predicates - |> Enum.map(& &1.__struct__) + all_predicates = + expression + |> Filter.list_predicates() |> Enum.uniq() - |> Enum.flat_map(fn struct -> - if :erlang.function_exported(struct, :bulk_compare, 1) do - struct.bulk_compare(all_predicates) - else - [] - end + + comparison_expressions = + all_predicates + |> Enum.filter(fn %module{} -> + :erlang.function_exported(module, :compare, 2) end) - |> Enum.reduce(expression_with_comparisons, fn comparison_expression, expression -> + |> Enum.reduce([], fn predicate, new_expressions -> + all_predicates + |> Enum.reject(&Kernel.==(&1, predicate)) + |> Enum.filter(&shares_ref?(&1, predicate)) + |> Enum.reduce(new_expressions, fn other_predicate, new_expressions -> + # With predicate as a and other_predicate as b + case Ash.Filter.Predicate.compare(predicate, other_predicate) do + :right_includes_left -> + # b || !a + + [b(other_predicate or not predicate) | new_expressions] + + :left_includes_right -> + # a || ! b + [b(predicate or not other_predicate) | new_expressions] + + :mutually_inclusive -> + # (a && b) || (! a && ! b) + [ + b((predicate and other_predicate) or (not predicate and not other_predicate)) + | new_expressions + ] + + :mutually_exclusive -> + [b(not (other_predicate and predicate)) | new_expressions] + + _other -> + # If we can't tell, we assume that both could be true + new_expressions + end + end) + end) + |> Enum.uniq() + + expression = filter_to_expr(expression) + + expression_with_comparisons = + Enum.reduce(comparison_expressions, expression, fn comparison_expression, expression -> b(comparison_expression and expression) end) - else - build_expr_with_predicate_information(simplified) + + all_predicates + |> Enum.map(& &1.__struct__) + |> Enum.uniq() + |> Enum.flat_map(fn struct -> + if :erlang.function_exported(struct, :bulk_compare, 1) do + struct.bulk_compare(all_predicates) + else + [] + end + end) + |> Enum.reduce(expression_with_comparisons, fn comparison_expression, expression -> + b(comparison_expression and expression) + end) + end + + def fully_simplify(expression) do + expression + |> lift_equals_out_of_in() + |> do_fully_simplify() + end + + defp do_fully_simplify(expression) do + expression + |> simplify() + |> case do + ^expression -> + expression + + simplified -> + fully_simplify(simplified) end end + def lift_equals_out_of_in(expression) do + case find_non_equal_overlap(expression) do + nil -> + expression + + non_equal_overlap -> + expression + |> split_in_expressions(non_equal_overlap) + |> lift_equals_out_of_in() + end + end + + def find_non_equal_overlap(expression) do + Ash.Filter.find(expression, fn sub_expr -> + Ash.Filter.find(expression, fn sub_expr2 -> + overlap?(sub_expr, sub_expr2) + end) + end) + end + + defp new_in(base, right) do + case MapSet.size(right) do + 1 -> + %Ash.Query.Operator.Eq{left: base.left, right: Enum.at(right, 0)} + + _ -> + %Ash.Query.Operator.In{left: base.left, right: right} + end + end + + def split_in_expressions( + %Ash.Query.Operator.In{right: right} = sub_expr, + %Ash.Query.Operator.Eq{right: value} = non_equal_overlap + ) do + if overlap?(non_equal_overlap, sub_expr) do + Ash.Query.Expression.new( + :or, + new_in(sub_expr, MapSet.delete(right, value)), + non_equal_overlap + ) + else + sub_expr + end + end + + def split_in_expressions( + %Ash.Query.Operator.In{} = sub_expr, + %Ash.Query.Operator.In{right: right} = non_equal_overlap + ) do + if overlap?(sub_expr, non_equal_overlap) do + diff = MapSet.difference(sub_expr.right, right) + + if MapSet.size(diff) == 0 do + Enum.reduce(sub_expr.right, nil, fn var, acc -> + Expression.new(:or, %Ash.Query.Operator.Eq{left: sub_expr.left, right: var}, acc) + end) + else + new_right = new_in(sub_expr, MapSet.intersection(sub_expr.right, right)) + + Ash.Query.Expression.new( + :or, + new_in(sub_expr, diff), + new_right + ) + end + else + sub_expr + end + end + + def split_in_expressions(nil, _), do: nil + + def split_in_expressions(%Ash.Filter{expression: expression} = filter, non_equal_overlap), + do: %{filter | expression: split_in_expressions(expression, non_equal_overlap)} + + def split_in_expressions(%Not{expression: expression} = not_expr, non_equal_overlap), + do: %{not_expr | expression: split_in_expressions(expression, non_equal_overlap)} + + def split_in_expressions(%Expression{left: left, right: right} = expr, non_equal_overlap), + do: %{ + expr + | left: split_in_expressions(left, non_equal_overlap), + right: split_in_expressions(right, non_equal_overlap) + } + + def split_in_expressions(other, _), do: other + + def overlap?( + %Ash.Query.Operator.In{left: left, right: %MapSet{} = left_right}, + %Ash.Query.Operator.In{left: left, right: %MapSet{} = right_right} + ) do + if MapSet.equal?(left_right, right_right) do + false + else + overlap? = + left_right + |> MapSet.intersection(right_right) + |> MapSet.size() + |> Kernel.>(0) + + if overlap? do + true + else + false + end + end + end + + def overlap?(_, %Ash.Query.Operator.Eq{right: %Ref{}}), + do: false + + def overlap?(%Ash.Query.Operator.Eq{right: %Ref{}}, _), + do: false + + def overlap?( + %Ash.Query.Operator.Eq{left: left, right: left_right}, + %Ash.Query.Operator.In{left: left, right: %MapSet{} = right_right} + ) do + MapSet.member?(right_right, left_right) + end + + def overlap?(_left, _right) do + false + end + def mutually_exclusive(predicates, acc \\ []) def mutually_exclusive([], acc), do: acc