From e2855843ca4a9141dcd7f40f47227e13b43f0e00 Mon Sep 17 00:00:00 2001 From: Zach Daniel Date: Mon, 11 Sep 2023 22:28:51 -0400 Subject: [PATCH] improvement: support vector types --- lib/ash/type/type.ex | 52 ++++++++++++++++++-------------------- lib/ash/type/vector.ex | 49 ++++++++++++++++++++++++++++++++++++ lib/ash/vector.ex | 57 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+), 27 deletions(-) create mode 100644 lib/ash/type/vector.ex create mode 100644 lib/ash/vector.ex diff --git a/lib/ash/type/type.ex b/lib/ash/type/type.ex index 66058d71..4809ce1c 100644 --- a/lib/ash/type/type.ex +++ b/lib/ash/type/type.ex @@ -25,33 +25,31 @@ defmodule Ash.Type do ] @builtin_short_names [ - map: "Ash.Type.Map", - keyword: "Ash.Type.Keyword", - term: "Ash.Type.Term", - atom: "Ash.Type.Atom", - string: "Ash.Type.String", - integer: "Ash.Type.Integer", - float: "Ash.Type.Float", - duration_name: "Ash.Type.DurationName", - function: "Ash.Type.Function", - boolean: "Ash.Type.Boolean", - struct: "Ash.Type.Struct", - uuid: "Ash.Type.UUID", - binary: "Ash.Type.Binary", - date: "Ash.Type.Date", - time: "Ash.Type.Time", - decimal: "Ash.Type.Decimal", - ci_string: "Ash.Type.CiString", - naive_datetime: "Ash.Type.NaiveDatetime", - utc_datetime: "Ash.Type.UtcDatetime", - utc_datetime_usec: "Ash.Type.UtcDatetimeUsec", - url_encoded_binary: "Ash.Type.UrlEncodedBinary", - union: "Ash.Type.Union", - module: "Ash.Type.Module" - ] - |> Enum.map(fn {key, value} -> - {key, Module.concat([value])} - end) + map: Ash.Type.Map, + keyword: Ash.Type.Keyword, + term: Ash.Type.Term, + atom: Ash.Type.Atom, + string: Ash.Type.String, + integer: Ash.Type.Integer, + float: Ash.Type.Float, + duration_name: Ash.Type.DurationName, + function: Ash.Type.Function, + boolean: Ash.Type.Boolean, + struct: Ash.Type.Struct, + uuid: Ash.Type.UUID, + binary: Ash.Type.Binary, + date: Ash.Type.Date, + time: Ash.Type.Time, + decimal: Ash.Type.Decimal, + ci_string: Ash.Type.CiString, + naive_datetime: Ash.Type.NaiveDatetime, + utc_datetime: Ash.Type.UtcDatetime, + utc_datetime_usec: Ash.Type.UtcDatetimeUsec, + url_encoded_binary: Ash.Type.UrlEncodedBinary, + union: Ash.Type.Union, + module: Ash.Type.Module, + vector: Ash.Type.Vector + ] @custom_short_names Application.compile_env(:ash, :custom_types, []) diff --git a/lib/ash/type/vector.ex b/lib/ash/type/vector.ex new file mode 100644 index 00000000..bddc4368 --- /dev/null +++ b/lib/ash/type/vector.ex @@ -0,0 +1,49 @@ +defmodule Ash.Type.Vector do + @moduledoc """ + Represents a vector. + + A builtin type that can be referenced via `:vector` + """ + + use Ash.Type + + @impl true + def storage_type(_), do: :vector + + @impl true + def generator(_constraints) do + StreamData.list_of(StreamData.float()) + end + + @impl true + def cast_input(value, _) do + Ash.Vector.new(value) + end + + @impl true + def cast_stored(nil, _), do: {:ok, nil} + + def cast_stored(%Ash.Vector{} = vector, _) do + {:ok, vector} + end + + def cast_stored(value, _) when is_list(value) do + case Ash.Vector.new(value) do + {:ok, vector} -> {:ok, vector} + {:error, _} -> :error + end + end + + @impl true + def dump_to_native(nil, _), do: {:ok, nil} + + def dump_to_native(%Ash.Vector{data: data}, _) do + {:ok, data} + end + + def dump_to_native(value, constraints) when is_list(value) do + with {:ok, value} <- cast_input(value, constraints) do + dump_to_native(value, constraints) + end + end +end diff --git a/lib/ash/vector.ex b/lib/ash/vector.ex new file mode 100644 index 00000000..08958d28 --- /dev/null +++ b/lib/ash/vector.ex @@ -0,0 +1,57 @@ +defmodule Ash.Vector do + @moduledoc """ + A vector struct for Ash. + + Implementation based off of https://github.com/pgvector/pgvector-elixir/blob/v0.2.0/lib/pgvector.ex + """ + + defstruct [:data] + + @doc """ + Creates a new vector from a list or tensor + """ + def new(binary) when is_binary(binary) do + from_binary(binary) + end + + def new(%__MODULE__{} = vector), do: {:ok, vector} + + def new(list) when is_list(list) do + # this definitely has failure cases + dim = list |> length() + bin = for v <- list, into: "", do: <> + {:ok, from_binary(<>)} + rescue + _ -> {:error, :invalid_vector} + end + + @doc """ + Creates a new vector from its binary representation + """ + def from_binary(binary) when is_binary(binary) do + %Ash.Vector{data: binary} + end + + @doc """ + Converts the vector to its binary representation + """ + def to_binary(vector) when is_struct(vector, Ash.Vector) do + vector.data + end + + @doc """ + Converts the vector to a list + """ + def to_list(vector) when is_struct(vector, Ash.Vector) do + <> = vector.data + for <>, do: v + end +end + +defimpl Inspect, for: Ash.Vector do + import Inspect.Algebra + + def inspect(vec, opts) do + concat(["Ash.Vector.new(", Inspect.List.inspect(Ash.Vector.to_list(vec), opts), ")"]) + end +end