mirror of
https://github.com/ash-project/ash.git
synced 2024-09-20 13:33:20 +12:00
improvement: support vector types
This commit is contained in:
parent
cedcda903f
commit
e2855843ca
3 changed files with 131 additions and 27 deletions
|
@ -25,33 +25,31 @@ defmodule Ash.Type do
|
|||
]
|
||||
|
||||
@builtin_short_names [
|
||||
map: "Ash.Type.Map",
|
||||
keyword: "Ash.Type.Keyword",
|
||||
term: "Ash.Type.Term",
|
||||
atom: "Ash.Type.Atom",
|
||||
string: "Ash.Type.String",
|
||||
integer: "Ash.Type.Integer",
|
||||
float: "Ash.Type.Float",
|
||||
duration_name: "Ash.Type.DurationName",
|
||||
function: "Ash.Type.Function",
|
||||
boolean: "Ash.Type.Boolean",
|
||||
struct: "Ash.Type.Struct",
|
||||
uuid: "Ash.Type.UUID",
|
||||
binary: "Ash.Type.Binary",
|
||||
date: "Ash.Type.Date",
|
||||
time: "Ash.Type.Time",
|
||||
decimal: "Ash.Type.Decimal",
|
||||
ci_string: "Ash.Type.CiString",
|
||||
naive_datetime: "Ash.Type.NaiveDatetime",
|
||||
utc_datetime: "Ash.Type.UtcDatetime",
|
||||
utc_datetime_usec: "Ash.Type.UtcDatetimeUsec",
|
||||
url_encoded_binary: "Ash.Type.UrlEncodedBinary",
|
||||
union: "Ash.Type.Union",
|
||||
module: "Ash.Type.Module"
|
||||
]
|
||||
|> Enum.map(fn {key, value} ->
|
||||
{key, Module.concat([value])}
|
||||
end)
|
||||
map: Ash.Type.Map,
|
||||
keyword: Ash.Type.Keyword,
|
||||
term: Ash.Type.Term,
|
||||
atom: Ash.Type.Atom,
|
||||
string: Ash.Type.String,
|
||||
integer: Ash.Type.Integer,
|
||||
float: Ash.Type.Float,
|
||||
duration_name: Ash.Type.DurationName,
|
||||
function: Ash.Type.Function,
|
||||
boolean: Ash.Type.Boolean,
|
||||
struct: Ash.Type.Struct,
|
||||
uuid: Ash.Type.UUID,
|
||||
binary: Ash.Type.Binary,
|
||||
date: Ash.Type.Date,
|
||||
time: Ash.Type.Time,
|
||||
decimal: Ash.Type.Decimal,
|
||||
ci_string: Ash.Type.CiString,
|
||||
naive_datetime: Ash.Type.NaiveDatetime,
|
||||
utc_datetime: Ash.Type.UtcDatetime,
|
||||
utc_datetime_usec: Ash.Type.UtcDatetimeUsec,
|
||||
url_encoded_binary: Ash.Type.UrlEncodedBinary,
|
||||
union: Ash.Type.Union,
|
||||
module: Ash.Type.Module,
|
||||
vector: Ash.Type.Vector
|
||||
]
|
||||
|
||||
@custom_short_names Application.compile_env(:ash, :custom_types, [])
|
||||
|
||||
|
|
49
lib/ash/type/vector.ex
Normal file
49
lib/ash/type/vector.ex
Normal file
|
@ -0,0 +1,49 @@
|
|||
defmodule Ash.Type.Vector do
|
||||
@moduledoc """
|
||||
Represents a vector.
|
||||
|
||||
A builtin type that can be referenced via `:vector`
|
||||
"""
|
||||
|
||||
use Ash.Type
|
||||
|
||||
@impl true
|
||||
def storage_type(_), do: :vector
|
||||
|
||||
@impl true
|
||||
def generator(_constraints) do
|
||||
StreamData.list_of(StreamData.float())
|
||||
end
|
||||
|
||||
@impl true
|
||||
def cast_input(value, _) do
|
||||
Ash.Vector.new(value)
|
||||
end
|
||||
|
||||
@impl true
|
||||
def cast_stored(nil, _), do: {:ok, nil}
|
||||
|
||||
def cast_stored(%Ash.Vector{} = vector, _) do
|
||||
{:ok, vector}
|
||||
end
|
||||
|
||||
def cast_stored(value, _) when is_list(value) do
|
||||
case Ash.Vector.new(value) do
|
||||
{:ok, vector} -> {:ok, vector}
|
||||
{:error, _} -> :error
|
||||
end
|
||||
end
|
||||
|
||||
@impl true
|
||||
def dump_to_native(nil, _), do: {:ok, nil}
|
||||
|
||||
def dump_to_native(%Ash.Vector{data: data}, _) do
|
||||
{:ok, data}
|
||||
end
|
||||
|
||||
def dump_to_native(value, constraints) when is_list(value) do
|
||||
with {:ok, value} <- cast_input(value, constraints) do
|
||||
dump_to_native(value, constraints)
|
||||
end
|
||||
end
|
||||
end
|
57
lib/ash/vector.ex
Normal file
57
lib/ash/vector.ex
Normal file
|
@ -0,0 +1,57 @@
|
|||
defmodule Ash.Vector do
|
||||
@moduledoc """
|
||||
A vector struct for Ash.
|
||||
|
||||
Implementation based off of https://github.com/pgvector/pgvector-elixir/blob/v0.2.0/lib/pgvector.ex
|
||||
"""
|
||||
|
||||
defstruct [:data]
|
||||
|
||||
@doc """
|
||||
Creates a new vector from a list or tensor
|
||||
"""
|
||||
def new(binary) when is_binary(binary) do
|
||||
from_binary(binary)
|
||||
end
|
||||
|
||||
def new(%__MODULE__{} = vector), do: {:ok, vector}
|
||||
|
||||
def new(list) when is_list(list) do
|
||||
# this definitely has failure cases
|
||||
dim = list |> length()
|
||||
bin = for v <- list, into: "", do: <<v::float-32>>
|
||||
{:ok, from_binary(<<dim::unsigned-16, 0::unsigned-16, bin::binary>>)}
|
||||
rescue
|
||||
_ -> {:error, :invalid_vector}
|
||||
end
|
||||
|
||||
@doc """
|
||||
Creates a new vector from its binary representation
|
||||
"""
|
||||
def from_binary(binary) when is_binary(binary) do
|
||||
%Ash.Vector{data: binary}
|
||||
end
|
||||
|
||||
@doc """
|
||||
Converts the vector to its binary representation
|
||||
"""
|
||||
def to_binary(vector) when is_struct(vector, Ash.Vector) do
|
||||
vector.data
|
||||
end
|
||||
|
||||
@doc """
|
||||
Converts the vector to a list
|
||||
"""
|
||||
def to_list(vector) when is_struct(vector, Ash.Vector) do
|
||||
<<dim::unsigned-16, 0::unsigned-16, bin::binary-size(dim)-unit(32)>> = vector.data
|
||||
for <<v::float-32 <- bin>>, do: v
|
||||
end
|
||||
end
|
||||
|
||||
defimpl Inspect, for: Ash.Vector do
|
||||
import Inspect.Algebra
|
||||
|
||||
def inspect(vec, opts) do
|
||||
concat(["Ash.Vector.new(", Inspect.List.inspect(Ash.Vector.to_list(vec), opts), ")"])
|
||||
end
|
||||
end
|
Loading…
Reference in a new issue