diff --git a/lib/expr.ex b/lib/expr.ex index a9a152d..f4fd42b 100644 --- a/lib/expr.ex +++ b/lib/expr.ex @@ -344,28 +344,37 @@ defmodule AshPostgres.Expr do when_true = do_dynamic_expr(query, when_true, bindings, pred_embedded? || embedded?, when_true_type) - when_false = - do_dynamic_expr( - query, - when_false, - bindings, - pred_embedded? || embedded?, - when_false_type - ) + {additional_cases, when_false} = + extract_cases(query, when_false, bindings, pred_embedded? || embedded?, when_false_type) + + additional_case_fragments = + additional_cases + |> Enum.flat_map(fn {condition, when_true} -> + [ + raw: " WHEN ", + casted_expr: condition, + raw: " THEN ", + casted_expr: when_true + ] + end) do_dynamic_expr( query, %Fragment{ embedded?: pred_embedded?, - arguments: [ - raw: "(CASE WHEN ", - casted_expr: condition, - raw: " THEN ", - casted_expr: when_true, - raw: " ELSE ", - casted_expr: when_false, - raw: " END)" - ] + arguments: + [ + raw: "(CASE WHEN ", + casted_expr: condition, + raw: " THEN ", + casted_expr: when_true + ] ++ + additional_case_fragments ++ + [ + raw: " ELSE ", + casted_expr: when_false, + raw: " END)" + ] }, bindings, embedded?, @@ -1376,6 +1385,81 @@ defmodule AshPostgres.Expr do end end + defp extract_cases( + query, + expr, + bindings, + embedded?, + type, + acc \\ [] + ) + + defp extract_cases( + query, + %If{arguments: [condition, when_true, when_false], embedded?: pred_embedded?}, + bindings, + embedded?, + type, + acc + ) do + [condition_type, when_true_type, when_false_type] = + case AshPostgres.Types.determine_types(If, [condition, when_true, when_false]) do + [condition_type, when_true] -> + [condition_type, when_true, nil] + + [condition_type, when_true, when_false] -> + [condition_type, when_true, when_false] + end + |> case do + [condition_type, nil, nil] -> + [condition_type, type, type] + + [condition_type, when_true, nil] -> + [condition_type, when_true, type] + + [condition_type, nil, when_false] -> + [condition_type, type, when_false] + + [condition_type, when_true, when_false] -> + [condition_type, when_true, when_false] + end + + condition = + do_dynamic_expr(query, condition, bindings, pred_embedded? || embedded?, condition_type) + + when_true = + do_dynamic_expr(query, when_true, bindings, pred_embedded? || embedded?, when_true_type) + + extract_cases( + query, + when_false, + bindings, + embedded?, + when_false_type, + [{condition, when_true} | acc] + ) + end + + defp extract_cases( + query, + other, + bindings, + embedded?, + type, + acc + ) do + expr = + do_dynamic_expr( + query, + other, + bindings, + embedded?, + type + ) + + {Enum.reverse(acc), expr} + end + defp split_at_paths(type, constraints, next, acc \\ [{:bracket, [], nil, nil}]) defp split_at_paths(_type, _constraints, [], acc) do diff --git a/mix.lock b/mix.lock index 451389b..9e72318 100644 --- a/mix.lock +++ b/mix.lock @@ -1,5 +1,5 @@ %{ - "ash": {:hex, :ash, "2.17.13", "a3bb846238d4eb029da00583554f074d73950cfa4bc2f5964283d2e4491a9543", [:mix], [{:comparable, "~> 1.0", [hex: :comparable, repo: "hexpm", optional: false]}, {:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:earmark, "~> 1.4", [hex: :earmark, repo: "hexpm", optional: false]}, {:ecto, "~> 3.7", [hex: :ecto, repo: "hexpm", optional: false]}, {:ets, "~> 0.8", [hex: :ets, repo: "hexpm", optional: false]}, {:jason, ">= 1.0.0", [hex: :jason, repo: "hexpm", optional: false]}, {:picosat_elixir, "~> 0.2", [hex: :picosat_elixir, repo: "hexpm", optional: false]}, {:plug, ">= 0.0.0", [hex: :plug, repo: "hexpm", optional: true]}, {:spark, ">= 1.1.50 and < 2.0.0-0", [hex: :spark, repo: "hexpm", optional: false]}, {:stream_data, "~> 0.6", [hex: :stream_data, repo: "hexpm", optional: false]}, {:telemetry, "~> 1.1", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "b41a3e029b1553e71a0cab7db66d98a1ea7a883d301b866630ef9d09e64d38ee"}, + "ash": {:hex, :ash, "2.17.14", "ccb68a0eacd7d4f4652a01baa3ceda510d793ab769c82958d644638905f08f7d", [:mix], [{:comparable, "~> 1.0", [hex: :comparable, repo: "hexpm", optional: false]}, {:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:earmark, "~> 1.4", [hex: :earmark, repo: "hexpm", optional: false]}, {:ecto, "~> 3.7", [hex: :ecto, repo: "hexpm", optional: false]}, {:ets, "~> 0.8", [hex: :ets, repo: "hexpm", optional: false]}, {:jason, ">= 1.0.0", [hex: :jason, repo: "hexpm", optional: false]}, {:picosat_elixir, "~> 0.2", [hex: :picosat_elixir, repo: "hexpm", optional: false]}, {:plug, ">= 0.0.0", [hex: :plug, repo: "hexpm", optional: true]}, {:spark, ">= 1.1.50 and < 2.0.0-0", [hex: :spark, repo: "hexpm", optional: false]}, {:stream_data, "~> 0.6", [hex: :stream_data, repo: "hexpm", optional: false]}, {:telemetry, "~> 1.1", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "b09a585924f222a1d353ebd13ec8691c4b9c5e37a8be271ee8960c34feb3fad0"}, "benchee": {:hex, :benchee, "1.1.0", "f3a43817209a92a1fade36ef36b86e1052627fd8934a8b937ac9ab3a76c43062", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}], "hexpm", "7da57d545003165a012b587077f6ba90b89210fd88074ce3c60ce239eb5e6d93"}, "bunt": {:hex, :bunt, "0.2.0", "951c6e801e8b1d2cbe58ebbd3e616a869061ddadcc4863d0a2182541acae9a38", [:mix], [], "hexpm", "7af5c7e09fe1d40f76c8e4f9dd2be7cebd83909f31fee7cd0e9eadc567da8353"}, "certifi": {:hex, :certifi, "2.9.0", "6f2a475689dd47f19fb74334859d460a2dc4e3252a3324bd2111b8f0429e7e21", [:rebar3], [], "hexpm", "266da46bdb06d6c6d35fde799bcb28d36d985d424ad7c08b5bb48f5b5cdd4641"},