improvement: optimize sat solving

1. only convert to CNF once
2. group predicates that only appear in specific combinations to limit amount of variables provided to the sat solver

Number 2 above does technically slow down all cases a bit, but the optimization is really important when it matters. And cases that don't need this optimization still happen on the order microseconds anyway.
This commit is contained in:
Zach Daniel 2022-11-15 01:34:57 -05:00
parent 2cf18e6a3a
commit 665a9fb5c4
8 changed files with 359 additions and 56 deletions

39
benchmarks/sat_solver.exs Normal file
View file

@ -0,0 +1,39 @@
list = Enum.to_list(1..10_000)
map_fun = fn i -> [i, i * i] end
mixed = fn count ->
Enum.reduce(1..count, 0, fn var, expr ->
cond do
rem(var, 4) == 0 ->
{:or, var, expr}
rem(var, 3) == 0 ->
{:and, expr, var}
rem(var, 2) == 0 ->
{:and, -var, expr}
true ->
{:or, -var, expr}
end
end) |> Ash.Policy.SatSolver.solve()
end
Benchee.run(
%{
solve: fn input ->
Ash.Policy.SatSolver.solve(input)
end
},
inputs: %{
"3 conjunctive" => Enum.to_list(1..3) |> Enum.reduce(0, fn var, expr -> {:and, var, expr} end) |> Ash.Policy.SatSolver.solve(),
"3 disjunctive" => Enum.to_list(1..3) |> Enum.reduce(0, fn var, expr -> {:or, var, expr} end) |> Ash.Policy.SatSolver.solve(),
"3 mixed" => mixed.(3),
"5 conjunctive" => Enum.to_list(1..5) |> Enum.reduce(0, fn var, expr -> {:and, var, expr} end) |> Ash.Policy.SatSolver.solve(),
"5 disjunctive" => Enum.to_list(1..5) |> Enum.reduce(0, fn var, expr -> {:or, var, expr} end) |> Ash.Policy.SatSolver.solve(),
"5 mixed" => mixed.(5),
"7 conjunctive" => Enum.to_list(1..7) |> Enum.reduce(0, fn var, expr -> {:and, var, expr} end) |> Ash.Policy.SatSolver.solve(),
"7 disjunctive" => Enum.to_list(1..7) |> Enum.reduce(0, fn var, expr -> {:or, var, expr} end) |> Ash.Policy.SatSolver.solve(),
"7 mixed" => mixed.(7),
}
)

View file

@ -488,7 +488,12 @@ defmodule Ash.Policy.Authorizer do
end
{{check_module, check_opts}, false} ->
result = check_module.auto_filter_not(authorizer.actor, authorizer, check_opts)
result =
if :erlang.function_exported(check_module, :auto_filter_not, 3) do
check_module.auto_filter_not(authorizer.actor, authorizer, check_opts)
else
[not: check_module.auto_filter(authorizer.actor, authorizer, check_opts)]
end
if is_nil(result) do
false
@ -576,7 +581,11 @@ defmodule Ash.Policy.Authorizer do
if required_status do
check_module.auto_filter(authorizer.actor, authorizer, check_opts)
else
check_module.auto_filter_not(authorizer.actor, authorizer, check_opts)
if :erlang.function_exported(check_module, :auto_filter_not, 3) do
check_module.auto_filter_not(authorizer.actor, authorizer, check_opts)
else
[not: check_module.auto_filter(authorizer.actor, authorizer, check_opts)]
end
end
scenarios = remove_clause(authorizer.scenarios, {check_module, check_opts})

View file

@ -14,10 +14,21 @@ defmodule Ash.Policy.Policy do
@type t :: %__MODULE__{}
@static_checks [
{Ash.Policy.Check.Static, [result: true]},
{Ash.Policy.Check.Static, [result: false]},
true,
false
]
def solve(authorizer) do
authorizer.policies
|> build_requirements_expression(authorizer.facts)
|> Ash.Policy.SatSolver.solve()
|> Ash.Policy.SatSolver.solve(fn scenario, bindings ->
scenario
|> Ash.SatSolver.solutions_to_predicate_values(bindings)
|> Map.drop(@static_checks)
end)
end
defp build_requirements_expression(policies, facts) do

View file

@ -1,33 +1,51 @@
defmodule Ash.Policy.SatSolver do
@moduledoc false
def solve(expression) do
expression
def solve(expression, mapper \\ nil) do
{cnf, bindings} = Ash.SatSolver.to_cnf(expression)
cnf
|> add_negations_and_solve([])
|> get_all_scenarios(expression)
|> get_all_scenarios(cnf)
|> case do
[] ->
{:error, :unsatisfiable}
scenarios ->
static_checks = [
{Ash.Policy.Check.Static, [result: true]},
{Ash.Policy.Check.Static, [result: false]}
]
{scenarios, bindings} = Ash.SatSolver.unbind(scenarios, bindings)
{:ok,
scenarios
|> Enum.map(&Map.drop(&1, static_checks))
|> Enum.uniq()}
if mapper do
{:ok,
Enum.map(scenarios, fn scenario ->
mapper.(scenario, bindings)
end)}
else
{:ok, scenarios}
end
end
end
defp get_all_scenarios(scenario_result, expression, scenarios \\ [])
defp get_all_scenarios({:error, :unsatisfiable}, _, scenarios), do: scenarios
defp get_all_scenarios(
scenario_result,
expression,
scenarios \\ [],
negations \\ []
)
defp get_all_scenarios(
{:error, :unsatisfiable},
_,
scenarios,
_negations
),
do: scenarios
defp get_all_scenarios({:ok, scenario}, expression, scenarios, negations) do
all_scenarios = [scenario | scenarios]
negations = [Enum.map(scenario, &(-&1)) | negations]
defp get_all_scenarios({:ok, scenario}, expression, scenarios) do
expression
|> add_negations_and_solve([Map.drop(scenario, [true, false]) | scenarios])
|> get_all_scenarios(expression, [Map.drop(scenario, [true, false]) | scenarios])
|> add_negations_and_solve(negations)
|> get_all_scenarios(expression, all_scenarios, negations)
end
def simplify_clauses([scenario]), do: [scenario]
@ -82,29 +100,8 @@ defmodule Ash.Policy.SatSolver do
end
@spec add_negations_and_solve(term, term) :: term | no_return()
defp add_negations_and_solve(requirements_expression, negations) do
negations =
Enum.reduce(negations, nil, fn negation, expr ->
negation_statement =
negation
|> Map.drop([true, false])
|> facts_to_statement()
if expr do
{:and, expr, {:not, negation_statement}}
else
{:not, negation_statement}
end
end)
full_expression =
if negations do
{:and, requirements_expression, negations}
else
requirements_expression
end
solve_expression(full_expression)
defp add_negations_and_solve(cnf, negations) do
solve_expression(cnf ++ negations)
end
def facts_to_statement(facts) do

View file

@ -85,6 +85,7 @@ defmodule Ash.SatSolver do
{:ok, _scenario} ->
expr = BooleanExpression.new(:and, Not.new(filter.expression), candidate.expression)
Application.put_env(:foo, :bar, true)
case transform_and_solve(
filter.resource,
@ -125,6 +126,8 @@ defmodule Ash.SatSolver do
def transform_and_solve(resource, expression) do
resource
|> transform(expression)
|> to_cnf()
|> elem(0)
|> solve_expression()
end
@ -732,7 +735,7 @@ defmodule Ash.SatSolver do
defp simplify(other), do: other
def solve_expression(expression) do
def to_cnf(expression) do
expression_with_constants = b(true and not false and expression)
{bindings, expression} = extract_bindings(expression_with_constants)
@ -741,23 +744,201 @@ defmodule Ash.SatSolver do
|> to_conjunctive_normal_form()
|> lift_clauses()
|> negations_to_negative_numbers()
|> Enum.uniq()
|> Picosat.solve()
|> solutions_to_predicate_values(bindings)
|> Enum.map(fn scenario ->
Enum.sort_by(scenario, fn item ->
{abs(item), item}
end)
end)
|> group_predicates(bindings)
|> rebind()
end
defp solutions_to_predicate_values({:ok, solution}, bindings) do
scenario =
Enum.reduce(solution, %{true: [], false: []}, fn var, state ->
fact = Map.get(bindings, abs(var))
defp group_predicates(expression, bindings) do
case expression do
[_] ->
{expression, bindings}
Map.put(state, fact, var > 0)
scenarios ->
Enum.reduce(scenarios, {[], bindings}, fn scenario, {new_scenarios, bindings} ->
{scenario, bindings} = group_scenario_predicates(scenario, scenarios, bindings)
{[scenario | new_scenarios], bindings}
end)
end
end
defp group_scenario_predicates(scenario, all_scenarios, bindings) do
scenario
|> Ash.SatSolver.Utils.ordered_sublists()
|> Enum.filter(&can_be_used_as_group?(&1, all_scenarios, bindings))
|> Enum.sort_by(&(-length(&1)))
|> Enum.reduce({scenario, bindings}, fn group, {scenario, bindings} ->
bindings = add_group_binding(bindings, group)
{Ash.SatSolver.Utils.replace_ordered_sublist(scenario, group, bindings[:groups][group]),
bindings}
end)
end
def unbind(expression, %{temp_bindings: temp_bindings, old_bindings: old_bindings}) do
expression =
Enum.flat_map(expression, fn statement ->
Enum.flat_map(statement, fn var ->
neg? = var < 0
old_binding = temp_bindings[abs(var)]
case old_bindings[:reverse_groups][old_binding] do
nil ->
if neg? do
[-old_binding]
else
[old_binding]
end
group ->
if neg? do
Enum.map(group, &(-&1))
else
[{:expand, group}]
end
end
end)
|> expand_groups()
end)
{:ok, scenario}
{expression, old_bindings}
end
defp solutions_to_predicate_values({:error, error}, _), do: {:error, error}
def expand_groups(expression) do
if Enum.any?(expression, &match?({:expand, _}, &1)) do
do_expand_groups(expression)
else
[expression]
end
end
defp do_expand_groups([]), do: [[]]
defp do_expand_groups([{:expand, group} | rest]) do
Enum.flat_map(group, fn var ->
Enum.map(do_expand_groups(rest), fn future ->
[var | future]
end)
end)
end
defp do_expand_groups([var | rest]) do
Enum.map(do_expand_groups(rest), fn future ->
[var | future]
end)
end
defp rebind({expression, bindings}) do
{expression, temp_bindings} =
Enum.reduce(expression, {[], %{current: 0}}, fn statement, {statements, acc} ->
{statement, acc} =
Enum.reduce(statement, {[], acc}, fn var, {statement, acc} ->
case acc[:reverse][abs(var)] do
nil ->
binding = acc.current + 1
value =
if var < 0 do
-binding
else
binding
end
{[value | statement],
acc
|> Map.put(:current, binding)
|> Map.update(:reverse, %{abs(var) => binding}, &Map.put(&1, abs(var), binding))
|> Map.put(binding, abs(var))}
value ->
value =
if var < 0 do
-value
else
value
end
{[value | statement], acc}
end
end)
{[Enum.reverse(statement) | statements], acc}
end)
bindings_with_old_bindings = %{temp_bindings: temp_bindings, old_bindings: bindings}
{expression, bindings_with_old_bindings}
end
def can_be_used_as_group?(group, scenarios, bindings) do
Map.has_key?(bindings[:groups] || %{}, group) ||
Enum.all?(scenarios, fn scenario ->
has_no_overlap?(scenario, group) || group_in_scenario?(scenario, group)
end)
end
defp has_no_overlap?(scenario, group) do
not Enum.any?(group, fn group_predicate ->
Enum.any?(scenario, fn scenario_predicate ->
abs(group_predicate) == abs(scenario_predicate)
end)
end)
end
defp group_in_scenario?(scenario, group) do
Ash.SatSolver.Utils.is_ordered_sublist_of?(group, scenario)
end
defp add_group_binding(bindings, group) do
if bindings[:groups][group] do
bindings
else
new_binding = bindings[:current] + 1
bindings
|> Map.put(:current, new_binding)
|> Map.put_new(:reverse_groups, %{})
|> Map.update!(:reverse_groups, &Map.put(&1, new_binding, group))
|> Map.put_new(:groups, %{})
|> Map.update!(:groups, &Map.put(&1, group, new_binding))
end
end
def solve_expression(cnf) do
Picosat.solve(cnf)
end
def contains?([], _), do: false
def contains?([_ | t] = l1, l2) do
List.starts_with?(l1, l2) or contains?(t, l2)
end
def solutions_to_predicate_values(solution, bindings) do
Enum.reduce(solution, %{true: [], false: []}, fn var, state ->
fact = Map.get(bindings, abs(var))
if is_nil(fact) do
raise Ash.Error.Framework.AssumptionFailed.exception(
message: """
A fact from the sat solver had no corresponding bound fact:
Bindings:
#{inspect(bindings)}
Missing:
#{inspect(var)}
"""
)
end
Map.put(state, fact, var > 0)
end)
end
defp extract_bindings(expr, bindings \\ %{current: 1})

62
lib/sat_solver/utils.ex Normal file
View file

@ -0,0 +1,62 @@
defmodule Ash.SatSolver.Utils do
@moduledoc false
def replace_ordered_sublist([], _sublist, _replacement), do: []
def replace_ordered_sublist([h | tail] = list, sublist, replacement) do
if List.starts_with?(list, sublist) do
[replacement | Enum.drop(list, length(sublist))]
else
[h | replace_ordered_sublist(tail, sublist, replacement)]
end
end
def is_ordered_sublist_of?(_, []), do: false
def is_ordered_sublist_of?(sublist, [_ | rest] = list) do
List.starts_with?(list, sublist) || is_ordered_sublist_of?(sublist, rest)
end
def ordered_sublists(list) do
front = sublists_front(list)
back =
list
|> sublists_back()
|> Enum.reject(&(&1 in front))
front ++ back
end
def sublists_back([_, _]), do: []
def sublists_back([]), do: []
def sublists_back([_]), do: []
def sublists_back(list) do
list = :lists.droplast(list)
[list | sublists_back(list)]
end
def sublists_front(list) do
list
|> do_sublists_front()
|> Enum.reject(fn
^list ->
true
[_] ->
true
_ ->
false
end)
end
def do_sublists_front([]) do
[]
end
def do_sublists_front([_first | rest] = list) do
[list | do_sublists_front(rest)]
end
end

View file

@ -222,6 +222,7 @@ defmodule Ash.MixProject do
{:ecto, "~> 3.7"},
{:ets, "~> 0.8.0"},
{:decimal, "~> 2.0"},
# {:picosat_elixir, path: "../picosat_elixir"},
{:picosat_elixir, "~> 0.2"},
{:nimble_options, "~> 0.4.0"},
{:comparable, "~> 1.0"},
@ -238,8 +239,8 @@ defmodule Ash.MixProject do
{:git_ops, "~> 2.5", only: :dev},
{:mix_test_watch, "~> 1.0", only: :dev, runtime: false},
{:parse_trans, "3.3.0", only: [:dev, :test], override: true},
{:elixir_sense, github: "elixir-lsp/elixir_sense", only: [:dev, :test, :docs]},
{:plug, ">= 0.0.0", only: [:dev, :test], runtime: false}
{:plug, ">= 0.0.0", only: [:dev, :test], runtime: false},
{:benchee, "~> 1.1", only: [:dev, :test]}
]
end

View file

@ -1,15 +1,17 @@
%{
"benchee": {:hex, :benchee, "1.1.0", "f3a43817209a92a1fade36ef36b86e1052627fd8934a8b937ac9ab3a76c43062", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}], "hexpm", "7da57d545003165a012b587077f6ba90b89210fd88074ce3c60ce239eb5e6d93"},
"bunt": {:hex, :bunt, "0.2.0", "951c6e801e8b1d2cbe58ebbd3e616a869061ddadcc4863d0a2182541acae9a38", [:mix], [], "hexpm", "7af5c7e09fe1d40f76c8e4f9dd2be7cebd83909f31fee7cd0e9eadc567da8353"},
"certifi": {:hex, :certifi, "2.6.1", "dbab8e5e155a0763eea978c913ca280a6b544bfa115633fa20249c3d396d9493", [:rebar3], [], "hexpm", "524c97b4991b3849dd5c17a631223896272c6b0af446778ba4675a1dff53bb7e"},
"comparable": {:hex, :comparable, "1.0.0", "bb669e91cedd14ae9937053e5bcbc3c52bb2f22422611f43b6e38367d94a495f", [:mix], [{:typable, "~> 0.1", [hex: :typable, repo: "hexpm", optional: false]}], "hexpm", "277c11eeb1cd726e7cd41c6c199e7e52fa16ee6830b45ad4cdc62e51f62eb60c"},
"credo": {:hex, :credo, "1.6.4", "ddd474afb6e8c240313f3a7b0d025cc3213f0d171879429bf8535d7021d9ad78", [:mix], [{:bunt, "~> 0.2.0", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2.8", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "c28f910b61e1ff829bffa056ef7293a8db50e87f2c57a9b5c3f57eee124536b7"},
"decimal": {:hex, :decimal, "2.0.0", "a78296e617b0f5dd4c6caf57c714431347912ffb1d0842e998e9792b5642d697", [:mix], [], "hexpm", "34666e9c55dea81013e77d9d87370fe6cb6291d1ef32f46a1600230b1d44f577"},
"deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"},
"dialyxir": {:hex, :dialyxir, "1.2.0", "58344b3e87c2e7095304c81a9ae65cb68b613e28340690dfe1a5597fd08dec37", [:mix], [{:erlex, ">= 0.2.6", [hex: :erlex, repo: "hexpm", optional: false]}], "hexpm", "61072136427a851674cab81762be4dbeae7679f85b1272b6d25c3a839aff8463"},
"docsh": {:hex, :docsh, "0.7.2", "f893d5317a0e14269dd7fe79cf95fb6b9ba23513da0480ec6e77c73221cae4f2", [:rebar3], [{:providers, "1.8.1", [hex: :providers, repo: "hexpm", optional: false]}], "hexpm", "4e7db461bb07540d2bc3d366b8513f0197712d0495bb85744f367d3815076134"},
"earmark": {:hex, :earmark, "1.4.24", "1923e201c3742af421860b983560967cc3e3deacc59c12966bc991a5435565e6", [:mix], [{:earmark_parser, "~> 1.4.25", [hex: :earmark_parser, repo: "hexpm", optional: false]}], "hexpm", "9724242f241f2ad634756d8f2bb57a3d0992cedd10c51842fa655703b4da7c67"},
"earmark_parser": {:hex, :earmark_parser, "1.4.25", "2024618731c55ebfcc5439d756852ec4e85978a39d0d58593763924d9a15916f", [:mix], [], "hexpm", "56749c5e1c59447f7b7a23ddb235e4b3defe276afc220a6227237f3efe83f51e"},
"ecto": {:hex, :ecto, "3.7.2", "44c034f88e1980754983cc4400585970b4206841f6f3780967a65a9150ef09a8", [:mix], [{:decimal, "~> 1.6 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "a600da5772d1c31abbf06f3e4a1ffb150e74ed3e2aa92ff3cee95901657a874e"},
"elixir_make": {:hex, :elixir_make, "0.6.2", "7dffacd77dec4c37b39af867cedaabb0b59f6a871f89722c25b28fcd4bd70530", [:mix], [], "hexpm", "03e49eadda22526a7e5279d53321d1cced6552f344ba4e03e619063de75348d9"},
"elixir_make": {:hex, :elixir_make, "0.6.3", "bc07d53221216838d79e03a8019d0839786703129599e9619f4ab74c8c096eac", [:mix], [], "hexpm", "f5cbd651c5678bcaabdbb7857658ee106b12509cd976c2c2fca99688e1daf716"},
"elixir_sense": {:git, "https://github.com/elixir-lsp/elixir_sense.git", "85d4a87d216678dae30f348270eb90f9ed49ce20", []},
"erlex": {:hex, :erlex, "0.2.6", "c7987d15e899c7a2f34f5420d2a2ea0d659682c06ac607572df55a43753aa12e", [:mix], [], "hexpm", "2ed2e25711feb44d52b17d2780eabf998452f6efda104877a3881c2f8c0c0c75"},
"ets": {:hex, :ets, "0.8.1", "8ff9bcda5682b98493f8878fc9dbd990e48d566cba8cce59f7c2a78130da29ea", [:mix], [], "hexpm", "6be41b50adb5bc5c43626f25ea2d0af1f4a242fb3fad8d53f0c67c20b78915cc"},
@ -41,6 +43,7 @@
"sourceror": {:hex, :sourceror, "0.11.2", "549ce48be666421ac60cfb7f59c8752e0d393baa0b14d06271d3f6a8c1b027ab", [:mix], [], "hexpm", "9ab659118896a36be6eec68ff7b0674cba372fc8e210b1e9dc8cf2b55bb70dfb"},
"spark": {:hex, :spark, "0.2.8", "adce1887be100fae124c65f9dcfb6675fefdcd0ffd6355bb02f638bf88e26630", [:mix], [{:nimble_options, "~> 0.4.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:sourceror, "~> 0.1", [hex: :sourceror, repo: "hexpm", optional: false]}], "hexpm", "1619dc8fd899980e33a7a63939e0a02783ed32c3344d192ff6669bf44ff637a6"},
"ssl_verify_fun": {:hex, :ssl_verify_fun, "1.1.6", "cf344f5692c82d2cd7554f5ec8fd961548d4fd09e7d22f5b62482e5aeaebd4b0", [:make, :mix, :rebar3], [], "hexpm", "bdb0d2471f453c88ff3908e7686f86f9be327d065cc1ec16fa4540197ea04680"},
"statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"},
"stream_data": {:hex, :stream_data, "0.5.0", "b27641e58941685c75b353577dc602c9d2c12292dd84babf506c2033cd97893e", [:mix], [], "hexpm", "012bd2eec069ada4db3411f9115ccafa38540a3c78c4c0349f151fc761b9e271"},
"telemetry": {:hex, :telemetry, "1.1.0", "a589817034a27eab11144ad24d5c0f9fab1f58173274b1e9bae7074af9cbee51", [:rebar3], [], "hexpm", "b727b2a1f75614774cff2d7565b64d0dfa5bd52ba517f16543e6fc7efcc0df48"},
"typable": {:hex, :typable, "0.3.0", "0431e121d124cd26f312123e313d2689b9a5322b15add65d424c07779eaa3ca1", [:mix], [], "hexpm", "880a0797752da1a4c508ac48f94711e04c86156f498065a83d160eef945858f8"},