diff --git a/lib/mars/aggregator.rb b/lib/mars/aggregator.rb index 0866e9f..d21b3bd 100644 --- a/lib/mars/aggregator.rb +++ b/lib/mars/aggregator.rb @@ -2,12 +2,11 @@ module MARS class Aggregator < Runnable - attr_reader :name, :operation + attr_reader :operation def initialize(name = "Aggregator", operation: nil, **kwargs) - super(**kwargs) + super(name: name, **kwargs) - @name = name @operation = operation || ->(inputs) { inputs } end diff --git a/lib/mars/gate.rb b/lib/mars/gate.rb index 21e407d..8a2bafd 100644 --- a/lib/mars/gate.rb +++ b/lib/mars/gate.rb @@ -2,12 +2,9 @@ module MARS class Gate < Runnable - attr_reader :name - def initialize(name = "Gate", condition:, branches:, **kwargs) - super(**kwargs) + super(name: name, **kwargs) - @name = name @condition = condition @branches = branches end diff --git a/lib/mars/runnable.rb b/lib/mars/runnable.rb index f7a4909..ea071e5 100644 --- a/lib/mars/runnable.rb +++ b/lib/mars/runnable.rb @@ -2,10 +2,28 @@ module MARS class Runnable + include Hooks + + attr_reader :name, :formatter attr_accessor :state - def initialize(state: {}) + class << self + def step_name + return @step_name if defined?(@step_name) + return unless name + + name.split("::").last.gsub(/([a-z])([A-Z])/, '\1_\2').downcase + end + + def formatter(klass = nil) + klass ? @formatter_class = klass : @formatter_class + end + end + + def initialize(name: self.class.step_name, state: {}, formatter: nil) + @name = name @state = state + @formatter = formatter || self.class.formatter&.new || Formatter.new end def run(input) diff --git a/lib/mars/workflows/parallel.rb b/lib/mars/workflows/parallel.rb index da65a27..ef8f3f6 100644 --- a/lib/mars/workflows/parallel.rb +++ b/lib/mars/workflows/parallel.rb @@ -3,12 +3,9 @@ module MARS module Workflows class Parallel < Runnable - attr_reader :name - def initialize(name, steps:, aggregator: nil, **kwargs) - super(**kwargs) + super(name: name, **kwargs) - @name = name @steps = steps @aggregator = aggregator || Aggregator.new("#{name} Aggregator") end diff --git a/lib/mars/workflows/sequential.rb b/lib/mars/workflows/sequential.rb index 0db625e..df673c6 100644 --- a/lib/mars/workflows/sequential.rb +++ b/lib/mars/workflows/sequential.rb @@ -3,12 +3,9 @@ module MARS module Workflows class Sequential < Runnable - attr_reader :name - def initialize(name, steps:, **kwargs) - super(**kwargs) + super(name: name, **kwargs) - @name = name @steps = steps end diff --git a/spec/mars/runnable_spec.rb b/spec/mars/runnable_spec.rb index 6c3e5f7..972a549 100644 --- a/spec/mars/runnable_spec.rb +++ b/spec/mars/runnable_spec.rb @@ -42,6 +42,73 @@ def run(input) end end + describe "#name" do + it "defaults to nil for anonymous classes" do + klass = Class.new(described_class) + expect(klass.new.name).to be_nil + end + + it "can be set via the name keyword" do + runnable = described_class.new(name: "my_step") + expect(runnable.name).to eq("my_step") + end + + it "derives step_name from the class name" do + stub_const("MARS::MyCustomStep", Class.new(described_class)) + expect(MARS::MyCustomStep.new.name).to eq("my_custom_step") + end + end + + describe "#formatter" do + it "defaults to a Formatter instance" do + runnable = described_class.new + expect(runnable.formatter).to be_a(MARS::Formatter) + end + + it "can be set via the formatter keyword" do + custom_formatter = MARS::Formatter.new + runnable = described_class.new(formatter: custom_formatter) + expect(runnable.formatter).to eq(custom_formatter) + end + + it "uses the class-level formatter when declared" do + custom_formatter_class = Class.new(MARS::Formatter) + klass = Class.new(described_class) do + formatter custom_formatter_class + end + + expect(klass.new.formatter).to be_a(custom_formatter_class) + end + end + + describe "hooks" do + it "includes Hooks module" do + expect(described_class.ancestors).to include(MARS::Hooks) + end + + it "supports before_run hooks" do + klass = Class.new(described_class) + calls = [] + klass.before_run { |_ctx, step| calls << step.name } + + step = klass.new(name: "test") + step.run_before_hooks(MARS::ExecutionContext.new(input: "x")) + + expect(calls).to eq(["test"]) + end + + it "supports after_run hooks" do + klass = Class.new(described_class) + calls = [] + klass.after_run { |_ctx, result, _step| calls << result } + + step = klass.new(name: "test") + step.run_after_hooks(MARS::ExecutionContext.new(input: "x"), "result") + + expect(calls).to eq(["result"]) + end + end + describe "inheritance" do it "can be inherited" do subclass = Class.new(described_class)