mirror of
https://github.com/ash-project/ash.git
synced 2024-09-20 13:33:20 +12:00
improvement: support vector types
This commit is contained in:
parent
cedcda903f
commit
e2855843ca
3 changed files with 131 additions and 27 deletions
|
@ -25,33 +25,31 @@ defmodule Ash.Type do
|
||||||
]
|
]
|
||||||
|
|
||||||
@builtin_short_names [
|
@builtin_short_names [
|
||||||
map: "Ash.Type.Map",
|
map: Ash.Type.Map,
|
||||||
keyword: "Ash.Type.Keyword",
|
keyword: Ash.Type.Keyword,
|
||||||
term: "Ash.Type.Term",
|
term: Ash.Type.Term,
|
||||||
atom: "Ash.Type.Atom",
|
atom: Ash.Type.Atom,
|
||||||
string: "Ash.Type.String",
|
string: Ash.Type.String,
|
||||||
integer: "Ash.Type.Integer",
|
integer: Ash.Type.Integer,
|
||||||
float: "Ash.Type.Float",
|
float: Ash.Type.Float,
|
||||||
duration_name: "Ash.Type.DurationName",
|
duration_name: Ash.Type.DurationName,
|
||||||
function: "Ash.Type.Function",
|
function: Ash.Type.Function,
|
||||||
boolean: "Ash.Type.Boolean",
|
boolean: Ash.Type.Boolean,
|
||||||
struct: "Ash.Type.Struct",
|
struct: Ash.Type.Struct,
|
||||||
uuid: "Ash.Type.UUID",
|
uuid: Ash.Type.UUID,
|
||||||
binary: "Ash.Type.Binary",
|
binary: Ash.Type.Binary,
|
||||||
date: "Ash.Type.Date",
|
date: Ash.Type.Date,
|
||||||
time: "Ash.Type.Time",
|
time: Ash.Type.Time,
|
||||||
decimal: "Ash.Type.Decimal",
|
decimal: Ash.Type.Decimal,
|
||||||
ci_string: "Ash.Type.CiString",
|
ci_string: Ash.Type.CiString,
|
||||||
naive_datetime: "Ash.Type.NaiveDatetime",
|
naive_datetime: Ash.Type.NaiveDatetime,
|
||||||
utc_datetime: "Ash.Type.UtcDatetime",
|
utc_datetime: Ash.Type.UtcDatetime,
|
||||||
utc_datetime_usec: "Ash.Type.UtcDatetimeUsec",
|
utc_datetime_usec: Ash.Type.UtcDatetimeUsec,
|
||||||
url_encoded_binary: "Ash.Type.UrlEncodedBinary",
|
url_encoded_binary: Ash.Type.UrlEncodedBinary,
|
||||||
union: "Ash.Type.Union",
|
union: Ash.Type.Union,
|
||||||
module: "Ash.Type.Module"
|
module: Ash.Type.Module,
|
||||||
]
|
vector: Ash.Type.Vector
|
||||||
|> Enum.map(fn {key, value} ->
|
]
|
||||||
{key, Module.concat([value])}
|
|
||||||
end)
|
|
||||||
|
|
||||||
@custom_short_names Application.compile_env(:ash, :custom_types, [])
|
@custom_short_names Application.compile_env(:ash, :custom_types, [])
|
||||||
|
|
||||||
|
|
49
lib/ash/type/vector.ex
Normal file
49
lib/ash/type/vector.ex
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
defmodule Ash.Type.Vector do
|
||||||
|
@moduledoc """
|
||||||
|
Represents a vector.
|
||||||
|
|
||||||
|
A builtin type that can be referenced via `:vector`
|
||||||
|
"""
|
||||||
|
|
||||||
|
use Ash.Type
|
||||||
|
|
||||||
|
@impl true
|
||||||
|
def storage_type(_), do: :vector
|
||||||
|
|
||||||
|
@impl true
|
||||||
|
def generator(_constraints) do
|
||||||
|
StreamData.list_of(StreamData.float())
|
||||||
|
end
|
||||||
|
|
||||||
|
@impl true
|
||||||
|
def cast_input(value, _) do
|
||||||
|
Ash.Vector.new(value)
|
||||||
|
end
|
||||||
|
|
||||||
|
@impl true
|
||||||
|
def cast_stored(nil, _), do: {:ok, nil}
|
||||||
|
|
||||||
|
def cast_stored(%Ash.Vector{} = vector, _) do
|
||||||
|
{:ok, vector}
|
||||||
|
end
|
||||||
|
|
||||||
|
def cast_stored(value, _) when is_list(value) do
|
||||||
|
case Ash.Vector.new(value) do
|
||||||
|
{:ok, vector} -> {:ok, vector}
|
||||||
|
{:error, _} -> :error
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@impl true
|
||||||
|
def dump_to_native(nil, _), do: {:ok, nil}
|
||||||
|
|
||||||
|
def dump_to_native(%Ash.Vector{data: data}, _) do
|
||||||
|
{:ok, data}
|
||||||
|
end
|
||||||
|
|
||||||
|
def dump_to_native(value, constraints) when is_list(value) do
|
||||||
|
with {:ok, value} <- cast_input(value, constraints) do
|
||||||
|
dump_to_native(value, constraints)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
57
lib/ash/vector.ex
Normal file
57
lib/ash/vector.ex
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
defmodule Ash.Vector do
|
||||||
|
@moduledoc """
|
||||||
|
A vector struct for Ash.
|
||||||
|
|
||||||
|
Implementation based off of https://github.com/pgvector/pgvector-elixir/blob/v0.2.0/lib/pgvector.ex
|
||||||
|
"""
|
||||||
|
|
||||||
|
defstruct [:data]
|
||||||
|
|
||||||
|
@doc """
|
||||||
|
Creates a new vector from a list or tensor
|
||||||
|
"""
|
||||||
|
def new(binary) when is_binary(binary) do
|
||||||
|
from_binary(binary)
|
||||||
|
end
|
||||||
|
|
||||||
|
def new(%__MODULE__{} = vector), do: {:ok, vector}
|
||||||
|
|
||||||
|
def new(list) when is_list(list) do
|
||||||
|
# this definitely has failure cases
|
||||||
|
dim = list |> length()
|
||||||
|
bin = for v <- list, into: "", do: <<v::float-32>>
|
||||||
|
{:ok, from_binary(<<dim::unsigned-16, 0::unsigned-16, bin::binary>>)}
|
||||||
|
rescue
|
||||||
|
_ -> {:error, :invalid_vector}
|
||||||
|
end
|
||||||
|
|
||||||
|
@doc """
|
||||||
|
Creates a new vector from its binary representation
|
||||||
|
"""
|
||||||
|
def from_binary(binary) when is_binary(binary) do
|
||||||
|
%Ash.Vector{data: binary}
|
||||||
|
end
|
||||||
|
|
||||||
|
@doc """
|
||||||
|
Converts the vector to its binary representation
|
||||||
|
"""
|
||||||
|
def to_binary(vector) when is_struct(vector, Ash.Vector) do
|
||||||
|
vector.data
|
||||||
|
end
|
||||||
|
|
||||||
|
@doc """
|
||||||
|
Converts the vector to a list
|
||||||
|
"""
|
||||||
|
def to_list(vector) when is_struct(vector, Ash.Vector) do
|
||||||
|
<<dim::unsigned-16, 0::unsigned-16, bin::binary-size(dim)-unit(32)>> = vector.data
|
||||||
|
for <<v::float-32 <- bin>>, do: v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
defimpl Inspect, for: Ash.Vector do
|
||||||
|
import Inspect.Algebra
|
||||||
|
|
||||||
|
def inspect(vec, opts) do
|
||||||
|
concat(["Ash.Vector.new(", Inspect.List.inspect(Ash.Vector.to_list(vec), opts), ")"])
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in a new issue