diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 9019c8ee85..f6efca3b04 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -14,6 +14,8 @@ #include "exla_nif_util.h" #include "ipc.h" #include "mlir/IR/MLIRContext.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/register.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/pjrt/pjrt_api.h" @@ -67,6 +69,9 @@ mlir_new_context(ErlNifEnv *env, context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); + context->getOrLoadDialect(); + mlir::sdy::registerAllDialects( + const_cast(context->getDialectRegistry())); return context; } @@ -171,6 +176,69 @@ fine::Ok<> mlir_pop_region(ErlNifEnv *env, FINE_NIF(mlir_pop_region, 0); +fine::Ok<> mlir_add_mesh(ErlNifEnv *env, fine::ResourcePtr module, + std::string mesh_name, + std::vector> axes) { + auto builder = module->builder(); + auto context = module->module()->getContext(); + + llvm::SmallVector axis_attrs; + for (auto [name, size] : axes) { + axis_attrs.push_back(mlir::sdy::MeshAxisAttr::get(context, name, size)); + } + + auto mesh_attr = mlir::sdy::MeshAttr::get(context, axis_attrs); + + // Create the mesh op at the beginning of the module + auto module_op = module->module(); + auto &body_region = module_op.getBodyRegion(); + mlir::OpBuilder::InsertionGuard guard(*builder); + builder->setInsertionPointToStart(&body_region.front()); + + mlir::OperationState state(builder->getUnknownLoc(), "sdy.mesh"); + mlir::sdy::MeshOp::build(*builder, state, mesh_name, mesh_attr); + builder->create(state); + + return fine::Ok(); +} + +FINE_NIF(mlir_add_mesh, 0); + +mlir::sdy::TensorShardingAttr mlir_create_tensor_sharding_attr( + mlir::MLIRContext *context, std::string mesh_name, + std::vector> dim_shardings) { + llvm::SmallVector dim_sharding_attrs; + for (const auto &dim : dim_shardings) { + llvm::SmallVector axis_refs; + for (const auto &axis : dim) { + axis_refs.push_back(mlir::sdy::AxisRefAttr::get(context, axis)); + } + dim_sharding_attrs.push_back(mlir::sdy::DimensionShardingAttr::get( + context, axis_refs, /*is_closed=*/false, /*priority=*/0)); + } + + return mlir::sdy::TensorShardingAttr::get( + context, mesh_name, dim_sharding_attrs, + /*replicated_axes=*/llvm::ArrayRef(), + /*unreduced_axes=*/llvm::ArrayRef()); +} + +fine::Ok<> +mlir_set_arg_sharding(ErlNifEnv *env, fine::ResourcePtr function, + int64_t arg_index, std::string mesh_name, + std::vector> dim_shardings) { + + auto context = function->module()->module()->getContext(); + auto sharding_attr = + mlir_create_tensor_sharding_attr(context, mesh_name, dim_shardings); + + function->function().setArgAttr(arg_index, "sdy.sharding", sharding_attr); + + return fine::Ok(); +} + +FINE_NIF(mlir_set_arg_sharding, 0); + mlir::Type mlir_get_typespec(ErlNifEnv *env, fine::ResourcePtr value) { return value->getType(); diff --git a/exla/c_src/exla/exla_client.cc b/exla/c_src/exla/exla_client.cc index c2f5d0fd8d..43cb75650c 100644 --- a/exla/c_src/exla/exla_client.cc +++ b/exla/c_src/exla/exla_client.cc @@ -152,15 +152,24 @@ PjRtBufferFromBinary(xla::PjRtClient *client, ERL_NIF_TERM source_term, tsl::StatusOr>> UnpackRunArguments( ErlNifEnv *env, ExlaExecutable::RunArguments arguments, std::vector> &transient_buffers, - ExlaClient *client, xla::DeviceAssignment device_assignment, - int device_id) { + ExlaClient *client, xla::DeviceAssignment device_assignment, int device_id, + int num_partitions) { std::vector> arg_buffers; arg_buffers.reserve(arguments.size()); - int replica = 0; + int index = 0; for (const auto &replica_arguments : arguments) { - auto device = device_id >= 0 ? device_id : device_assignment(replica, 0); + // For automatic SPMD: each input list goes to a different partition device + // device_assignment is (replica, partition) -> device + // With num_partitions > 1, we iterate through partitions (replica=0, + // partition=0..N-1) For replication, we iterate through replicas + // (replica=0..N-1, partition=0) + int replica = (num_partitions > 1) ? 0 : index; + int partition = (num_partitions > 1) ? index : 0; + + auto device = + device_id >= 0 ? device_id : device_assignment(replica, partition); auto replica_buffers = std::vector(); replica_buffers.reserve(replica_arguments.size()); @@ -200,7 +209,7 @@ tsl::StatusOr>> UnpackRunArguments( arg_buffers.push_back(std::move(replica_buffers)); - replica++; + index++; } return arg_buffers; @@ -216,7 +225,17 @@ UnpackResult(ErlNifEnv *env, for (int i = 0; i < result.size(); i++) { auto replica_results = std::vector>(); - int64_t device = device_id >= 0 ? device_id : device_assignment(i, 0); + + int64_t device; + if (device_id >= 0) { + device = device_id; + } else if (device_assignment.computation_count() > 1) { + // SPMD: results correspond to partitions (replica 0, partition i) + device = device_assignment(0, i); + } else { + // Replication: results correspond to replicas (replica i, partition 0) + device = device_assignment(i, 0); + } for (auto &pjrt_buf : result.at(i)) { pjrt_buf->GetReadyFuture().Await(); @@ -266,20 +285,23 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments, // a pmap, but in all other cases it will be equal to 1 int num_replicas = executable_->num_replicas(); + // the number of partitions is used for SPMD partitioning + int num_partitions = executable_->num_partitions(); + // input buffers are a list of lists, where each list maps to the args // to pass to one of the replicas in a computation, e.g. [replica_args1, // replica_args2, ...] std::vector> input_buffers; // the device assignment is a 2d array which maps coordinates (replica, - // partition) to a device; or in this case just maps a replica to a device + // partition) to a device xla::DeviceAssignment device_assignment; if (client_->client()->platform_name() == "METAL") { device_assignment = xla::DeviceAssignment(1, 1); } else { - EXLA_ASSIGN_OR_RETURN( - device_assignment, - client_->client()->GetDefaultDeviceAssignment(num_replicas, 1)); + EXLA_ASSIGN_OR_RETURN(device_assignment, + client_->client()->GetDefaultDeviceAssignment( + num_replicas, num_partitions)); } // Buffers allocated from binaries for this specific run need to be @@ -300,15 +322,20 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments, EXLA_ASSIGN_OR_RETURN(input_buffers, UnpackRunArguments(env, arguments, transient_buffers, client_, device_assignment, - device_id)); + device_id, num_partitions)); } - // at this point input buffers is a vector of arguments per replica - // and the size of that vector should equal the number of replicas in the - // executable, otherwise it is invalid - if (num_replicas != input_buffers.size()) { - return xla::InvalidArgument("Got %d replica arguments for %d replicas", - input_buffers.size(), num_replicas); + // at this point input buffers is a vector of arguments per device + // For automatic SPMD: one input list per partition (num_partitions lists) + // For standard replication: one input list per replica (num_replicas lists) + // Each input list contains full unreplicated tensors; XLA slices based on + // sharding + int expected_lists = num_partitions > 1 ? num_partitions : num_replicas; + if (input_buffers.size() != expected_lists) { + return xla::InvalidArgument("Got %d argument lists, expected %d " + "(num_replicas=%d, num_partitions=%d)", + input_buffers.size(), expected_lists, + num_replicas, num_partitions); } std::vector>> @@ -333,10 +360,9 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments, // result buffers to unpack per_replica_results.push_back(std::move(portable_result)); } else { - // no device ID is present, so it may be a replicated executable which means - // we need to use the replica execution path - // TODO: This now exposes a `returned_futures` API, does this make sense for - // us? + // no device ID is present, so it may be a replicated or SPMD executable + // For SPMD with num_partitions > 1, Execute handles partitioned execution + // using sharding annotations EXLA_ASSIGN_OR_RETURN(per_replica_results, executable_->Execute(input_buffers, options)); } @@ -344,9 +370,15 @@ ExlaExecutable::Run(ErlNifEnv *env, ExlaExecutable::RunArguments arguments, // EXLA_ASSIGN_OR_RETURN(per_replica_results, // executable_->Execute(input_buffers, options)); - // sanity check - if (per_replica_results.size() != num_replicas) { - return xla::FailedPrecondition("Invalid execution."); + // sanity check - for SPMD we get results per partition, for replication per + // replica + int expected_results = num_partitions > 1 ? num_partitions : num_replicas; + if (per_replica_results.size() != expected_results) { + return xla::FailedPrecondition( + "Invalid execution: got %d results, expected %d (num_replicas=%d, " + "num_partitions=%d)", + per_replica_results.size(), expected_results, num_replicas, + num_partitions); } // we need to unpack the results into Erlang terms, the result is a vector diff --git a/exla/c_src/exla/exla_mlir.h b/exla/c_src/exla/exla_mlir.h index 095ad4c1a7..5d5ef060ee 100644 --- a/exla/c_src/exla/exla_mlir.h +++ b/exla/c_src/exla/exla_mlir.h @@ -29,6 +29,8 @@ class MLIRFunction { llvm::MutableArrayRef GetArguments() { return func_->getBody().front().getArguments(); } + mlir::func::FuncOp function() { return *func_; } + fine::ResourcePtr module() { return module_; } private: diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 275528fa28..6ca4dfb43c 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -149,10 +149,19 @@ defmodule EXLA.Defn do EXLA.Defn.Buffers.from_nx!(arg, executable) end) - EXLA.Executable.run(executable, [buffers], run_options) + input_lists = slice_inputs(buffers, executable) + + EXLA.Executable.run(executable, input_lists, run_options) else [result] -> [EXLA.Defn.Buffers.to_nx!(result, outputs)] + + results when is_list(results) -> + # For SPMD, we get multiple results (one per partition). + # For now, we just take the first one to verify execution. + # TODO: Implement re-assembly of sharded outputs + [first | _] = results + [EXLA.Defn.Buffers.to_nx!(first, outputs)] after EXLA.Defn.Lock.unlock(lock) end @@ -160,6 +169,133 @@ defmodule EXLA.Defn do defp run_key(%{client: %{ref: ref}, device_id: device_id}), do: [ref | device_id] + defp slice_inputs(buffers, %EXLA.Executable{num_partitions: 1}), do: [buffers] + + defp slice_inputs( + buffers, + %EXLA.Executable{ + mesh: mesh, + input_shardings: shardings, + num_partitions: np + } + ) + when np > 1 and not is_nil(mesh) and not is_nil(shardings) do + # Build mesh axis map for quick lookup + mesh_axes = Map.new(mesh.axes) + + # Generate shards for each partition + for partition_idx <- 0..(np - 1) do + # Convert linear partition index to mesh coordinates + coords = unravel_index(partition_idx, mesh.axes) + + # Slice each buffer according to its sharding spec + Enum.zip(buffers, shardings) + |> Enum.map(fn {buffer, sharding} -> + slice_buffer_for_partition(buffer, sharding, coords, mesh_axes) + end) + end + end + + defp slice_inputs(buffers, %EXLA.Executable{num_partitions: np}), + do: List.duplicate(buffers, np) + + # Converts linear partition index to mesh coordinates + # Example: index 3 in [x: 2, y: 2] -> %{x: 1, y: 1} + defp unravel_index(index, axes) do + {coords, _} = + Enum.reduce(Enum.reverse(axes), {%{}, index}, fn {name, size}, {acc, current_idx} -> + coord = rem(current_idx, size) + remaining = div(current_idx, size) + {Map.put(acc, name, coord), remaining} + end) + + coords + end + + # Slices a single buffer for a specific partition based on sharding spec + defp slice_buffer_for_partition( + %EXLA.BinaryBuffer{data: data, typespec: typespec}, + sharding, + coords, + mesh_axes + ) do + # Convert binary buffer to Nx tensor + tensor = binary_buffer_to_nx(data, typespec) + + # Slice along each dimension according to sharding spec + sharded_tensor = + tensor.shape + |> Tuple.to_list() + |> Enum.with_index() + |> Enum.reduce(tensor, fn {dim_size, dim_idx}, acc -> + axis_names = Enum.at(sharding.axes, dim_idx, []) + + if axis_names == [] do + # Dimension is replicated, keep full dimension + acc + else + # Special case: size 1 dimensions cannot be sharded + # Treat them as replicated (effectively remove sharding) + if dim_size == 1 do + acc + else + # Calculate total number of shards for this dimension + # (product of all mesh axes this dimension is sharded on) + shards_count = + Enum.reduce(axis_names, 1, fn name, acc -> + acc * Map.fetch!(mesh_axes, name) + end) + + # Error if dimension size is less than shards_count (and not size 1) + if dim_size < shards_count do + raise ArgumentError, + "Cannot shard dimension #{dim_idx} of size #{dim_size} across #{shards_count} shards. " <> + "Dimension size must be >= shards_count (or size 1 for implicit replication)" + end + + # Calculate chunk size (assuming even division) + chunk_size = div(dim_size, shards_count) + + # Calculate slice index for this partition + slice_idx = + case axis_names do + [name] -> + Map.fetch!(coords, name) + + _ -> + # Multi-axis sharding: calculate linear index from coordinates + # This handles the cartesian product of mesh axes + Enum.reduce(axis_names, 0, fn name, acc -> + coord = Map.fetch!(coords, name) + axis_size = Map.fetch!(mesh_axes, name) + acc * axis_size + coord + end) + end + + # Normal case: evenly divisible + start = slice_idx * chunk_size + Nx.slice_along_axis(acc, start, chunk_size, axis: dim_idx) + end + end + end) + + # Convert back to BinaryBuffer + nx_to_binary_buffer(sharded_tensor) + end + + # Converts BinaryBuffer to Nx tensor + defp binary_buffer_to_nx(data, %EXLA.Typespec{type: type, shape: shape}) do + Nx.from_binary(data, type) |> Nx.reshape(shape) + end + + # Converts Nx tensor to BinaryBuffer + defp nx_to_binary_buffer(tensor) do + %EXLA.BinaryBuffer{ + data: Nx.to_binary(tensor), + typespec: %EXLA.Typespec{type: tensor.type, shape: tensor.shape} + } + end + ## Compile defp compile( @@ -228,6 +364,9 @@ defmodule EXLA.Defn do outfeed = Outfeed.new(hooks, defined_hooks) comp_key = {ref, client.name, outfeed.used_hooks, lazy_transfers, options} + mesh = Keyword.get(options, :mesh) + input_shardings = Keyword.get(options, :input_shardings, []) + {comp_time, {evaled, {xla_time, executable, inputs_and_typespecs, outfeed}}} = :timer.tc(fn -> comp_cache_fun.(comp_key, fn -> @@ -254,6 +393,29 @@ defmodule EXLA.Defn do end) EXLA.MLIR.Module.new(comp_typespecs, out_typespecs, fn builder -> + # Add device mesh to module if provided + if mesh do + EXLA.MLIR.Module.add_mesh(builder.module, mesh) + end + + if !mesh and input_shardings != [] do + raise ArgumentError, "input sharding configs provided but no device mesh was provided" + end + + # Apply sharding annotations to function arguments if provided + if input_shardings != [] do + num_comp_args = length(comp_typespecs) + + if length(input_shardings) != num_comp_args do + raise ArgumentError, + "expected #{num_comp_args} input sharding configs (one per argument), got #{length(input_shardings)}" + end + + Enum.with_index(input_shardings, fn sharding, arg_index -> + Function.set_arg_sharding(builder, arg_index, sharding) + end) + end + # Only create the token when we know it will actually be # used, that is: streaming, lazy transfers or hooks outfeed = @@ -270,6 +432,19 @@ defmodule EXLA.Defn do options = Keyword.put(options, :callback_server_pid, callback_server_pid) + # Compute num_partitions from mesh and enable SPMD if mesh is provided + options = + if mesh do + num_partitions = + Enum.reduce(mesh.axes, 1, fn {_name, size}, acc -> acc * size end) + + options + |> Keyword.put(:num_partitions, num_partitions) + |> Keyword.put(:use_spmd, true) + else + options + end + {xla_time, executable} = :timer.tc(fn -> EXLA.MLIR.Module.compile( diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index 15ffbbdfe0..c49b74ed13 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -7,7 +7,16 @@ defmodule EXLA.Executable do alias EXLA.{BinaryBuffer, DeviceBuffer} @enforce_keys [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id] - defstruct [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id] + defstruct [ + :client, + :ref, + :output_typespecs, + :num_replicas, + :num_partitions, + :device_id, + :mesh, + :input_shardings + ] @doc """ Runs the given executable with a list of lists as inputs and the given options. @@ -45,7 +54,9 @@ defmodule EXLA.Executable do output_typespecs: output_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, - device_id: device_id + device_id: device_id, + mesh: mesh, + input_shardings: input_shardings }) when node(ref) == node() do serialized_exec = @@ -58,7 +69,9 @@ defmodule EXLA.Executable do output_typespecs: output_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, - device_id: device_id + device_id: device_id, + mesh: mesh, + input_shardings: input_shardings } end @@ -85,6 +98,8 @@ defmodule EXLA.Executable do num_replicas: num_replicas, num_partitions: num_partitions, device_id: device_id, + mesh: Map.get(data, :mesh), + input_shardings: Map.get(data, :input_shardings), ref: ref, client: client } diff --git a/exla/lib/exla/mlir/function.ex b/exla/lib/exla/mlir/function.ex index 7b1157955a..80d59decc9 100644 --- a/exla/lib/exla/mlir/function.ex +++ b/exla/lib/exla/mlir/function.ex @@ -36,4 +36,14 @@ defmodule EXLA.MLIR.Function do def pop_region(%Function{ref: ref}) do EXLA.NIF.mlir_pop_region(ref) end + + @doc """ + Sets sharding annotation for a function argument. + """ + def set_arg_sharding(%Function{ref: ref}, arg_index, %EXLA.Sharding.TensorSharding{ + mesh_name: mesh, + axes: dims + }) do + EXLA.NIF.mlir_set_arg_sharding(ref, arg_index, mesh, dims) + end end diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index d1ba3d0b0b..fceb607e7e 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -130,10 +130,20 @@ defmodule EXLA.MLIR.Module do output_typespecs: return_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, - device_id: device_id + device_id: device_id, + mesh: Keyword.get(options, :mesh), + input_shardings: Keyword.get(options, :input_shardings) } end + @doc """ + Adds a device mesh definition to the module. + """ + def add_mesh(%__MODULE__{ref: module_ref}, %EXLA.Sharding.DeviceMesh{name: name, axes: axes}) do + EXLA.NIF.mlir_add_mesh(module_ref, name, axes) + :ok + end + @doc """ Returns a human-readable representation of the module using MLIR syntax. diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 6e70d07d57..85d051f835 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -28,6 +28,8 @@ defmodule EXLA.NIF do def mlir_op(_function, _op_name, _operands, _result_type, _attributes, _blocks), do: err!() def mlir_push_region(_function, _arg_types), do: err!() def mlir_pop_region(_function), do: err!() + def mlir_add_mesh(_module, _mesh_name, _axes), do: err!() + def mlir_set_arg_sharding(_function, _arg_index, _mesh_name, _dim_shardings), do: err!() def mlir_build(_function, _root), do: err!() def mlir_compile( diff --git a/exla/lib/exla/sharding.ex b/exla/lib/exla/sharding.ex new file mode 100644 index 0000000000..281ae530bf --- /dev/null +++ b/exla/lib/exla/sharding.ex @@ -0,0 +1,68 @@ +defmodule EXLA.Sharding do + @moduledoc """ + Helper module for defining Shardy device meshes and tensor sharding specifications. + """ + + defmodule DeviceMesh do + @moduledoc """ + Represents a device mesh configuration. + """ + @enforce_keys [:name, :axes] + defstruct [:name, :axes] + + @type axis :: {name :: String.t(), size :: pos_integer()} + @type t :: %__MODULE__{ + name: String.t(), + axes: [axis()] + } + end + + defmodule TensorSharding do + @moduledoc """ + Represents a sharding specification for a tensor. + """ + @enforce_keys [:mesh_name, :axes] + defstruct [:mesh_name, :axes] + + @type dim_sharding :: [String.t()] + @type t :: %__MODULE__{ + mesh_name: String.t(), + axes: [dim_sharding()] + } + end + + @doc """ + Creates a device mesh definition. + + ## Examples + + iex> EXLA.Sharding.mesh(:my_mesh, x: 2, y: 4) + %EXLA.Sharding.DeviceMesh{name: "my_mesh", axes: [{"x", 2}, {"y", 4}]} + """ + def mesh(name, axes) when (is_atom(name) or is_binary(name)) and is_list(axes) do + normalized_axes = + Enum.map(axes, fn {k, v} -> {to_string(k), v} end) + + %DeviceMesh{name: to_string(name), axes: normalized_axes} + end + + @doc """ + Creates a sharding specification for a tensor. + + The `dim_shardings` list must match the rank of the tensor. + Each element is a list of axis names that the corresponding dimension is sharded on. + + ## Examples + + # Rank 2 tensor, dim 0 sharded on "x", dim 1 sharded on "y" + iex> EXLA.Sharding.sharding(:my_mesh, [["x"], ["y"]]) + %EXLA.Sharding.TensorSharding{mesh_name: "my_mesh", axes: [["x"], ["y"]]} + + # Rank 2 tensor, dim 0 sharded on "x", dim 1 replicated + iex> EXLA.Sharding.sharding(:my_mesh, [["x"], []]) + %EXLA.Sharding.TensorSharding{mesh_name: "my_mesh", axes: [["x"], []]} + """ + def sharding(mesh_name, dim_shardings) do + %TensorSharding{mesh_name: to_string(mesh_name), axes: dim_shardings} + end +end diff --git a/exla/sharding.exs b/exla/sharding.exs new file mode 100644 index 0000000000..a43aeecda4 --- /dev/null +++ b/exla/sharding.exs @@ -0,0 +1,15 @@ +fun = fn x, y -> {Nx.add(x, y), Nx.multiply(x, y)} end +args = [Nx.iota({8, 2}), Nx.iota({8, 1})] + +mesh = EXLA.Sharding.mesh("mesh", x: 2, y: 2, z: 2) + +input_shardings = [EXLA.Sharding.sharding("mesh", [["x", "z"], ["y"]]), EXLA.Sharding.sharding("mesh", [["x", "z"], []])] + +result = EXLA.to_mlir_module(fun, args, mesh: mesh, input_shardings: input_shardings) + +IO.puts(result.mlir_module) + +result = EXLA.jit_apply(fun, args, mesh: mesh, input_shardings: input_shardings) +dbg(result) + +# run with: XLA_FLAGS="--xla_dump_to=/tmp/xla_dump --xla_dump_hlo_as_text --xla_force_host_platform_device_count=10" mix run sharding.exs