diff --git a/lib/polaris/shared.ex b/lib/polaris/shared.ex index 5bf77fa..30bb58d 100644 --- a/lib/polaris/shared.ex +++ b/lib/polaris/shared.ex @@ -21,11 +21,11 @@ defmodule Polaris.Shared do """ deftransform fulls_like(params, value, opts \\ []) do opts = Keyword.validate!(opts, [:type]) - fun = &Nx.broadcast(Nx.tensor(value, type: &2), &1) + fun = &Nx.broadcast(Nx.tensor(value, type: &2), Nx.shape(&1), names: Nx.names(&1)) deep_new(params, fn x -> type = opts[:type] || Nx.type(x) - fun.(Nx.shape(x), type) + fun.(x, type) end) end