mirror of
https://github.com/ash-project/ash.git
synced 2024-09-20 21:43:02 +12:00
665a9fb5c4
1. only convert to CNF once 2. group predicates that only appear in specific combinations to limit amount of variables provided to the sat solver Number 2 above does technically slow down all cases a bit, but the optimization is really important when it matters. And cases that don't need this optimization still happen on the order microseconds anyway.
1112 lines
30 KiB
Elixir
1112 lines
30 KiB
Elixir
defmodule Ash.SatSolver do
|
|
@moduledoc """
|
|
Tools for working with the satsolver that drives filter subset checking (for authorization)
|
|
"""
|
|
|
|
alias Ash.Filter
|
|
alias Ash.Query.{BooleanExpression, Not, Ref}
|
|
|
|
@dialyzer {:nowarn_function, overlap?: 2}
|
|
|
|
defmacro b(statement) do
|
|
value =
|
|
Macro.prewalk(
|
|
statement,
|
|
fn
|
|
{:and, _, [left, right]} ->
|
|
quote do
|
|
{:and, unquote(left), unquote(right)}
|
|
end
|
|
|
|
{:or, _, [left, right]} ->
|
|
quote do
|
|
{:or, unquote(left), unquote(right)}
|
|
end
|
|
|
|
{:not, _, [value]} ->
|
|
quote do
|
|
{:not, unquote(value)}
|
|
end
|
|
|
|
other ->
|
|
other
|
|
end
|
|
)
|
|
|
|
quote do
|
|
unquote(value)
|
|
|> Ash.SatSolver.balance()
|
|
end
|
|
end
|
|
|
|
def balance({op, left, right}) do
|
|
left = balance(left)
|
|
right = balance(right)
|
|
[left, right] = Enum.sort([left, right])
|
|
|
|
{op, left, right}
|
|
end
|
|
|
|
def balance({:not, {:not, right}}) do
|
|
balance(right)
|
|
end
|
|
|
|
def balance({:not, statement}) do
|
|
{:not, balance(statement)}
|
|
end
|
|
|
|
def balance(other), do: other
|
|
|
|
def strict_filter_subset(filter, candidate) do
|
|
case {filter, candidate} do
|
|
{%{expression: nil}, %{expression: nil}} ->
|
|
true
|
|
|
|
{%{expression: nil}, _candidate_expr} ->
|
|
true
|
|
|
|
{_filter_expr, %{expression: nil}} ->
|
|
false
|
|
|
|
{filter, candidate} ->
|
|
do_strict_filter_subset(filter, candidate)
|
|
end
|
|
end
|
|
|
|
defp do_strict_filter_subset(filter, candidate) do
|
|
expr = BooleanExpression.new(:and, filter.expression, candidate.expression)
|
|
|
|
case transform_and_solve(
|
|
filter.resource,
|
|
expr
|
|
) do
|
|
{:error, :unsatisfiable} ->
|
|
false
|
|
|
|
{:ok, _scenario} ->
|
|
expr = BooleanExpression.new(:and, Not.new(filter.expression), candidate.expression)
|
|
Application.put_env(:foo, :bar, true)
|
|
|
|
case transform_and_solve(
|
|
filter.resource,
|
|
expr
|
|
) do
|
|
{:error, :unsatisfiable} ->
|
|
true
|
|
|
|
{:ok, _scenario} ->
|
|
:maybe
|
|
end
|
|
end
|
|
end
|
|
|
|
defp filter_to_expr(nil), do: nil
|
|
defp filter_to_expr(false), do: false
|
|
defp filter_to_expr(true), do: true
|
|
defp filter_to_expr(%Filter{expression: expression}), do: filter_to_expr(expression)
|
|
defp filter_to_expr(%{__predicate__?: _} = op_or_func), do: op_or_func
|
|
defp filter_to_expr(%Ash.Query.Exists{} = exists), do: exists
|
|
defp filter_to_expr(%Not{expression: expression}), do: b(not filter_to_expr(expression))
|
|
|
|
defp filter_to_expr(%BooleanExpression{op: op, left: left, right: right}) do
|
|
{op, filter_to_expr(left), filter_to_expr(right)}
|
|
end
|
|
|
|
defp filter_to_expr(expr) do
|
|
raise ArgumentError, message: "Invalid filter expression #{inspect(expr)}"
|
|
end
|
|
|
|
def transform(resource, expression) do
|
|
expression
|
|
|> consolidate_relationships(resource)
|
|
|> upgrade_related_filters_to_join_keys(resource)
|
|
|> build_expr_with_predicate_information()
|
|
end
|
|
|
|
def transform_and_solve(resource, expression) do
|
|
resource
|
|
|> transform(expression)
|
|
|> to_cnf()
|
|
|> elem(0)
|
|
|> solve_expression()
|
|
end
|
|
|
|
defp upgrade_related_filters_to_join_keys(
|
|
%BooleanExpression{op: op, left: left, right: right},
|
|
resource
|
|
) do
|
|
BooleanExpression.new(
|
|
op,
|
|
upgrade_related_filters_to_join_keys(left, resource),
|
|
upgrade_related_filters_to_join_keys(right, resource)
|
|
)
|
|
end
|
|
|
|
defp upgrade_related_filters_to_join_keys(%Not{expression: expression}, resource) do
|
|
Not.new(upgrade_related_filters_to_join_keys(expression, resource))
|
|
end
|
|
|
|
defp upgrade_related_filters_to_join_keys(
|
|
%Ash.Query.Exists{path: path, expr: expr} = exists,
|
|
resource
|
|
) do
|
|
related = Ash.Resource.Info.related(resource, path)
|
|
|
|
%{exists | expr: upgrade_related_filters_to_join_keys(expr, related)}
|
|
end
|
|
|
|
defp upgrade_related_filters_to_join_keys(
|
|
%{__operator__?: true, left: left, right: right} = op,
|
|
resource
|
|
) do
|
|
%{op | left: upgrade_ref(left, resource), right: upgrade_ref(right, resource)}
|
|
end
|
|
|
|
defp upgrade_related_filters_to_join_keys(
|
|
%{__function__?: true, arguments: arguments} = function,
|
|
resource
|
|
) do
|
|
%{function | arguments: Enum.map(arguments, &upgrade_ref(&1, resource))}
|
|
end
|
|
|
|
defp upgrade_related_filters_to_join_keys(expr, _), do: expr
|
|
|
|
defp upgrade_ref({key, ref}, resource) when is_atom(key) do
|
|
{key, upgrade_ref(ref, resource)}
|
|
end
|
|
|
|
defp upgrade_ref(
|
|
%Ash.Query.Ref{attribute: attribute, relationship_path: path} = ref,
|
|
resource
|
|
)
|
|
when path != [] do
|
|
with relationship when not is_nil(relationship) <-
|
|
Ash.Resource.Info.relationship(resource, path),
|
|
true <- attribute.name == relationship.destination_attribute,
|
|
new_attribute when not is_nil(new_attribute) <-
|
|
Ash.Resource.Info.attribute(relationship.source, relationship.source_attribute) do
|
|
%{
|
|
ref
|
|
| relationship_path: :lists.droplast(path),
|
|
attribute: new_attribute,
|
|
resource: resource
|
|
}
|
|
else
|
|
_ ->
|
|
ref
|
|
end
|
|
end
|
|
|
|
defp upgrade_ref(other, _), do: other
|
|
|
|
defp consolidate_relationships(expression, resource) do
|
|
{replacements, _all_relationship_paths} =
|
|
expression
|
|
|> Filter.relationship_paths(true)
|
|
|> Enum.uniq()
|
|
|> Enum.reduce({%{}, []}, fn path, {replacements, kept_paths} ->
|
|
case find_synonymous_relationship_path(resource, kept_paths, path) do
|
|
nil ->
|
|
{replacements, [path | kept_paths]}
|
|
|
|
synonymous_path ->
|
|
Map.put(replacements, path, synonymous_path)
|
|
end
|
|
end)
|
|
|
|
do_consolidate_relationships(expression, replacements, resource)
|
|
end
|
|
|
|
defp do_consolidate_relationships(
|
|
%BooleanExpression{op: op, left: left, right: right},
|
|
replacements,
|
|
resource
|
|
) do
|
|
BooleanExpression.new(
|
|
op,
|
|
do_consolidate_relationships(left, replacements, resource),
|
|
do_consolidate_relationships(right, replacements, resource)
|
|
)
|
|
end
|
|
|
|
defp do_consolidate_relationships(%Not{expression: expression}, replacements, resource) do
|
|
Not.new(do_consolidate_relationships(expression, replacements, resource))
|
|
end
|
|
|
|
defp do_consolidate_relationships(
|
|
%Ash.Query.Exists{at_path: at_path, path: path, expr: expr} = exists,
|
|
replacements,
|
|
resource
|
|
) do
|
|
exists =
|
|
case Map.fetch(replacements, at_path) do
|
|
{:ok, replacement} when not is_nil(replacement) ->
|
|
%{exists | at_path: replacement}
|
|
|
|
:error ->
|
|
exists
|
|
end
|
|
|
|
related = Ash.Resource.Info.related(resource, at_path)
|
|
|
|
{replacements, _all_relationship_paths} =
|
|
expr
|
|
|> Filter.relationship_paths(true)
|
|
|> Enum.uniq()
|
|
|> Enum.reduce({%{}, []}, fn path, {replacements, kept_paths} ->
|
|
case find_synonymous_relationship_path(related, kept_paths, path) do
|
|
nil ->
|
|
{replacements, [path | kept_paths]}
|
|
|
|
synonymous_path ->
|
|
Map.put(replacements, path, synonymous_path)
|
|
end
|
|
end)
|
|
|
|
exists =
|
|
case Map.fetch(replacements, path) do
|
|
{:ok, replacement} when not is_nil(replacement) ->
|
|
%{exists | path: replacement}
|
|
|
|
:error ->
|
|
exists
|
|
end
|
|
|
|
full_related = Ash.Resource.Info.related(related, path)
|
|
|
|
%{exists | expr: consolidate_relationships(expr, full_related)}
|
|
end
|
|
|
|
defp do_consolidate_relationships(
|
|
%Ash.Query.Ref{relationship_path: path} = ref,
|
|
replacements,
|
|
_resource
|
|
)
|
|
when path != [] do
|
|
case Map.fetch(replacements, path) do
|
|
{:ok, replacement} when not is_nil(replacement) -> %{ref | relationship_path: replacement}
|
|
:error -> ref
|
|
end
|
|
end
|
|
|
|
defp do_consolidate_relationships(
|
|
%{__function__?: true, arguments: args} = func,
|
|
replacements,
|
|
resource
|
|
) do
|
|
%{func | arguments: Enum.map(args, &do_consolidate_relationships(&1, replacements, resource))}
|
|
end
|
|
|
|
defp do_consolidate_relationships(
|
|
%{__operator__?: true, left: left, right: right} = op,
|
|
replacements,
|
|
resource
|
|
) do
|
|
%{
|
|
op
|
|
| left: do_consolidate_relationships(left, replacements, resource),
|
|
right: do_consolidate_relationships(right, replacements, resource)
|
|
}
|
|
end
|
|
|
|
defp do_consolidate_relationships(other, _, _), do: other
|
|
|
|
defp find_synonymous_relationship_path(resource, paths, path) do
|
|
Enum.find_value(paths, fn candidate_path ->
|
|
if synonymous_relationship_paths?(resource, candidate_path, path) do
|
|
candidate_path
|
|
else
|
|
false
|
|
end
|
|
end)
|
|
end
|
|
|
|
# def synonymous_relationship_paths?(_, [], []), do: true
|
|
|
|
# def synonymous_relationship_paths?(_resource, candidate_path, path)
|
|
# when length(candidate_path) != length(path),
|
|
# do: false
|
|
|
|
# def synonymous_relationship_paths?(resource, [candidate_first | candidate_rest], [first | rest])
|
|
# when first == candidate_first do
|
|
# synonymous_relationship_paths?(
|
|
# Ash.Resource.Info.relationship(resource, candidate_first).destination,
|
|
# candidate_rest,
|
|
# rest
|
|
# )
|
|
# end
|
|
|
|
def synonymous_relationship_paths?(
|
|
left_resource,
|
|
candidate,
|
|
search,
|
|
right_resource \\ nil
|
|
)
|
|
|
|
def synonymous_relationship_paths?(_, [], [], _), do: true
|
|
def synonymous_relationship_paths?(_, [], _, _), do: false
|
|
def synonymous_relationship_paths?(_, _, [], _), do: false
|
|
|
|
def synonymous_relationship_paths?(
|
|
left_resource,
|
|
[candidate_first | candidate_rest],
|
|
[first | rest],
|
|
right_resource
|
|
) do
|
|
right_resource = right_resource || left_resource
|
|
relationship = Ash.Resource.Info.relationship(left_resource, first)
|
|
candidate_relationship = Ash.Resource.Info.relationship(right_resource, candidate_first)
|
|
|
|
cond do
|
|
!relationship || !candidate_relationship ->
|
|
false
|
|
|
|
relationship.type == :many_to_many && candidate_relationship.type == :has_many ->
|
|
synonymous_relationship_paths?(left_resource, [relationship.join_relationship], [
|
|
candidate_first
|
|
]) &&
|
|
synonymous_relationship_paths?(
|
|
left_resource,
|
|
candidate_rest,
|
|
rest,
|
|
right_resource
|
|
)
|
|
|
|
relationship.type == :has_many && candidate_relationship.type == :many_to_many ->
|
|
synonymous_relationship_paths?(left_resource, [relationship.name], [
|
|
candidate_relationship.join_relationship
|
|
]) &&
|
|
synonymous_relationship_paths?(
|
|
left_resource,
|
|
candidate_rest,
|
|
rest,
|
|
right_resource
|
|
)
|
|
|
|
true ->
|
|
comparison_keys = [
|
|
:source_attribute,
|
|
:destination_attribute,
|
|
:source_attribute_on_join_resource,
|
|
:destination_attribute_on_join_resource,
|
|
:destination_attribute,
|
|
:destination
|
|
]
|
|
|
|
Map.take(relationship, comparison_keys) ==
|
|
Map.take(candidate_relationship, comparison_keys) and
|
|
synonymous_relationship_paths?(relationship.destination, candidate_rest, rest)
|
|
end
|
|
end
|
|
|
|
defp build_expr_with_predicate_information(expression) do
|
|
expression = fully_simplify(expression)
|
|
|
|
all_predicates =
|
|
expression
|
|
|> Filter.list_predicates()
|
|
|> Enum.uniq()
|
|
|
|
comparison_expressions =
|
|
all_predicates
|
|
|> Enum.filter(fn %module{} ->
|
|
:erlang.function_exported(module, :compare, 2)
|
|
end)
|
|
|> Enum.reduce([], fn predicate, new_expressions ->
|
|
all_predicates
|
|
|> Enum.reject(&Kernel.==(&1, predicate))
|
|
|> Enum.filter(&shares_ref?(&1, predicate))
|
|
|> Enum.reduce(new_expressions, fn other_predicate, new_expressions ->
|
|
# With predicate as a and other_predicate as b
|
|
case Ash.Filter.Predicate.compare(predicate, other_predicate) do
|
|
:right_includes_left ->
|
|
# b || !a
|
|
|
|
[b(other_predicate or not predicate) | new_expressions]
|
|
|
|
:left_includes_right ->
|
|
# a || ! b
|
|
[b(predicate or not other_predicate) | new_expressions]
|
|
|
|
:mutually_inclusive ->
|
|
# (a && b) || (! a && ! b)
|
|
[
|
|
b((predicate and other_predicate) or (not predicate and not other_predicate))
|
|
| new_expressions
|
|
]
|
|
|
|
:mutually_exclusive ->
|
|
[b(not (other_predicate and predicate)) | new_expressions]
|
|
|
|
:mutually_exclusive_and_collectively_exhaustive ->
|
|
[
|
|
b(
|
|
not (other_predicate and predicate) and
|
|
not (not other_predicate and not predicate)
|
|
)
|
|
| new_expressions
|
|
]
|
|
|
|
_other ->
|
|
# If we can't tell, we assume that both could be true
|
|
new_expressions
|
|
end
|
|
end)
|
|
end)
|
|
|> Enum.uniq()
|
|
|
|
expression = filter_to_expr(expression)
|
|
|
|
expression_with_comparisons =
|
|
Enum.reduce(comparison_expressions, expression, fn comparison_expression, expression ->
|
|
b(comparison_expression and expression)
|
|
end)
|
|
|
|
all_predicates
|
|
|> Enum.map(& &1.__struct__)
|
|
|> Enum.uniq()
|
|
|> Enum.flat_map(fn struct ->
|
|
if :erlang.function_exported(struct, :bulk_compare, 1) do
|
|
struct.bulk_compare(all_predicates)
|
|
else
|
|
[]
|
|
end
|
|
end)
|
|
|> Enum.reduce(expression_with_comparisons, fn comparison_expression, expression ->
|
|
b(comparison_expression and expression)
|
|
end)
|
|
end
|
|
|
|
def fully_simplify(expression) do
|
|
expression
|
|
|> do_fully_simplify()
|
|
|> lift_equals_out_of_in()
|
|
|> do_fully_simplify()
|
|
end
|
|
|
|
defp do_fully_simplify(expression) do
|
|
expression
|
|
|> simplify()
|
|
|> case do
|
|
^expression ->
|
|
expression
|
|
|
|
simplified ->
|
|
fully_simplify(simplified)
|
|
end
|
|
end
|
|
|
|
def lift_equals_out_of_in(expression) do
|
|
case find_non_equal_overlap(expression) do
|
|
nil ->
|
|
expression
|
|
|
|
non_equal_overlap ->
|
|
expression
|
|
|> split_in_expressions(non_equal_overlap)
|
|
|> lift_equals_out_of_in()
|
|
end
|
|
end
|
|
|
|
def find_non_equal_overlap(expression) do
|
|
Ash.Filter.find(expression, fn sub_expr ->
|
|
Ash.Filter.find(expression, fn sub_expr2 ->
|
|
# if has_call_or_expression?(sub_expr) || has_call_or_expression?(sub_expr2) do
|
|
# false
|
|
# else
|
|
overlap?(sub_expr, sub_expr2)
|
|
# end
|
|
end)
|
|
end)
|
|
end
|
|
|
|
defp new_in(base, right) do
|
|
case MapSet.size(right) do
|
|
1 ->
|
|
%Ash.Query.Operator.Eq{left: base.left, right: Enum.at(right, 0)}
|
|
|
|
_ ->
|
|
%Ash.Query.Operator.In{left: base.left, right: right}
|
|
end
|
|
end
|
|
|
|
def split_in_expressions(
|
|
%Ash.Query.Operator.In{right: right} = sub_expr,
|
|
%Ash.Query.Operator.Eq{right: value} = non_equal_overlap
|
|
) do
|
|
if overlap?(non_equal_overlap, sub_expr) do
|
|
Ash.Query.BooleanExpression.new(
|
|
:or,
|
|
new_in(sub_expr, MapSet.delete(right, value)),
|
|
non_equal_overlap
|
|
)
|
|
else
|
|
sub_expr
|
|
end
|
|
end
|
|
|
|
def split_in_expressions(
|
|
%Ash.Query.Operator.In{} = sub_expr,
|
|
%Ash.Query.Operator.In{right: right} = non_equal_overlap
|
|
) do
|
|
if overlap?(sub_expr, non_equal_overlap) do
|
|
diff = MapSet.difference(sub_expr.right, right)
|
|
|
|
if MapSet.size(diff) == 0 do
|
|
Enum.reduce(sub_expr.right, nil, fn var, acc ->
|
|
BooleanExpression.new(:or, %Ash.Query.Operator.Eq{left: sub_expr.left, right: var}, acc)
|
|
end)
|
|
else
|
|
new_right = new_in(sub_expr, MapSet.intersection(sub_expr.right, right))
|
|
|
|
Ash.Query.BooleanExpression.new(
|
|
:or,
|
|
new_in(sub_expr, diff),
|
|
new_right
|
|
)
|
|
end
|
|
else
|
|
sub_expr
|
|
end
|
|
end
|
|
|
|
def split_in_expressions(nil, _), do: nil
|
|
|
|
def split_in_expressions(%Ash.Filter{expression: expression} = filter, non_equal_overlap),
|
|
do: %{filter | expression: split_in_expressions(expression, non_equal_overlap)}
|
|
|
|
def split_in_expressions(%Not{expression: expression} = not_expr, non_equal_overlap),
|
|
do: %{not_expr | expression: split_in_expressions(expression, non_equal_overlap)}
|
|
|
|
def split_in_expressions(
|
|
%BooleanExpression{left: left, right: right} = expr,
|
|
non_equal_overlap
|
|
),
|
|
do: %{
|
|
expr
|
|
| left: split_in_expressions(left, non_equal_overlap),
|
|
right: split_in_expressions(right, non_equal_overlap)
|
|
}
|
|
|
|
def split_in_expressions(other, _), do: other
|
|
|
|
def overlap?(
|
|
%Ash.Query.Operator.In{left: left, right: %{__struct__: MapSet} = left_right},
|
|
%Ash.Query.Operator.In{left: left, right: %{__struct__: MapSet} = right_right}
|
|
) do
|
|
if MapSet.equal?(left_right, right_right) do
|
|
false
|
|
else
|
|
overlap? =
|
|
left_right
|
|
|> MapSet.intersection(right_right)
|
|
|> MapSet.size()
|
|
|> Kernel.>(0)
|
|
|
|
if overlap? do
|
|
true
|
|
else
|
|
false
|
|
end
|
|
end
|
|
end
|
|
|
|
def overlap?(_, %Ash.Query.Operator.Eq{right: %Ref{}}),
|
|
do: false
|
|
|
|
def overlap?(%Ash.Query.Operator.Eq{right: %Ref{}}, _),
|
|
do: false
|
|
|
|
def overlap?(
|
|
%Ash.Query.Operator.Eq{left: left, right: left_right},
|
|
%Ash.Query.Operator.In{left: left, right: %{__struct__: MapSet} = right_right}
|
|
) do
|
|
MapSet.member?(right_right, left_right)
|
|
end
|
|
|
|
def overlap?(_left, _right) do
|
|
false
|
|
end
|
|
|
|
def mutually_exclusive(predicates, acc \\ [])
|
|
def mutually_exclusive([], acc), do: acc
|
|
|
|
def mutually_exclusive([predicate | rest], acc) do
|
|
new_acc =
|
|
Enum.reduce(rest, acc, fn other_predicate, acc ->
|
|
[b(not (predicate and other_predicate)) | acc]
|
|
end)
|
|
|
|
mutually_exclusive(rest, new_acc)
|
|
end
|
|
|
|
def mutually_exclusive_and_collectively_exhaustive([]), do: []
|
|
|
|
def mutually_exclusive_and_collectively_exhaustive([_]), do: []
|
|
|
|
def mutually_exclusive_and_collectively_exhaustive(predicates) do
|
|
mutually_exclusive(predicates) ++
|
|
Enum.flat_map(predicates, fn predicate ->
|
|
other_predicates = Enum.reject(predicates, &(&1 == predicate))
|
|
|
|
other_predicates_union =
|
|
Enum.reduce(other_predicates, nil, fn other_predicate, expr ->
|
|
if expr do
|
|
b(expr or other_predicate)
|
|
else
|
|
other_predicate
|
|
end
|
|
end)
|
|
|
|
b(
|
|
not (predicate and other_predicates_union) and
|
|
not (not predicate and not other_predicates_union)
|
|
)
|
|
end)
|
|
end
|
|
|
|
def left_excludes_right(left, right) do
|
|
b(not (left and right))
|
|
end
|
|
|
|
def right_excludes_left(left, right) do
|
|
b(not (right and left))
|
|
end
|
|
|
|
def mutually_inclusive(predicates, acc \\ [])
|
|
def mutually_inclusive([], acc), do: acc
|
|
|
|
def mutually_inclusive([predicate | rest], acc) do
|
|
new_acc =
|
|
Enum.reduce(rest, acc, fn other_predicate, acc ->
|
|
[b((predicate and other_predicate) or (not predicate and not other_predicate)) | acc]
|
|
end)
|
|
|
|
mutually_exclusive(rest, new_acc)
|
|
end
|
|
|
|
def right_implies_left(left, right) do
|
|
b(not (right and not left))
|
|
end
|
|
|
|
def left_implies_right(left, right) do
|
|
b(not (left and not right))
|
|
end
|
|
|
|
defp shares_ref?(left, right) do
|
|
any_refs_in_common?(refs(left), refs(right))
|
|
end
|
|
|
|
defp any_refs_in_common?(left_refs, right_refs) do
|
|
Enum.any?(left_refs, &(&1 in right_refs))
|
|
end
|
|
|
|
defp refs(%{__operator__?: true, left: left, right: right}) do
|
|
Enum.filter([left, right], &match?(%Ref{}, &1))
|
|
end
|
|
|
|
defp refs(%{__function__?: true, arguments: arguments}) do
|
|
Enum.filter(arguments, &match?(%Ref{}, &1))
|
|
end
|
|
|
|
defp refs(_), do: []
|
|
|
|
defp simplify(%BooleanExpression{op: op, left: left, right: right}) do
|
|
BooleanExpression.new(op, simplify(left), simplify(right))
|
|
end
|
|
|
|
defp simplify(%Not{expression: expression}) do
|
|
Not.new(simplify(expression))
|
|
end
|
|
|
|
defp simplify(%Ash.Query.Exists{expr: expr} = exists) do
|
|
%{exists | expr: simplify(expr)}
|
|
end
|
|
|
|
defp simplify(%mod{__predicate__?: true} = predicate) do
|
|
if :erlang.function_exported(mod, :simplify, 1) do
|
|
predicate
|
|
|> mod.simplify()
|
|
|> Kernel.||(predicate)
|
|
else
|
|
predicate
|
|
end
|
|
end
|
|
|
|
defp simplify(other), do: other
|
|
|
|
def to_cnf(expression) do
|
|
expression_with_constants = b(true and not false and expression)
|
|
|
|
{bindings, expression} = extract_bindings(expression_with_constants)
|
|
|
|
expression
|
|
|> to_conjunctive_normal_form()
|
|
|> lift_clauses()
|
|
|> negations_to_negative_numbers()
|
|
|> Enum.map(fn scenario ->
|
|
Enum.sort_by(scenario, fn item ->
|
|
{abs(item), item}
|
|
end)
|
|
end)
|
|
|> group_predicates(bindings)
|
|
|> rebind()
|
|
end
|
|
|
|
defp group_predicates(expression, bindings) do
|
|
case expression do
|
|
[_] ->
|
|
{expression, bindings}
|
|
|
|
scenarios ->
|
|
Enum.reduce(scenarios, {[], bindings}, fn scenario, {new_scenarios, bindings} ->
|
|
{scenario, bindings} = group_scenario_predicates(scenario, scenarios, bindings)
|
|
{[scenario | new_scenarios], bindings}
|
|
end)
|
|
end
|
|
end
|
|
|
|
defp group_scenario_predicates(scenario, all_scenarios, bindings) do
|
|
scenario
|
|
|> Ash.SatSolver.Utils.ordered_sublists()
|
|
|> Enum.filter(&can_be_used_as_group?(&1, all_scenarios, bindings))
|
|
|> Enum.sort_by(&(-length(&1)))
|
|
|> Enum.reduce({scenario, bindings}, fn group, {scenario, bindings} ->
|
|
bindings = add_group_binding(bindings, group)
|
|
|
|
{Ash.SatSolver.Utils.replace_ordered_sublist(scenario, group, bindings[:groups][group]),
|
|
bindings}
|
|
end)
|
|
end
|
|
|
|
def unbind(expression, %{temp_bindings: temp_bindings, old_bindings: old_bindings}) do
|
|
expression =
|
|
Enum.flat_map(expression, fn statement ->
|
|
Enum.flat_map(statement, fn var ->
|
|
neg? = var < 0
|
|
old_binding = temp_bindings[abs(var)]
|
|
|
|
case old_bindings[:reverse_groups][old_binding] do
|
|
nil ->
|
|
if neg? do
|
|
[-old_binding]
|
|
else
|
|
[old_binding]
|
|
end
|
|
|
|
group ->
|
|
if neg? do
|
|
Enum.map(group, &(-&1))
|
|
else
|
|
[{:expand, group}]
|
|
end
|
|
end
|
|
end)
|
|
|> expand_groups()
|
|
end)
|
|
|
|
{expression, old_bindings}
|
|
end
|
|
|
|
def expand_groups(expression) do
|
|
if Enum.any?(expression, &match?({:expand, _}, &1)) do
|
|
do_expand_groups(expression)
|
|
else
|
|
[expression]
|
|
end
|
|
end
|
|
|
|
defp do_expand_groups([]), do: [[]]
|
|
|
|
defp do_expand_groups([{:expand, group} | rest]) do
|
|
Enum.flat_map(group, fn var ->
|
|
Enum.map(do_expand_groups(rest), fn future ->
|
|
[var | future]
|
|
end)
|
|
end)
|
|
end
|
|
|
|
defp do_expand_groups([var | rest]) do
|
|
Enum.map(do_expand_groups(rest), fn future ->
|
|
[var | future]
|
|
end)
|
|
end
|
|
|
|
defp rebind({expression, bindings}) do
|
|
{expression, temp_bindings} =
|
|
Enum.reduce(expression, {[], %{current: 0}}, fn statement, {statements, acc} ->
|
|
{statement, acc} =
|
|
Enum.reduce(statement, {[], acc}, fn var, {statement, acc} ->
|
|
case acc[:reverse][abs(var)] do
|
|
nil ->
|
|
binding = acc.current + 1
|
|
|
|
value =
|
|
if var < 0 do
|
|
-binding
|
|
else
|
|
binding
|
|
end
|
|
|
|
{[value | statement],
|
|
acc
|
|
|> Map.put(:current, binding)
|
|
|> Map.update(:reverse, %{abs(var) => binding}, &Map.put(&1, abs(var), binding))
|
|
|> Map.put(binding, abs(var))}
|
|
|
|
value ->
|
|
value =
|
|
if var < 0 do
|
|
-value
|
|
else
|
|
value
|
|
end
|
|
|
|
{[value | statement], acc}
|
|
end
|
|
end)
|
|
|
|
{[Enum.reverse(statement) | statements], acc}
|
|
end)
|
|
|
|
bindings_with_old_bindings = %{temp_bindings: temp_bindings, old_bindings: bindings}
|
|
|
|
{expression, bindings_with_old_bindings}
|
|
end
|
|
|
|
def can_be_used_as_group?(group, scenarios, bindings) do
|
|
Map.has_key?(bindings[:groups] || %{}, group) ||
|
|
Enum.all?(scenarios, fn scenario ->
|
|
has_no_overlap?(scenario, group) || group_in_scenario?(scenario, group)
|
|
end)
|
|
end
|
|
|
|
defp has_no_overlap?(scenario, group) do
|
|
not Enum.any?(group, fn group_predicate ->
|
|
Enum.any?(scenario, fn scenario_predicate ->
|
|
abs(group_predicate) == abs(scenario_predicate)
|
|
end)
|
|
end)
|
|
end
|
|
|
|
defp group_in_scenario?(scenario, group) do
|
|
Ash.SatSolver.Utils.is_ordered_sublist_of?(group, scenario)
|
|
end
|
|
|
|
defp add_group_binding(bindings, group) do
|
|
if bindings[:groups][group] do
|
|
bindings
|
|
else
|
|
new_binding = bindings[:current] + 1
|
|
|
|
bindings
|
|
|> Map.put(:current, new_binding)
|
|
|> Map.put_new(:reverse_groups, %{})
|
|
|> Map.update!(:reverse_groups, &Map.put(&1, new_binding, group))
|
|
|> Map.put_new(:groups, %{})
|
|
|> Map.update!(:groups, &Map.put(&1, group, new_binding))
|
|
end
|
|
end
|
|
|
|
def solve_expression(cnf) do
|
|
Picosat.solve(cnf)
|
|
end
|
|
|
|
def contains?([], _), do: false
|
|
|
|
def contains?([_ | t] = l1, l2) do
|
|
List.starts_with?(l1, l2) or contains?(t, l2)
|
|
end
|
|
|
|
def solutions_to_predicate_values(solution, bindings) do
|
|
Enum.reduce(solution, %{true: [], false: []}, fn var, state ->
|
|
fact = Map.get(bindings, abs(var))
|
|
|
|
if is_nil(fact) do
|
|
raise Ash.Error.Framework.AssumptionFailed.exception(
|
|
message: """
|
|
A fact from the sat solver had no corresponding bound fact:
|
|
|
|
Bindings:
|
|
#{inspect(bindings)}
|
|
|
|
Missing:
|
|
#{inspect(var)}
|
|
"""
|
|
)
|
|
end
|
|
|
|
Map.put(state, fact, var > 0)
|
|
end)
|
|
end
|
|
|
|
defp extract_bindings(expr, bindings \\ %{current: 1})
|
|
|
|
defp extract_bindings({operator, left, right}, bindings) do
|
|
{bindings, left_extracted} = extract_bindings(left, bindings)
|
|
{bindings, right_extracted} = extract_bindings(right, bindings)
|
|
|
|
{bindings, {operator, left_extracted, right_extracted}}
|
|
end
|
|
|
|
defp extract_bindings({:not, value}, bindings) do
|
|
{bindings, extracted} = extract_bindings(value, bindings)
|
|
|
|
{bindings, b(not extracted)}
|
|
end
|
|
|
|
defp extract_bindings(value, %{current: current} = bindings) do
|
|
current_binding =
|
|
Enum.find(bindings, fn {key, binding_value} ->
|
|
key != :current && binding_value == value
|
|
end)
|
|
|
|
case current_binding do
|
|
nil ->
|
|
new_bindings =
|
|
bindings
|
|
|> Map.put(:current, current + 1)
|
|
|> Map.put(current, value)
|
|
|
|
{new_bindings, current}
|
|
|
|
{binding, _} ->
|
|
{bindings, binding}
|
|
end
|
|
end
|
|
|
|
# A helper function for formatting to the same output we'd give to picosat
|
|
@doc false
|
|
def to_picosat(clauses, variable_count) do
|
|
clause_count = Enum.count(clauses)
|
|
|
|
formatted_input =
|
|
Enum.map_join(clauses, "\n", fn clause ->
|
|
format_clause(clause) <> " 0"
|
|
end)
|
|
|
|
"p cnf #{variable_count} #{clause_count}\n" <> formatted_input
|
|
end
|
|
|
|
defp negations_to_negative_numbers(clauses) do
|
|
Enum.map(
|
|
clauses,
|
|
fn
|
|
{:not, var} when is_integer(var) ->
|
|
[negate_var(var)]
|
|
|
|
var when is_integer(var) ->
|
|
[var]
|
|
|
|
clause ->
|
|
Enum.map(clause, fn
|
|
{:not, var} -> negate_var(var)
|
|
var -> var
|
|
end)
|
|
end
|
|
)
|
|
end
|
|
|
|
defp negate_var(var, multiplier \\ -1)
|
|
|
|
defp negate_var({:not, value}, multiplier) do
|
|
negate_var(value, multiplier * -1)
|
|
end
|
|
|
|
defp negate_var(value, multiplier), do: value * multiplier
|
|
|
|
defp format_clause(clause) do
|
|
Enum.map_join(clause, " ", fn
|
|
{:not, var} -> "-#{var}"
|
|
var -> "#{var}"
|
|
end)
|
|
end
|
|
|
|
defp lift_clauses({:and, left, right}) do
|
|
lift_clauses(left) ++ lift_clauses(right)
|
|
end
|
|
|
|
defp lift_clauses({:or, left, right}) do
|
|
[lift_or_clauses(left) ++ lift_or_clauses(right)]
|
|
end
|
|
|
|
defp lift_clauses(value), do: [[value]]
|
|
|
|
defp lift_or_clauses({:or, left, right}) do
|
|
lift_or_clauses(left) ++ lift_or_clauses(right)
|
|
end
|
|
|
|
defp lift_or_clauses(value), do: [value]
|
|
|
|
defp to_conjunctive_normal_form(expression) do
|
|
expression
|
|
|> demorgans_law()
|
|
|> distributive_law()
|
|
end
|
|
|
|
defp distributive_law(expression) do
|
|
distributive_law_applied = apply_distributive_law(expression)
|
|
|
|
if expression == distributive_law_applied do
|
|
expression
|
|
else
|
|
distributive_law(distributive_law_applied)
|
|
end
|
|
end
|
|
|
|
defp apply_distributive_law({:or, left, {:and, right1, right2}}) do
|
|
left_distributed = apply_distributive_law(left)
|
|
|
|
{:and, {:or, left_distributed, apply_distributive_law(right1)},
|
|
{:or, left_distributed, apply_distributive_law(right2)}}
|
|
end
|
|
|
|
defp apply_distributive_law({:or, {:and, left1, left2}, right}) do
|
|
right_distributed = apply_distributive_law(right)
|
|
|
|
{:and, {:or, apply_distributive_law(left1), right_distributed},
|
|
{:or, apply_distributive_law(left2), right_distributed}}
|
|
end
|
|
|
|
defp apply_distributive_law({:not, expression}) do
|
|
{:not, apply_distributive_law(expression)}
|
|
end
|
|
|
|
defp apply_distributive_law({operator, left, right}) when operator in [:and, :or] do
|
|
{operator, apply_distributive_law(left), apply_distributive_law(right)}
|
|
end
|
|
|
|
defp apply_distributive_law(var) when is_integer(var) do
|
|
var
|
|
end
|
|
|
|
defp demorgans_law(expression) do
|
|
demorgans_law_applied = apply_demorgans_law(expression)
|
|
|
|
if expression == demorgans_law_applied do
|
|
expression
|
|
else
|
|
demorgans_law(demorgans_law_applied)
|
|
end
|
|
end
|
|
|
|
defp apply_demorgans_law({:not, {:and, left, right}}) do
|
|
{:or, {:not, apply_demorgans_law(left)}, {:not, apply_demorgans_law(right)}}
|
|
end
|
|
|
|
defp apply_demorgans_law({:not, {:or, left, right}}) do
|
|
{:and, {:not, left}, {:not, right}}
|
|
end
|
|
|
|
defp apply_demorgans_law({operator, left, right}) when operator in [:or, :and] do
|
|
{operator, apply_demorgans_law(left), apply_demorgans_law(right)}
|
|
end
|
|
|
|
defp apply_demorgans_law({:not, expression}) do
|
|
{:not, apply_demorgans_law(expression)}
|
|
end
|
|
|
|
defp apply_demorgans_law(var) when is_integer(var) do
|
|
var
|
|
end
|
|
end
|