diff --git a/Project.toml b/Project.toml index ee29ee1..63561ca 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NumExpr" uuid = "005f7402-6e25-4d9a-960d-a0ddd50a2fba" -version = "1.0.0" +version = "1.1.0" [compat] julia = "1.8" diff --git a/benchmark/Project.toml b/benchmark/Project.toml new file mode 100644 index 0000000..a0fa77c --- /dev/null +++ b/benchmark/Project.toml @@ -0,0 +1,3 @@ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +NumExpr = "005f7402-6e25-4d9a-960d-a0ddd50a2fba" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl new file mode 100644 index 0000000..d8490df --- /dev/null +++ b/benchmark/benchmarks.jl @@ -0,0 +1,118 @@ +using BenchmarkTools +using NumExpr + +const SUITE = BenchmarkGroup() + +# ───────────────────────────────────────────────────────────────────────────── +# Shared test data +# ───────────────────────────────────────────────────────────────────────────── + +const BENCH_VARS = Dict{String,Float64}( + "a" => 1.5, "b" => 2.3, "c" => 0.7, "d" => 4.1, "e" => 5.9, + "f" => 3.2, "g" => 7.8, "h" => 6.4, "i" => 8.1, "j" => 9.3, + "k" => 2.7, "x" => 1.2, +) + +NumExpr.eval_expr(var::NumExpr.Variable) = get(BENCH_VARS, var[], NaN) + +const BENCH_EXPRS = [ + "constant" => "42", + "simple" => "a + b", + "medium" => "a + b * c - d / e", + "trig" => "sin(x) ^ 2 + cos(x) ^ 2", + "func_chain" => "sqrt(abs(a * b - c))", + "complex" => "sin(a) * cos(b) + exp(c) / log(d + 1) - sqrt(abs(e))", + "large" => "a + b * c + d * e + f * g + h * i + j * k", +] + +# ───────────────────────────────────────────────────────────────────────────── +# 1. Parse +# ───────────────────────────────────────────────────────────────────────────── + +SUITE["parse"] = BenchmarkGroup(["parsing"]) + +for (label, expr_str) in BENCH_EXPRS + SUITE["parse"][label] = @benchmarkable parse_expr($expr_str) +end + +# ───────────────────────────────────────────────────────────────────────────── +# 2. Compile (parse + compile) +# ───────────────────────────────────────────────────────────────────────────── + +SUITE["compile"] = BenchmarkGroup(["compilation"]) + +for (label, expr_str) in BENCH_EXPRS + SUITE["compile"][label] = @benchmarkable compile_expr($expr_str, VarContext()) +end + +# ───────────────────────────────────────────────────────────────────────────── +# 3. Eval — tree walk +# ───────────────────────────────────────────────────────────────────────────── + +SUITE["eval"] = BenchmarkGroup(["evaluation"]) +SUITE["eval"]["tree"] = BenchmarkGroup(["tree-walk"]) + +for (label, expr_str) in BENCH_EXPRS + node = parse_expr(expr_str) + SUITE["eval"]["tree"][label] = @benchmarkable eval_expr($node) +end + +# ───────────────────────────────────────────────────────────────────────────── +# 4. Eval — bytecode VM (pre-allocated stack) +# ───────────────────────────────────────────────────────────────────────────── + +SUITE["eval"]["vm"] = BenchmarkGroup(["bytecode", "vm"]) + +for (label, expr_str) in BENCH_EXPRS + ctx = VarContext() + compiled = compile_expr(expr_str, ctx) + values = zeros(Float64, length(ctx)) + for (name, val) in BENCH_VARS + haskey(ctx, name) && (values[ctx[name]] = val) + end + stack = Vector{Float64}(undef, compiled.max_stack) + SUITE["eval"]["vm"][label] = @benchmarkable eval_compiled($compiled, $values, $stack) +end + +# ───────────────────────────────────────────────────────────────────────────── +# 5. Eval — bytecode VM (auto stack) +# ───────────────────────────────────────────────────────────────────────────── + +SUITE["eval"]["vm_auto"] = BenchmarkGroup(["bytecode", "vm", "auto-stack"]) + +for (label, expr_str) in BENCH_EXPRS + ctx = VarContext() + compiled = compile_expr(expr_str, ctx) + values = zeros(Float64, length(ctx)) + for (name, val) in BENCH_VARS + haskey(ctx, name) && (values[ctx[name]] = val) + end + SUITE["eval"]["vm_auto"][label] = @benchmarkable eval_compiled($compiled, $values) +end + +# ───────────────────────────────────────────────────────────────────────────── +# 6. Throughput — bulk evaluation +# ───────────────────────────────────────────────────────────────────────────── + +SUITE["throughput"] = BenchmarkGroup(["bulk", "throughput"]) + +let + ctx = VarContext() + formulas_1k = [compile_expr("$(rand()) + $(rand()) * $(rand())", ctx) for _ in 1:1_000] + formulas_10k = [compile_expr("$(rand()) + $(rand()) * $(rand())", ctx) for _ in 1:10_000] + values = Float64[] + stack_1k = Vector{Float64}(undef, maximum(f.max_stack for f in formulas_1k)) + stack_10k = Vector{Float64}(undef, maximum(f.max_stack for f in formulas_10k)) + + SUITE["throughput"]["1k_formulas"] = @benchmarkable begin + @inbounds for f in $formulas_1k + eval_compiled(f, $values, $stack_1k) + end + end + + SUITE["throughput"]["10k_formulas"] = @benchmarkable begin + @inbounds for f in $formulas_10k + eval_compiled(f, $values, $stack_10k) + end + end +end diff --git a/src/NumExpr.jl b/src/NumExpr.jl index c36cb92..6a61f61 100644 --- a/src/NumExpr.jl +++ b/src/NumExpr.jl @@ -3,7 +3,11 @@ module NumExpr export parse_expr, eval_expr, isglobal_scope, - islocal_scope + islocal_scope, + var_has_tags, + VarContext, + compile_expr, + eval_compiled #__ exceptions @@ -48,5 +52,9 @@ struct LocalScope <: AbstractScope end include("utils.jl") include("parser.jl") include("eval.jl") +include("opcodes.jl") +include("context.jl") +include("compiler.jl") +include("vm.jl") end diff --git a/src/compiler.jl b/src/compiler.jl new file mode 100644 index 0000000..f5b13f1 --- /dev/null +++ b/src/compiler.jl @@ -0,0 +1,221 @@ +# compiler + +""" + CompiledExpr + +Compact bytecode representation of a numeric expression. +Created by [`compile_expr`](@ref) and evaluated by [`eval_compiled`](@ref). + +## Fields +- `code::Vector{UInt8}`: Bytecode in postfix (RPN) order. +- `constants::Vector{Float64}`: Constant pool for numeric literals. +- `nvars::Int`: Maximum variable index referenced. +- `max_stack::Int`: Stack depth required for evaluation. +""" +struct CompiledExpr + code::Vector{UInt8} + constants::Vector{Float64} + nvars::Int + max_stack::Int +end + +function Base.show(io::IO, expr::CompiledExpr) + print( + io, + "CompiledExpr(", + length(expr.code), " bytes, ", + length(expr.constants), " consts, ", + "stack=", expr.max_stack, ")", + ) +end + +#__ emitter + +mutable struct Emitter + const code::Vector{UInt8} + const constants::Vector{Float64} + const const_index::Dict{Float64,Int} + const ctx::VarContext + nvars::Int + depth::Int + max_depth::Int + + Emitter(ctx::VarContext) = new(UInt8[], Float64[], Dict{Float64,Int}(), ctx, 0, 0, 0) +end + +function emit!(e::Emitter, op::UInt8)::Nothing + push!(e.code, op) + return nothing +end + +function emit_u16!(e::Emitter, val::Int)::Nothing + push!(e.code, UInt8(val & 0xFF)) + push!(e.code, UInt8((val >> 8) & 0xFF)) + return nothing +end + +function push_depth!(e::Emitter)::Nothing + e.depth += 1 + e.depth > e.max_depth && (e.max_depth = e.depth) + return nothing +end + +function pop_depth!(e::Emitter)::Nothing + e.depth -= 1 + return nothing +end + +function add_const!(e::Emitter, val::Float64)::Int + idx = get(e.const_index, val, 0) + if idx == 0 + push!(e.constants, val) + idx = length(e.constants) + idx > typemax(UInt16) && error("constant pool exceeds UInt16 limit: $idx") + e.const_index[val] = idx + end + return idx +end + +#__ compile dispatch + +function compile_node!(e::Emitter, val::NumVal{Float64})::Nothing + x = val[] + if isnan(x) + emit!(e, OP_LOAD_NAN) + elseif x === 0.0 + emit!(e, OP_LOAD_ZERO) + elseif x === 1.0 + emit!(e, OP_LOAD_ONE) + else + emit!(e, OP_LOAD_CONST) + emit_u16!(e, add_const!(e, x)) + end + push_depth!(e) + return nothing +end + +function compile_node!(e::Emitter, val::NumVal{Bool})::Nothing + emit!(e, val[] ? OP_LOAD_TRUE : OP_LOAD_FALSE) + push_depth!(e) + return nothing +end + +function compile_node!(e::Emitter, var::Variable)::Nothing + idx = get_or_create!(e.ctx, var[]) + idx > typemax(UInt16) && error("variable index exceeds UInt16 limit: $idx") + idx > e.nvars && (e.nvars = idx) + emit!(e, OP_LOAD_VAR) + emit_u16!(e, idx) + push_depth!(e) + return nothing +end + +function compile_node!(::Emitter, ::StrVal)::Nothing + error("String operations are not supported in compiled mode. Use eval_expr instead.") +end + +function compile_node!(e::Emitter, node::ExprNode)::Nothing + head = node.head + args = node.args + n = length(args) + + # Unary minus + if head isa Arithmetic{:-} && n == 1 + compile_node!(e, args[1]) + emit!(e, OP_NEG) + return nothing + end + + # Unary plus (identity) + if head isa Arithmetic{:+} && n == 1 + compile_node!(e, args[1]) + return nothing + end + + # Functions (unary, binary, ternary) + if head isa AbstractFuncOperator + op = opcode(head) + if op == OP_MEAN + n < 1 && error("mean requires at least 1 argument, got $n") + n > 255 && error("mean supports at most 255 arguments, got $n") + for i in 1:n + compile_node!(e, args[i]) + end + emit!(e, op) + push!(e.code, UInt8(n)) + for _ in 2:n + pop_depth!(e) + end + return nothing + elseif op == OP_IFELSE + n != 3 && error("ifelse requires exactly 3 arguments, got $n") + compile_node!(e, args[1]) + compile_node!(e, args[2]) + compile_node!(e, args[3]) + emit!(e, op) + pop_depth!(e) + pop_depth!(e) + return nothing + elseif op in (OP_MAX2, OP_MIN2, OP_GET, OP_ROUND, OP_ISLESS, OP_DIV_INT, OP_REM) + n != 2 && error("function $(head) requires exactly 2 arguments, got $n") + compile_node!(e, args[1]) + compile_node!(e, args[2]) + emit!(e, op) + pop_depth!(e) + return nothing + else + n != 1 && error("compiled mode supports only unary function $(head), got $n arguments") + compile_node!(e, args[1]) + emit!(e, op) + return nothing + end + end + + # Binary / n-ary operators (left-associative fold) + n < 2 && error("operator $(head) requires at least 2 operands, got $n") + op = opcode(head) + compile_node!(e, args[1]) + for i in 2:n + compile_node!(e, args[i]) + emit!(e, op) + pop_depth!(e) + end + return nothing +end + +#__ public API + +""" + compile_expr(node::Union{AbstractExpr,ExprNode}, ctx::VarContext) -> CompiledExpr + +Compile a parsed expression tree into compact bytecode. +All formulas compiled with the same `ctx` share variable indices. + +""" +function compile_expr(node::Union{AbstractExpr,ExprNode}, ctx::VarContext)::CompiledExpr + e = Emitter(ctx) + compile_node!(e, node) + return CompiledExpr(e.code, e.constants, e.nvars, e.max_depth) +end + +""" + compile_expr(str::AbstractString, ctx::VarContext) -> CompiledExpr + +Parse and compile a string expression into compact bytecode. + +## Examples + +```julia-repl +julia> ctx = VarContext() +VarContext(0 variables) + +julia> f = compile_expr("a + b * sin(c)", ctx) +CompiledExpr(13 bytes, 0 consts, stack=2) + +julia> ctx +VarContext(3 variables) +``` +""" +function compile_expr(str::AbstractString, ctx::VarContext)::CompiledExpr + return compile_expr(parse_expr(str), ctx) +end diff --git a/src/context.jl b/src/context.jl new file mode 100644 index 0000000..c386b9d --- /dev/null +++ b/src/context.jl @@ -0,0 +1,65 @@ +# context + +""" + VarContext(; sizehint::Int = 0) + +Shared variable registry that maps variable names to integer indices. +All formulas compiled with the same `VarContext` share variable indices, +so a single `values::Vector{Float64}` can be used to evaluate any of them. + +## Indexing + +- `ctx[i::Integer]` — variable name by index. +- `ctx[name::String]` — index by variable name. +- `haskey(ctx, name)` — check if a variable is registered. + +## Examples + +```julia-repl +julia> ctx = VarContext() +VarContext(0 variables) + +julia> compile_expr("price + tax", ctx) +CompiledExpr(9 bytes, 0 consts, stack=2) + +julia> ctx["price"] +1 + +julia> ctx["tax"] +2 +``` +""" +struct VarContext + names::Vector{String} + lookup::Dict{String,Int} + + function VarContext(; sizehint::Int = 0) + names = sizehint > 0 ? sizehint!(String[], sizehint) : String[] + lookup = sizehint > 0 ? sizehint!(Dict{String,Int}(), sizehint) : Dict{String,Int}() + return new(names, lookup) + end +end + +#__ accessors + +Base.length(ctx::VarContext)::Int = length(ctx.names) +Base.getindex(ctx::VarContext, i::Integer)::String = ctx.names[i] +Base.getindex(ctx::VarContext, name::AbstractString)::Int = ctx.lookup[name] +Base.haskey(ctx::VarContext, name::AbstractString)::Bool = haskey(ctx.lookup, name) +Base.keys(ctx::VarContext)::Vector{String} = ctx.names + +function Base.show(io::IO, ctx::VarContext) + print(io, "VarContext(", length(ctx), " variables)") +end + +#__ mutation + +function get_or_create!(ctx::VarContext, name::String)::Int + idx = get(ctx.lookup, name, 0) + if idx == 0 + push!(ctx.names, name) + idx = length(ctx.names) + ctx.lookup[name] = idx + end + return idx +end diff --git a/src/eval.jl b/src/eval.jl index f9fa225..7a53e0a 100644 --- a/src/eval.jl +++ b/src/eval.jl @@ -47,9 +47,27 @@ call(::Func{:sqrt}, x::Number) = sqrt(x) call(::Func{:abs}, x::Number) = abs(x) call(::Func{:sin}, x::Number) = sin(x) call(::Func{:cos}, x::Number) = cos(x) +call(::Func{:tg}, x::Number) = tan(x) +call(::Func{:ctg}, x::Number) = cot(x) call(::Func{:atan}, x::Number) = atan(x) call(::Func{:exp}, x::Number) = exp(x) call(::Func{:log}, x::Number) = log(x) +call(::Func{:div}, x::Number, y::Number) = div(x, y) +call(::Func{:rem}, x::Number, y::Number) = rem(x, y) +call(::Func{:mean}, x::Number...) = sum(x) / length(x) + +call(::Func{:isnan}, x::Number) = Float64(isnan(x)) +call(::Func{:not}, x::Number) = isnan(x) ? NaN : Float64(x == 0) +call(::Func{:iszero}, x::Number) = Float64(x == 0) +call(::Func{:isone}, x::Number) = Float64(x == 1) +call(::Func{:floor}, x::Number) = floor(x) +call(::Func{:ceil}, x::Number) = ceil(x) +call(::Func{:max}, a::Number, b::Number) = max(a, b) +call(::Func{:min}, a::Number, b::Number) = min(a, b) +call(::Func{:ifelse}, c::Number, a::Number, b::Number) = isnan(c) ? NaN : (c != 0 ? a : b) +call(::Func{:get}, a::Number, b::Number) = isnan(a) ? b : a +call(::Func{:round}, a::Number, b::Number) = isnan(b) ? NaN : round(a; digits = Int(b)) +call(::Func{:isless}, a::Number, b::Number) = Float64(isless(a, b)) #__ convert diff --git a/src/opcodes.jl b/src/opcodes.jl new file mode 100644 index 0000000..8b82c8e --- /dev/null +++ b/src/opcodes.jl @@ -0,0 +1,118 @@ +# opcodes + +#__ load + +const OP_LOAD_CONST = 0x01 +const OP_LOAD_VAR = 0x02 +const OP_LOAD_TRUE = 0x03 +const OP_LOAD_FALSE = 0x04 +const OP_LOAD_NAN = 0x05 +const OP_LOAD_ZERO = 0x06 +const OP_LOAD_ONE = 0x07 + +#__ arithmetic + +const OP_ADD = 0x08 +const OP_SUB = 0x09 +const OP_MUL = 0x0A +const OP_DIV = 0x0B +const OP_POW = 0x0C +const OP_MOD = 0x0D +const OP_NEG = 0x0E +const OP_REM = 0x0F + +#__ comparison + +const OP_GT = 0x10 +const OP_LT = 0x11 +const OP_GE = 0x12 +const OP_LE = 0x13 +const OP_EQ = 0x14 +const OP_NE = 0x15 + +#__ logic + +const OP_AND = 0x16 +const OP_OR = 0x17 + +#__ function + +const OP_SQRT = 0x18 +const OP_ABS = 0x19 +const OP_SIN = 0x1A +const OP_COS = 0x1B +const OP_ATAN = 0x1C +const OP_EXP = 0x1D +const OP_LOG = 0x1E + +#__ unary predicates + +const OP_ISNAN = 0x1F +const OP_NOT = 0x20 +const OP_ISZERO = 0x21 +const OP_ISONE = 0x22 +const OP_FLOOR = 0x23 +const OP_CEIL = 0x24 +const OP_TG = 0x25 +const OP_CTG = 0x26 + +#__ binary / ternary functions + +const OP_MAX2 = 0x27 +const OP_MIN2 = 0x28 +const OP_IFELSE = 0x29 +const OP_GET = 0x2A +const OP_ROUND = 0x2B +const OP_ISLESS = 0x2C +const OP_DIV_INT = 0x2D +const OP_MEAN = 0x2E + +#__ opcode dispatch + +# Fallback: any function operator without an explicit opcode is unsupported. +opcode(f::AbstractFuncOperator) = + error("compiled mode does not support function '$(operator(f))'") + +opcode(::Arithmetic{:+})::UInt8 = OP_ADD +opcode(::Arithmetic{:-})::UInt8 = OP_SUB +opcode(::Arithmetic{:*})::UInt8 = OP_MUL +opcode(::Arithmetic{:/})::UInt8 = OP_DIV +opcode(::Arithmetic{:^})::UInt8 = OP_POW +opcode(::Arithmetic{:%})::UInt8 = OP_MOD + +opcode(::Logic{:>})::UInt8 = OP_GT +opcode(::Logic{:<})::UInt8 = OP_LT +opcode(::Logic{:>=})::UInt8 = OP_GE +opcode(::Logic{:<=})::UInt8 = OP_LE +opcode(::Logic{:(==)})::UInt8 = OP_EQ +opcode(::Logic{:!=})::UInt8 = OP_NE + +opcode(::Logic{:&&})::UInt8 = OP_AND +opcode(::Logic{:||})::UInt8 = OP_OR + +opcode(::Func{:sqrt})::UInt8 = OP_SQRT +opcode(::Func{:abs})::UInt8 = OP_ABS +opcode(::Func{:sin})::UInt8 = OP_SIN +opcode(::Func{:cos})::UInt8 = OP_COS +opcode(::Func{:atan})::UInt8 = OP_ATAN +opcode(::Func{:exp})::UInt8 = OP_EXP +opcode(::Func{:log})::UInt8 = OP_LOG + +opcode(::Func{:isnan})::UInt8 = OP_ISNAN +opcode(::Func{:not})::UInt8 = OP_NOT +opcode(::Func{:iszero})::UInt8 = OP_ISZERO +opcode(::Func{:isone})::UInt8 = OP_ISONE +opcode(::Func{:floor})::UInt8 = OP_FLOOR +opcode(::Func{:ceil})::UInt8 = OP_CEIL +opcode(::Func{:tg})::UInt8 = OP_TG +opcode(::Func{:ctg})::UInt8 = OP_CTG + +opcode(::Func{:max})::UInt8 = OP_MAX2 +opcode(::Func{:min})::UInt8 = OP_MIN2 +opcode(::Func{:ifelse})::UInt8 = OP_IFELSE +opcode(::Func{:get})::UInt8 = OP_GET +opcode(::Func{:round})::UInt8 = OP_ROUND +opcode(::Func{:isless})::UInt8 = OP_ISLESS +opcode(::Func{:div})::UInt8 = OP_DIV_INT +opcode(::Func{:mean})::UInt8 = OP_MEAN +opcode(::Func{:rem})::UInt8 = OP_REM diff --git a/src/parser.jl b/src/parser.jl index 09add1e..bc7eda7 100644 --- a/src/parser.jl +++ b/src/parser.jl @@ -55,7 +55,7 @@ For more information see section [Variables](@ref variable_vals). ## Fields - `val::String`: Full name of the variable including its tags. - `name::String`: The variable name without tags -- `tags::Dict{String,String}`: Tags of the variable. +- `tags::Union{Nothing,Dict{String,String}}`: Tags of the variable (`nothing` when tag-less). ## Examples @@ -87,7 +87,7 @@ julia> var[] struct Variable{S<:AbstractScope} <: AbstractValue val::String name::String - tags::Dict{String,String} + tags::Union{Nothing,Dict{String,String}} end (isglobal_scope(::Variable{S})::Bool) where {S<:AbstractScope} = S <: GlobalScope @@ -106,7 +106,11 @@ struct NumVal{T<:Real} <: AbstractValue end var_name(x::Variable)::String = x.name -var_tags(x::Variable)::Dict{String,String} = x.tags + +const _EMPTY_TAGS = Dict{String,String}() + +var_tags(x::Variable)::Dict{String,String} = x.tags === nothing ? _EMPTY_TAGS : x.tags +var_has_tags(x::Variable)::Bool = x.tags !== nothing Base.getindex(x::AbstractValue) = getfield(x, :val) Base.show(io::IO, n::AbstractValue) = print(io, n[]) @@ -114,7 +118,7 @@ Base.show(io::IO, n::AbstractValue) = print(io, n[]) function parse_var_format1(::Type{S}, chars::Vector{Char})::Variable{S} where {S<:AbstractScope} str_val = String(chars) - return Variable{S}(str_val, str_val, Dict{String,String}()) + return Variable{S}(str_val, str_val, nothing) end function parse_var_format2(::Type{S}, chars::Vector{Char})::Variable{S} where {S<:AbstractScope} diff --git a/src/vm.jl b/src/vm.jl new file mode 100644 index 0000000..418ca72 --- /dev/null +++ b/src/vm.jl @@ -0,0 +1,378 @@ +# vm + +#__ variable adapter + +struct ResolverAdapter{F} + f::F +end + +@inline Base.getindex(r::ResolverAdapter, idx::Int)::Float64 = r.f(idx) + +#__ bytecode decoding + +@inline function _read_u16(code_ptr::Ptr{UInt8}, pos::Int)::Int + lo = unsafe_load(code_ptr, pos) + hi = unsafe_load(code_ptr, pos + 1) + return Int(lo) | (Int(hi) << 8) +end + + +@inline _vars_pointer(vars::Vector{Float64}) = pointer(vars) +@inline _vars_pointer(_) = nothing + +@inline _read_var(::Vector{Float64}, vars_ptr::Ptr{Float64}, idx::Int) = + unsafe_load(vars_ptr, idx) +@inline _read_var(vars, ::Nothing, idx::Int) = + @inbounds vars[idx] + +#__ core VM loop + +function vm_eval( + code::Vector{UInt8}, + consts::Vector{Float64}, + vars::V, + stack::Vector{Float64}, +)::Float64 where {V} + n = length(code) + GC.@preserve code consts stack vars begin + code_ptr = pointer(code) + consts_ptr = pointer(consts) + stack_ptr = pointer(stack) + vars_ptr = _vars_pointer(vars) + + sp = 0 + pc = 1 + while pc <= n + op = unsafe_load(code_ptr, pc) + + if op == OP_LOAD_CONST + idx = _read_u16(code_ptr, pc + 1) + sp += 1 + unsafe_store!(stack_ptr, unsafe_load(consts_ptr, idx), sp) + pc += 3 + elseif op == OP_LOAD_VAR + idx = _read_u16(code_ptr, pc + 1) + sp += 1 + unsafe_store!(stack_ptr, _read_var(vars, vars_ptr, idx), sp) + pc += 3 + elseif op == OP_LOAD_TRUE + sp += 1 + unsafe_store!(stack_ptr, 1.0, sp) + pc += 1 + elseif op == OP_LOAD_FALSE + sp += 1 + unsafe_store!(stack_ptr, 0.0, sp) + pc += 1 + elseif op == OP_LOAD_NAN + sp += 1 + unsafe_store!(stack_ptr, NaN, sp) + pc += 1 + elseif op == OP_LOAD_ZERO + sp += 1 + unsafe_store!(stack_ptr, 0.0, sp) + pc += 1 + elseif op == OP_LOAD_ONE + sp += 1 + unsafe_store!(stack_ptr, 1.0, sp) + pc += 1 + elseif op == OP_ADD + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, a + b, sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_SUB + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, a - b, sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_MUL + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, a * b, sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_DIV + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, a / b, sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_POW + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, a ^ b, sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_MOD + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, a % b, sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_REM + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, rem(a, b), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_NEG + unsafe_store!(stack_ptr, -unsafe_load(stack_ptr, sp), sp) + pc += 1 + elseif op == OP_GT + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, Float64(a > b), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_LT + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, Float64(a < b), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_GE + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, Float64(a >= b), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_LE + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, Float64(a <= b), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_EQ + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, Float64(a == b), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_NE + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, Float64(a != b), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_AND + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, Float64((a != 0.0) & (b != 0.0)), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_OR + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, Float64((a != 0.0) | (b != 0.0)), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_SQRT + unsafe_store!(stack_ptr, sqrt(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_ABS + unsafe_store!(stack_ptr, abs(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_SIN + unsafe_store!(stack_ptr, sin(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_COS + unsafe_store!(stack_ptr, cos(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_ATAN + unsafe_store!(stack_ptr, atan(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_EXP + unsafe_store!(stack_ptr, exp(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_LOG + unsafe_store!(stack_ptr, log(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_ISNAN + unsafe_store!(stack_ptr, Float64(isnan(unsafe_load(stack_ptr, sp))), sp) + pc += 1 + elseif op == OP_NOT + x = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, isnan(x) ? NaN : (x == 0.0 ? 1.0 : 0.0), sp) + pc += 1 + elseif op == OP_ISZERO + unsafe_store!(stack_ptr, Float64(unsafe_load(stack_ptr, sp) == 0.0), sp) + pc += 1 + elseif op == OP_ISONE + unsafe_store!(stack_ptr, Float64(unsafe_load(stack_ptr, sp) == 1.0), sp) + pc += 1 + elseif op == OP_FLOOR + unsafe_store!(stack_ptr, floor(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_CEIL + unsafe_store!(stack_ptr, ceil(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_TG + unsafe_store!(stack_ptr, tan(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_CTG + unsafe_store!(stack_ptr, cot(unsafe_load(stack_ptr, sp)), sp) + pc += 1 + elseif op == OP_MAX2 + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, max(a, b), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_MIN2 + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, min(a, b), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_IFELSE + c = unsafe_load(stack_ptr, sp - 2) + v_then = unsafe_load(stack_ptr, sp - 1) + v_else = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, + isnan(c) ? NaN : (c != 0.0 ? v_then : v_else), + sp - 2) + sp -= 2 + pc += 1 + elseif op == OP_GET + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, isnan(a) ? b : a, sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_ROUND + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, + isnan(b) ? NaN : round(a; digits = Int(b)), + sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_ISLESS + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, Float64(isless(a, b)), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_DIV_INT + a = unsafe_load(stack_ptr, sp - 1) + b = unsafe_load(stack_ptr, sp) + unsafe_store!(stack_ptr, div(a, b), sp - 1) + sp -= 1 + pc += 1 + elseif op == OP_MEAN + count = Int(unsafe_load(code_ptr, pc + 1)) + s = 0.0 + base = sp - count + 1 + for i in base:sp + s += unsafe_load(stack_ptr, i) + end + unsafe_store!(stack_ptr, s / count, base) + sp = base + pc += 2 + else + error("unknown opcode: 0x$(string(op, base=16, pad=2))") + end + end + + return unsafe_load(stack_ptr, 1) + end +end + +#__ public API + +""" + eval_compiled(expr::CompiledExpr, values::AbstractVector{Float64}, stack::Vector{Float64}) -> Float64 + +Evaluate a compiled expression with a pre-allocated stack (zero heap allocations). +This is the fastest evaluation path, suitable for hot loops. + +## Examples + +```julia-repl +julia> ctx = VarContext() +VarContext(0 variables) + +julia> f = compile_expr("a + b * 2", ctx) +CompiledExpr(9 bytes, 1 consts, stack=2) + +julia> values = [3.0, 4.0] +julia> stack = Vector{Float64}(undef, f.max_stack) + +julia> eval_compiled(f, values, stack) +11.0 +``` +""" +function eval_compiled( + expr::CompiledExpr, + values::AbstractVector{Float64}, + stack::Vector{Float64}, +)::Float64 + return vm_eval(expr.code, expr.constants, values, stack) +end + +""" + eval_compiled(expr::CompiledExpr, values::AbstractVector{Float64}) -> Float64 + +Evaluate a compiled expression (allocates a temporary stack). + +## Examples + +```julia-repl +julia> ctx = VarContext() +VarContext(0 variables) + +julia> f = compile_expr("a + b * 2", ctx) +CompiledExpr(9 bytes, 1 consts, stack=2) + +julia> eval_compiled(f, [3.0, 4.0]) +11.0 +``` +""" +function eval_compiled(expr::CompiledExpr, values::AbstractVector{Float64})::Float64 + stack = Vector{Float64}(undef, expr.max_stack) + return vm_eval(expr.code, expr.constants, values, stack) +end + +""" + eval_compiled(expr::CompiledExpr, resolver) -> Float64 + +Evaluate a compiled expression using a resolver callable for variable lookup. +The resolver is called as `resolver(index::Int) -> Float64` for each variable reference. +Any callable (function, functor, closure) is accepted. +""" +function eval_compiled(expr::CompiledExpr, resolver)::Float64 + stack = Vector{Float64}(undef, expr.max_stack) + return vm_eval(expr.code, expr.constants, ResolverAdapter(resolver), stack) +end + +""" + eval_compiled(expr::CompiledExpr, ctx::VarContext, pairs::Pair...) -> Float64 + +Evaluate a compiled expression with named variable values (convenience method, allocates). + +## Examples + +```julia-repl +julia> ctx = VarContext() +VarContext(0 variables) + +julia> f = compile_expr("price + tax", ctx) +CompiledExpr(9 bytes, 0 consts, stack=2) + +julia> eval_compiled(f, ctx, "price" => 100.0, "tax" => 8.0) +108.0 +``` +""" +function eval_compiled( + expr::CompiledExpr, + ctx::VarContext, + pairs::Pair{<:AbstractString,<:Real}..., +)::Float64 + values = zeros(Float64, length(ctx)) + for (name, val) in pairs + values[ctx[name]] = Float64(val) + end + return eval_compiled(expr, values) +end diff --git a/test/runtests.jl b/test/runtests.jl index e7cbd61..c9b6055 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -234,6 +234,12 @@ end @test eval_expr(parse_expr("atan(2)")) == atan(2) @test eval_expr(parse_expr("exp(2)")) == exp(2) @test eval_expr(parse_expr("log(2)")) == log(2) + @test eval_expr(parse_expr("tg(0)")) == 0.0 + @test eval_expr(parse_expr("tg(1)")) == tan(1) + @test eval_expr(parse_expr("ctg(1)")) ≈ cos(1) / sin(1) + @test eval_expr(parse_expr("mean(2, 4, 6)")) == 4.0 + @test eval_expr(parse_expr("mean(10)")) == 10.0 + @test eval_expr(parse_expr("mean(1, 2, 3, 4, 5)")) == 3.0 end @testset verbose = true "Eval math const" begin @@ -513,7 +519,8 @@ end parsed_local = parse_expr(local_var_str) @test parsed_local.val == "type='tool', size='M4'" - @test isempty(parsed_local.tags) + @test parsed_local.tags === nothing + @test !var_has_tags(parsed_local) end @testset "Parsing: Invalid Syntax Error" begin @@ -557,4 +564,1243 @@ end parse_expr(var_str) end end + + #__ Boolean / NaN literal parsing ──────────────────────────────── + + @testset verbose = true "Boolean literal parsing" begin + # true / false → NumVal{Bool} + @test vectorial(parse_expr("true")) == ["true"] + @test vectorial(parse_expr("false")) == ["false"] + @test eval_expr(parse_expr("true")) === true + @test eval_expr(parse_expr("false")) === false + + # Case-sensitive: mixed / upper case → variables + @test vectorial(parse_expr("True")) == ["True"] + @test vectorial(parse_expr("False")) == ["False"] + @test vectorial(parse_expr("TRUE")) == ["TRUE"] + @test vectorial(parse_expr("FALSE")) == ["FALSE"] + @test vectorial(parse_expr("tRue")) == ["tRue"] + @test vectorial(parse_expr("fAlse")) == ["fAlse"] + + # 5-char words ending in 'e' must NOT be parsed as false + # (regression test for the || → && fix in isfalsenumber) + @test vectorial(parse_expr("state")) == ["state"] + @test vectorial(parse_expr("valve")) == ["valve"] + @test vectorial(parse_expr("parse")) == ["parse"] + @test vectorial(parse_expr("scope")) == ["scope"] + @test vectorial(parse_expr("value")) == ["value"] + @test vectorial(parse_expr("close")) == ["close"] + @test vectorial(parse_expr("price")) == ["price"] + @test vectorial(parse_expr("space")) == ["space"] + @test vectorial(parse_expr("store")) == ["store"] + @test vectorial(parse_expr("trade")) == ["trade"] + + # 4-char words ending in 'e' must NOT be parsed as true + @test vectorial(parse_expr("time")) == ["time"] + @test vectorial(parse_expr("type")) == ["type"] + @test vectorial(parse_expr("rate")) == ["rate"] + @test vectorial(parse_expr("size")) == ["size"] + @test vectorial(parse_expr("cure")) == ["cure"] + @test vectorial(parse_expr("pure")) == ["pure"] + @test vectorial(parse_expr("more")) == ["more"] + @test vectorial(parse_expr("sure")) == ["sure"] + + # Prefixed / suffixed true / false → variables + @test vectorial(parse_expr("truex")) == ["truex"] + @test vectorial(parse_expr("atrue")) == ["atrue"] + @test vectorial(parse_expr("truer")) == ["truer"] + @test vectorial(parse_expr("falsex")) == ["falsex"] + @test vectorial(parse_expr("afalse")) == ["afalse"] + @test vectorial(parse_expr("falser")) == ["falser"] + @test vectorial(parse_expr("istrue")) == ["istrue"] + @test vectorial(parse_expr("isfalse")) == ["isfalse"] + @test vectorial(parse_expr("truefalse")) == ["truefalse"] + @test vectorial(parse_expr("falsetrue")) == ["falsetrue"] + end + + @testset verbose = true "NaN literal parsing" begin + @test vectorial(parse_expr("NaN")) == ["NaN"] + @test eval_expr(parse_expr("NaN")) === NaN + + # Case-sensitive — other cases are variables + @test vectorial(parse_expr("nan")) == ["nan"] + @test vectorial(parse_expr("NAN")) == ["NAN"] + @test vectorial(parse_expr("Nan")) == ["Nan"] + @test vectorial(parse_expr("naN")) == ["naN"] + + # Prefixed / suffixed + @test vectorial(parse_expr("NaNx")) == ["NaNx"] + @test vectorial(parse_expr("aNaN")) == ["aNaN"] + @test vectorial(parse_expr("isNaN")) == ["isNaN"] + @test vectorial(parse_expr("NaN1")) == ["NaN1"] + end + + #__ Operator precedence ────────────────────────────────────────── + + @testset verbose = true "Operator precedence comprehensive" begin + # ^ (6) tighter than * / (5) + @test vectorial(parse_expr("2 * 3 ^ 4")) == ["*", "2.0", ["^", "3.0", "4.0"]] + @test vectorial(parse_expr("2 ^ 3 * 4")) == ["*", ["^", "2.0", "3.0"], "4.0"] + @test vectorial(parse_expr("2 / 3 ^ 4")) == ["/", "2.0", ["^", "3.0", "4.0"]] + + # * / (5) tighter than + - (3) + @test vectorial(parse_expr("2 + 3 * 4")) == ["+", "2.0", ["*", "3.0", "4.0"]] + @test vectorial(parse_expr("2 * 3 + 4")) == ["+", ["*", "2.0", "3.0"], "4.0"] + @test vectorial(parse_expr("2 - 3 / 4")) == ["-", "2.0", ["/", "3.0", "4.0"]] + @test vectorial(parse_expr("2 / 3 - 4")) == ["-", ["/", "2.0", "3.0"], "4.0"] + + # % (4) between + (3) and * (5) + @test vectorial(parse_expr("2 + 3 % 4")) == ["+", "2.0", ["%", "3.0", "4.0"]] + @test vectorial(parse_expr("2 % 3 + 4")) == ["+", ["%", "2.0", "3.0"], "4.0"] + @test vectorial(parse_expr("2 * 3 % 4")) == ["%", ["*", "2.0", "3.0"], "4.0"] + @test vectorial(parse_expr("2 % 3 * 4")) == ["%", "2.0", ["*", "3.0", "4.0"]] + + # Comparison (2) lower than arithmetic (3) + @test vectorial(parse_expr("1 + 2 > 3")) == [">", ["+", "1.0", "2.0"], "3.0"] + @test vectorial(parse_expr("1 > 2 + 3")) == [">", "1.0", ["+", "2.0", "3.0"]] + @test vectorial(parse_expr("1 + 2 == 3 - 1")) == ["==", ["+", "1.0", "2.0"], ["-", "3.0", "1.0"]] + @test vectorial(parse_expr("1 * 2 != 3 / 4")) == ["!=", ["*", "1.0", "2.0"], ["/", "3.0", "4.0"]] + @test vectorial(parse_expr("a >= b - 1")) == [">=", "a", ["-", "b", "1.0"]] + @test vectorial(parse_expr("a + 1 <= b")) == ["<=", ["+", "a", "1.0"], "b"] + + # && (1) lower than comparison (2) + @test vectorial(parse_expr("1 > 2 && 3 < 4")) == ["&&", [">", "1.0", "2.0"], ["<", "3.0", "4.0"]] + @test vectorial(parse_expr("a == b && c != d")) == ["&&", ["==", "a", "b"], ["!=", "c", "d"]] + + # || (0) lower than && (1) + @test vectorial(parse_expr("a || b && c")) == ["||", "a", ["&&", "b", "c"]] + @test vectorial(parse_expr("a && b || c")) == ["||", ["&&", "a", "b"], "c"] + @test vectorial(parse_expr("a && b || c && d")) == ["||", ["&&", "a", "b"], ["&&", "c", "d"]] + @test vectorial(parse_expr("a || b && c || d")) == ["||", "a", ["&&", "b", "c"], "d"] + + # Full precedence chain: || < && < comparison < +- < % < */ < ^ + @test vectorial(parse_expr("1 + 2 * 3 ^ 4 > 5 && 6 || 7")) == + ["||", ["&&", [">", ["+", "1.0", ["*", "2.0", ["^", "3.0", "4.0"]]], "5.0"], "6.0"], "7.0"] + + # Parentheses override precedence + @test vectorial(parse_expr("(1 + 2) * 3")) == ["*", ["+", "1.0", "2.0"], "3.0"] + @test vectorial(parse_expr("2 ^ (1 + 1)")) == ["^", "2.0", ["+", "1.0", "1.0"]] + @test vectorial(parse_expr("(a || b) && c")) == ["&&", ["||", "a", "b"], "c"] + @test vectorial(parse_expr("(1 + 2) * (3 + 4)")) == ["*", ["+", "1.0", "2.0"], ["+", "3.0", "4.0"]] + end + + #__ N-ary flattening ───────────────────────────────────────────── + + @testset verbose = true "N-ary operator flattening" begin + # Same-operator chains flatten into n-ary nodes + @test vectorial(parse_expr("1 + 2 + 3 + 4 + 5")) == ["+", "1.0", "2.0", "3.0", "4.0", "5.0"] + @test vectorial(parse_expr("2 * 3 * 4")) == ["*", "2.0", "3.0", "4.0"] + @test vectorial(parse_expr("2 * 3 * 4 * 5")) == ["*", "2.0", "3.0", "4.0", "5.0"] + @test vectorial(parse_expr("1 - 2 - 3")) == ["-", "1.0", "2.0", "3.0"] + @test vectorial(parse_expr("8 / 4 / 2")) == ["/", "8.0", "4.0", "2.0"] + @test vectorial(parse_expr("2 ^ 3 ^ 2")) == ["^", "2.0", "3.0", "2.0"] + @test vectorial(parse_expr("a || b || c")) == ["||", "a", "b", "c"] + @test vectorial(parse_expr("a && b && c")) == ["&&", "a", "b", "c"] + @test vectorial(parse_expr("a > b > c > d")) == [">", "a", "b", "c", "d"] + @test vectorial(parse_expr("a == b == c")) == ["==", "a", "b", "c"] + + # Different operators at same precedence break the chain + @test vectorial(parse_expr("1 + 2 - 3")) == ["-", ["+", "1.0", "2.0"], "3.0"] + @test vectorial(parse_expr("1 - 2 + 3")) == ["+", ["-", "1.0", "2.0"], "3.0"] + @test vectorial(parse_expr("1 * 2 / 3")) == ["/", ["*", "1.0", "2.0"], "3.0"] + @test vectorial(parse_expr("1 / 2 * 3")) == ["*", ["/", "1.0", "2.0"], "3.0"] + @test vectorial(parse_expr("1 > 2 < 3")) == ["<", [">", "1.0", "2.0"], "3.0"] + + # N-ary eval: left-associative reduce + @test eval_expr(parse_expr("1 + 2 + 3 + 4 + 5")) == 15.0 + @test eval_expr(parse_expr("2 * 3 * 4")) == 24.0 + @test eval_expr(parse_expr("1 - 2 - 3")) == -4.0 # ((1-2)-3) + @test eval_expr(parse_expr("24 / 4 / 3")) == 2.0 # ((24/4)/3) + @test eval_expr(parse_expr("24 / 4 / 3 / 2")) == 1.0 + + # Left-associative exponentiation (unlike standard math) + @test eval_expr(parse_expr("2 ^ 3 ^ 2")) == 64.0 # (2^3)^2, not 2^(3^2)=512 + @test eval_expr(parse_expr("2 ^ (3 ^ 2)")) == 512.0 # right-assoc with parens + + # Mixed chains with eval verification + @test eval_expr(parse_expr("10 - 3 - 2 - 1")) == 4.0 + @test eval_expr(parse_expr("100 / 5 / 4 / 5")) == 1.0 + @test eval_expr(parse_expr("1 + 2 - 3 + 4")) == 4.0 + end + + #__ Unary operators ────────────────────────────────────────────── + + @testset verbose = true "Unary operators comprehensive" begin + # Basic unary parse + @test vectorial(parse_expr("+a")) == ["+", "a"] + @test vectorial(parse_expr("-(1)")) == ["-", "1.0"] + @test vectorial(parse_expr("+(1)")) == ["+", "1.0"] + + # Double and triple unary + @test vectorial(parse_expr("-(-1)")) == ["-", ["-", "1.0"]] + @test vectorial(parse_expr("-(-(1))")) == ["-", ["-", "1.0"]] + @test vectorial(parse_expr("+(+1)")) == ["+", ["+", "1.0"]] + @test vectorial(parse_expr("-(-(-1))")) == ["-", ["-", ["-", "1.0"]]] + + # Unary in binary expressions + @test vectorial(parse_expr("1 + (-2)")) == ["+", "1.0", ["-", "2.0"]] + @test vectorial(parse_expr("1 * (-2)")) == ["*", "1.0", ["-", "2.0"]] + @test vectorial(parse_expr("-1 + -2")) == ["+", ["-", "1.0"], ["-", "2.0"]] + @test vectorial(parse_expr("-1 * -2")) == ["*", ["-", "1.0"], ["-", "2.0"]] + + # Unary minus binds tighter than ^ (handled at atomic level) + @test vectorial(parse_expr("-3 ^ 2")) == ["^", ["-", "3.0"], "2.0"] + + # Eval + @test eval_expr(parse_expr("+1")) == 1.0 + @test eval_expr(parse_expr("-(-1)")) == 1.0 + @test eval_expr(parse_expr("-(-(1))")) == 1.0 + @test eval_expr(parse_expr("-(-(-1))")) == -1.0 + @test eval_expr(parse_expr("1 + (-2)")) == -1.0 + @test eval_expr(parse_expr("-1 + -2")) == -3.0 + @test eval_expr(parse_expr("-1 * -2")) == 2.0 + @test eval_expr(parse_expr("-3 ^ 2")) == 9.0 # (-3)^2 + @test eval_expr(parse_expr("-(3 ^ 2)")) == -9.0 + @test eval_expr(parse_expr("-0")) == 0.0 + @test eval_expr(parse_expr("+0")) == 0.0 + end + + #__ Eval: comparisons ──────────────────────────────────────────── + + @testset verbose = true "Eval comparisons" begin + # All 6 comparison operators + @test eval_expr(parse_expr("1 > 0")) == true + @test eval_expr(parse_expr("0 > 1")) == false + @test eval_expr(parse_expr("1 > 1")) == false + @test eval_expr(parse_expr("1 < 2")) == true + @test eval_expr(parse_expr("2 < 1")) == false + @test eval_expr(parse_expr("1 < 1")) == false + @test eval_expr(parse_expr("1 == 1")) == true + @test eval_expr(parse_expr("1 == 2")) == false + @test eval_expr(parse_expr("1 != 2")) == true + @test eval_expr(parse_expr("1 != 1")) == false + @test eval_expr(parse_expr("1 >= 1")) == true + @test eval_expr(parse_expr("1 >= 2")) == false + @test eval_expr(parse_expr("2 >= 1")) == true + @test eval_expr(parse_expr("1 <= 1")) == true + @test eval_expr(parse_expr("2 <= 1")) == false + @test eval_expr(parse_expr("1 <= 2")) == true + + # Comparisons with arithmetic subexpressions + @test eval_expr(parse_expr("1 + 1 > 1")) == true + @test eval_expr(parse_expr("2 * 3 == 6")) == true + @test eval_expr(parse_expr("10 / 2 != 3")) == true + @test eval_expr(parse_expr("2 ^ 3 >= 8")) == true + @test eval_expr(parse_expr("2 ^ 3 <= 8")) == true + @test eval_expr(parse_expr("2 ^ 3 > 8")) == false + @test eval_expr(parse_expr("2 ^ 3 < 8")) == false + @test eval_expr(parse_expr("3 + 4 == 2 + 5")) == true + @test eval_expr(parse_expr("10 - 3 > 2 * 3")) == true + @test eval_expr(parse_expr("10 - 3 > 2 * 4")) == false + end + + #__ Eval: booleans ─────────────────────────────────────────────── + + @testset verbose = true "Eval booleans" begin + @test eval_expr(parse_expr("true")) === true + @test eval_expr(parse_expr("false")) === false + @test eval_expr(parse_expr("true == true")) == true + @test eval_expr(parse_expr("true == false")) == false + @test eval_expr(parse_expr("false == false")) == true + @test eval_expr(parse_expr("true != false")) == true + @test eval_expr(parse_expr("true != true")) == false + + # Bool <: Number in Julia → boolean arithmetic works + @test eval_expr(parse_expr("1 + true")) == 2.0 + @test eval_expr(parse_expr("1 + false")) == 1.0 + @test eval_expr(parse_expr("1 - true")) == 0.0 + @test eval_expr(parse_expr("true + true")) == 2 + @test eval_expr(parse_expr("true - false")) == 1 + @test eval_expr(parse_expr("true * 5")) == 5.0 + @test eval_expr(parse_expr("false * 100")) == 0.0 + + # Boolean from comparison used in arithmetic + @test eval_expr(parse_expr("(1 > 0) + (2 > 1)")) == 2 + @test eval_expr(parse_expr("(1 > 0) + (2 < 1)")) == 1 + @test eval_expr(parse_expr("(1 < 0) + (2 < 1)")) == 0 + end + + #__ Eval: NaN and Inf ──────────────────────────────────────────── + + @testset verbose = true "Eval NaN and Inf" begin + # NaN propagation + @test isnan(eval_expr(parse_expr("NaN"))) + @test isnan(eval_expr(parse_expr("NaN + 1"))) + @test isnan(eval_expr(parse_expr("NaN - 1"))) + @test isnan(eval_expr(parse_expr("NaN * 2"))) + @test isnan(eval_expr(parse_expr("NaN / 2"))) + @test isnan(eval_expr(parse_expr("NaN * 0"))) + @test isnan(eval_expr(parse_expr("NaN + NaN"))) + @test isnan(eval_expr(parse_expr("0 / 0"))) + + # NaN comparisons (IEEE 754) + @test eval_expr(parse_expr("NaN == NaN")) == false + @test eval_expr(parse_expr("NaN != NaN")) == true + @test eval_expr(parse_expr("NaN > 0")) == false + @test eval_expr(parse_expr("NaN < 0")) == false + @test eval_expr(parse_expr("NaN >= 0")) == false + @test eval_expr(parse_expr("NaN <= 0")) == false + @test eval_expr(parse_expr("NaN > NaN")) == false + @test eval_expr(parse_expr("NaN == 0")) == false + + # Inf + @test eval_expr(parse_expr("1 / 0")) == Inf + @test eval_expr(parse_expr("-1 / 0")) == -Inf + @test eval_expr(parse_expr("1 / 0 + 1")) == Inf + @test eval_expr(parse_expr("1 / 0 + 1 / 0")) == Inf + @test eval_expr(parse_expr("1 / 0 > 1000000")) == true + @test isnan(eval_expr(parse_expr("1 / 0 - 1 / 0"))) + @test isnan(eval_expr(parse_expr("1 / 0 * 0"))) + end + + #__ Eval: functions ────────────────────────────────────────────── + + @testset verbose = true "Eval functions comprehensive" begin + # sqrt + @test eval_expr(parse_expr("sqrt(4)")) == 2.0 + @test eval_expr(parse_expr("sqrt(0)")) == 0.0 + @test eval_expr(parse_expr("sqrt(1)")) == 1.0 + @test eval_expr(parse_expr("sqrt(0.25)")) == 0.5 + + # abs + @test eval_expr(parse_expr("abs(-5)")) == 5.0 + @test eval_expr(parse_expr("abs(5)")) == 5.0 + @test eval_expr(parse_expr("abs(0)")) == 0.0 + @test eval_expr(parse_expr("abs(-0.001)")) == 0.001 + + # exp / log + @test eval_expr(parse_expr("exp(0)")) == 1.0 + @test eval_expr(parse_expr("log(1)")) == 0.0 + @test eval_expr(parse_expr("exp(log(5))")) ≈ 5.0 + @test eval_expr(parse_expr("log(exp(3))")) ≈ 3.0 + @test eval_expr(parse_expr("exp(1)")) ≈ ℯ + + # sin / cos / atan + @test eval_expr(parse_expr("sin(0)")) == 0.0 + @test eval_expr(parse_expr("cos(0)")) == 1.0 + @test eval_expr(parse_expr("atan(0)")) == 0.0 + @test eval_expr(parse_expr("atan(1)")) ≈ atan(1) + + # Trigonometric identity: sin²(x) + cos²(x) = 1 + @test eval_expr(parse_expr("sin(0) ^ 2 + cos(0) ^ 2")) ≈ 1.0 + @test eval_expr(parse_expr("sin(1) ^ 2 + cos(1) ^ 2")) ≈ 1.0 + @test eval_expr(parse_expr("sin(2) ^ 2 + cos(2) ^ 2")) ≈ 1.0 + @test eval_expr(parse_expr("sin(42) ^ 2 + cos(42) ^ 2")) ≈ 1.0 + @test eval_expr(parse_expr("sin(100) ^ 2 + cos(100) ^ 2")) ≈ 1.0 + + # Nested functions + @test eval_expr(parse_expr("sqrt(abs(-16))")) == 4.0 + @test eval_expr(parse_expr("abs(sin(0))")) == 0.0 + @test eval_expr(parse_expr("abs(-sqrt(2))")) ≈ sqrt(2) + @test eval_expr(parse_expr("sqrt(sqrt(16))")) == 2.0 + @test eval_expr(parse_expr("log(sqrt(exp(2)))")) ≈ 1.0 + + # Functions in expressions + @test eval_expr(parse_expr("sin(0) + cos(0)")) == 1.0 + @test eval_expr(parse_expr("2 * sqrt(4)")) == 4.0 + @test eval_expr(parse_expr("sqrt(4) + sqrt(9)")) == 5.0 + @test eval_expr(parse_expr("sqrt(3 ^ 2 + 4 ^ 2)")) == 5.0 + @test eval_expr(parse_expr("sqrt(5 ^ 2 + 12 ^ 2)")) == 13.0 + end + + #__ Eval: string operations ────────────────────────────────────── + + @testset verbose = true "Eval string operations" begin + # Concatenation + @test eval_expr(parse_expr("'a' * 'b'")) == "ab" + @test eval_expr(parse_expr("'hello' * ' ' * 'world'")) == "hello world" + @test eval_expr(parse_expr("'a' * 'b' * 'c' * 'd'")) == "abcd" + + # Repetition + @test eval_expr(parse_expr("'a' ^ 5")) == "aaaaa" + @test eval_expr(parse_expr("'ab' ^ 3")) == "ababab" + @test eval_expr(parse_expr("'x' ^ 1")) == "x" + + # Comparisons + @test eval_expr(parse_expr("'abc' == 'abc'")) == true + @test eval_expr(parse_expr("'abc' == 'def'")) == false + @test eval_expr(parse_expr("'abc' != 'def'")) == true + @test eval_expr(parse_expr("'abc' != 'abc'")) == false + @test eval_expr(parse_expr("'abc' < 'def'")) == true + @test eval_expr(parse_expr("'def' > 'abc'")) == true + @test eval_expr(parse_expr("'abc' <= 'abc'")) == true + @test eval_expr(parse_expr("'abc' >= 'abc'")) == true + @test eval_expr(parse_expr("'a' < 'b'")) == true + @test eval_expr(parse_expr("'z' > 'a'")) == true + end + + #__ Complex mathematical expressions ───────────────────────────── + + @testset verbose = true "Complex mathematical expressions" begin + # Pythagorean theorem + @test eval_expr(parse_expr("sqrt(3 ^ 2 + 4 ^ 2)")) == 5.0 + @test eval_expr(parse_expr("sqrt(5 ^ 2 + 12 ^ 2)")) == 13.0 + @test eval_expr(parse_expr("sqrt(8 ^ 2 + 15 ^ 2)")) == 17.0 + + # Golden ratio + @test eval_expr(parse_expr("(1 + sqrt(5)) / 2")) ≈ (1 + sqrt(5)) / 2 + + # Logarithm/exponent identities + @test eval_expr(parse_expr("log(exp(1))")) ≈ 1.0 + @test eval_expr(parse_expr("exp(log(42))")) ≈ 42.0 + + # Power rules + @test eval_expr(parse_expr("2 ^ 10")) == 1024.0 + @test eval_expr(parse_expr("2 ^ (-1)")) == 0.5 + @test eval_expr(parse_expr("4 ^ 0.5")) == 2.0 + @test eval_expr(parse_expr("8 ^ (1 / 3)")) ≈ 2.0 + @test eval_expr(parse_expr("27 ^ (1 / 3)")) ≈ 3.0 + + # Deeply nested parentheses + @test eval_expr(parse_expr("((((1 + 2))))")) == 3.0 + @test eval_expr(parse_expr("((((((42))))))")) == 42.0 + @test eval_expr(parse_expr("(((1 + 2) * 3) - 4) / 5")) == 1.0 + + # Verification against Julia's eval + for expr_str in [ + "1 + 2 * 3", + "2 ^ 10 - 1", + "abs(-42)", + "(1 + 1) ^ 10", + "(3 + 4) * (5 - 2)", + "100 / (2 + 3) / 4", + "2 * (3 + 4 * (5 - 1))", + "((2 + 3) * (4 - 1)) ^ 2", + ] + @test eval_expr(parse_expr(expr_str)) ≈ Meta.eval(Meta.parse(expr_str)) + end + end + + #__ Scientific notation edge cases ─────────────────────────────── + + @testset verbose = true "Scientific notation comprehensive" begin + # Various formats + @test eval_expr(parse_expr("1e0")) == 1.0 + @test eval_expr(parse_expr("1e1")) == 10.0 + @test eval_expr(parse_expr("1e-1")) == 0.1 + @test eval_expr(parse_expr("1e+1")) == 10.0 + @test eval_expr(parse_expr("1E0")) == 1.0 + @test eval_expr(parse_expr("1E1")) == 10.0 + @test eval_expr(parse_expr("1E-1")) == 0.1 + @test eval_expr(parse_expr("1E+1")) == 10.0 + @test eval_expr(parse_expr("1.5e2")) == 150.0 + @test eval_expr(parse_expr("2.5e-3")) == 0.0025 + @test eval_expr(parse_expr("9.99e0")) == 9.99 + + # In arithmetic expressions + @test eval_expr(parse_expr("1e3 + 1e2")) == 1100.0 + @test eval_expr(parse_expr("1e3 * 2")) == 2000.0 + @test eval_expr(parse_expr("1e3 - 1e3")) == 0.0 + @test eval_expr(parse_expr("1e10 / 1e10")) == 1.0 + @test eval_expr(parse_expr("2.5e2 + 7.5e2")) == 1000.0 + + # In comparisons + @test eval_expr(parse_expr("1e3 > 999")) == true + @test eval_expr(parse_expr("1e3 == 1000")) == true + @test eval_expr(parse_expr("1e-1 < 1")) == true + + # Errors + @test_throws NumExpr.SyntaxError parse_expr("1e") + @test_throws NumExpr.SyntaxError parse_expr("1e+") + @test_throws NumExpr.SyntaxError parse_expr("1e-") + @test_throws NumExpr.SyntaxError parse_expr("1E") + @test_throws NumExpr.SyntaxError parse_expr("1E+") + @test_throws NumExpr.SyntaxError parse_expr("1E-") + @test_throws NumExpr.SyntaxError parse_expr("1.5e") + @test_throws NumExpr.SyntaxError parse_expr("2.5E+") + end + + #__ Parse errors ───────────────────────────────────────────────── + + @testset verbose = true "Parse errors additional" begin + # Invalid ASCII characters + @test_throws NumExpr.SyntaxError parse_expr("@x") + @test_throws NumExpr.SyntaxError parse_expr("#1") + @test_throws NumExpr.SyntaxError parse_expr("~a") + @test_throws NumExpr.SyntaxError parse_expr("\\x") + + # Broadcasting prohibited + @test_throws NumExpr.SyntaxError parse_expr("f.(x)") + @test_throws NumExpr.SyntaxError parse_expr("cos.(1)") + @test_throws NumExpr.SyntaxError parse_expr("abs.(x)") + + # Dot edge cases + @test_throws NumExpr.SyntaxError parse_expr(".1") + @test_throws NumExpr.SyntaxError parse_expr("1..2") + @test_throws NumExpr.SyntaxError parse_expr(".abc") + + # Whitespace variants + @test_throws NumExpr.SyntaxError parse_expr("\t") + @test_throws NumExpr.SyntaxError parse_expr("\n") + + # Unclosed brackets + @test_throws NumExpr.SyntaxError parse_expr("[abc") + @test_throws NumExpr.SyntaxError parse_expr("{abc") + end + + #__ Eval: precision and edge cases ─────────────────────────────── + + @testset verbose = true "Eval precision and edge cases" begin + # Float64 precision matches Julia + @test eval_expr(parse_expr("0.1 + 0.2")) == 0.1 + 0.2 + @test eval_expr(parse_expr("1.0 - 1.0")) == 0.0 + @test eval_expr(parse_expr("2.0 * 0.5")) == 1.0 + @test eval_expr(parse_expr("1.0 / 3.0")) == 1.0 / 3.0 + + # Large numbers + @test eval_expr(parse_expr("1e300 * 1e-300")) == 1.0 + @test eval_expr(parse_expr("1e308")) == 1e308 + @test eval_expr(parse_expr("1e-308")) == 1e-308 + + # Zero arithmetic + @test eval_expr(parse_expr("0 + 0")) == 0.0 + @test eval_expr(parse_expr("0 * 1000000")) == 0.0 + @test eval_expr(parse_expr("0 ^ 1")) == 0.0 + @test eval_expr(parse_expr("1 ^ 0")) == 1.0 + @test eval_expr(parse_expr("0 ^ 0")) == 1.0 # IEEE 754: 0^0 = 1 + + # Single value expressions + @test eval_expr(parse_expr("42")) == 42.0 + @test eval_expr(parse_expr("0")) == 0.0 + @test eval_expr(parse_expr("0.0")) == 0.0 + @test eval_expr(parse_expr("(42)")) == 42.0 + @test eval_expr(parse_expr("((42))")) == 42.0 + end + + #__ Memory / allocation stress ─────────────────────────────────── + + @testset verbose = true "Stress tests" begin + # Long addition chain + long_sum = join(fill("1", 100), " + ") + @test eval_expr(parse_expr(long_sum)) == 100.0 + + # Long multiplication chain + long_prod = join(fill("2", 20), " * ") + @test eval_expr(parse_expr(long_prod)) == 2.0^20 + + # Deeply nested parentheses + deep_paren = "(" ^ 50 * "1" * ")" ^ 50 + @test eval_expr(parse_expr(deep_paren)) == 1.0 + + # Nested function calls + nested_abs = "abs(" ^ 10 * "42" * ")" ^ 10 + @test eval_expr(parse_expr(nested_abs)) == 42.0 + + # Long variable name + long_var = "a" ^ 200 + @test vectorial(parse_expr(long_var)) == [long_var] + + # Complex deeply nested expression + expr_str = "((1 + 2) * (3 - 4) + (5 * 6)) / ((7 - 8) * (9 + 10) + 11)" + @test eval_expr(parse_expr(expr_str)) ≈ Meta.eval(Meta.parse(expr_str)) + end + + @testset verbose = true "Bytecode VM" begin + @testset "VarContext" begin + ctx = VarContext() + @test length(ctx) == 0 + + idx1 = NumExpr.get_or_create!(ctx, "a") + @test idx1 == 1 + @test length(ctx) == 1 + + idx2 = NumExpr.get_or_create!(ctx, "b") + @test idx2 == 2 + @test length(ctx) == 2 + + # Same name returns same index + idx1_again = NumExpr.get_or_create!(ctx, "a") + @test idx1_again == idx1 + @test length(ctx) == 2 + + # Indexing by name + @test ctx["a"] == 1 + @test ctx["b"] == 2 + + # Indexing by position + @test ctx[1] == "a" + @test ctx[2] == "b" + + # haskey + @test haskey(ctx, "a") + @test !haskey(ctx, "z") + end + + @testset "Compile constants" begin + ctx = VarContext() + + # Integer constant + c = compile_expr("42", ctx) + @test eval_compiled(c, Float64[]) == 42.0 + + # Float constant + c = compile_expr("3.14", ctx) + @test eval_compiled(c, Float64[]) ≈ 3.14 + + # NaN + c = compile_expr("NaN", ctx) + @test isnan(eval_compiled(c, Float64[])) + + # Boolean true + c = compile_expr("true", ctx) + @test eval_compiled(c, Float64[]) == 1.0 + + # Boolean false + c = compile_expr("false", ctx) + @test eval_compiled(c, Float64[]) == 0.0 + end + + @testset "Compile variables" begin + ctx = VarContext() + c = compile_expr("a + b", ctx) + @test length(ctx) == 2 + values = [3.0, 7.0] + @test eval_compiled(c, values) == 10.0 + end + + @testset "Arithmetic operations" begin + ctx = VarContext() + values = Float64[] + + @test eval_compiled(compile_expr("2 + 3", ctx), values) == 5.0 + @test eval_compiled(compile_expr("10 - 4", ctx), values) == 6.0 + @test eval_compiled(compile_expr("3 * 7", ctx), values) == 21.0 + @test eval_compiled(compile_expr("15 / 3", ctx), values) == 5.0 + @test eval_compiled(compile_expr("2 ^ 10", ctx), values) == 1024.0 + @test eval_compiled(compile_expr("17 % 5", ctx), values) == 2.0 + @test eval_compiled(compile_expr("-5", ctx), values) == -5.0 + @test eval_compiled(compile_expr("+5", ctx), values) == 5.0 + end + + @testset "N-ary flattening" begin + ctx = VarContext() + values = Float64[] + + @test eval_compiled(compile_expr("1 + 2 + 3 + 4", ctx), values) == 10.0 + @test eval_compiled(compile_expr("2 * 3 * 4", ctx), values) == 24.0 + @test eval_compiled(compile_expr("100 - 10 - 20 - 30", ctx), values) == 40.0 + end + + @testset "Comparison operations" begin + ctx = VarContext() + values = Float64[] + + @test eval_compiled(compile_expr("3 > 2", ctx), values) == 1.0 + @test eval_compiled(compile_expr("2 > 3", ctx), values) == 0.0 + @test eval_compiled(compile_expr("2 < 3", ctx), values) == 1.0 + @test eval_compiled(compile_expr("3 < 2", ctx), values) == 0.0 + @test eval_compiled(compile_expr("3 >= 3", ctx), values) == 1.0 + @test eval_compiled(compile_expr("2 >= 3", ctx), values) == 0.0 + @test eval_compiled(compile_expr("3 <= 3", ctx), values) == 1.0 + @test eval_compiled(compile_expr("4 <= 3", ctx), values) == 0.0 + @test eval_compiled(compile_expr("5 == 5", ctx), values) == 1.0 + @test eval_compiled(compile_expr("5 == 6", ctx), values) == 0.0 + @test eval_compiled(compile_expr("5 != 6", ctx), values) == 1.0 + @test eval_compiled(compile_expr("5 != 5", ctx), values) == 0.0 + end + + @testset "Logical operations" begin + ctx = VarContext() + values = Float64[] + + @test eval_compiled(compile_expr("1 && 1", ctx), values) == 1.0 + @test eval_compiled(compile_expr("1 && 0", ctx), values) == 0.0 + @test eval_compiled(compile_expr("0 && 1", ctx), values) == 0.0 + @test eval_compiled(compile_expr("0 || 1", ctx), values) == 1.0 + @test eval_compiled(compile_expr("0 || 0", ctx), values) == 0.0 + @test eval_compiled(compile_expr("1 || 0", ctx), values) == 1.0 + end + + @testset "Math functions" begin + ctx = VarContext() + values = Float64[] + + @test eval_compiled(compile_expr("sqrt(4)", ctx), values) == 2.0 + @test eval_compiled(compile_expr("abs(-7)", ctx), values) == 7.0 + @test eval_compiled(compile_expr("sin(0)", ctx), values) == 0.0 + @test eval_compiled(compile_expr("cos(0)", ctx), values) == 1.0 + @test eval_compiled(compile_expr("atan(0)", ctx), values) == 0.0 + @test eval_compiled(compile_expr("exp(0)", ctx), values) == 1.0 + @test eval_compiled(compile_expr("log(1)", ctx), values) == 0.0 + end + + @testset "Complex expressions" begin + ctx = VarContext() + + # sin^2 + cos^2 = 1 + c = compile_expr("sin(10) ^ 2 + cos(10) ^ 2", ctx) + @test eval_compiled(c, Float64[]) ≈ 1.0 + + # Nested functions + c = compile_expr("sqrt(abs(-16))", ctx) + @test eval_compiled(c, Float64[]) == 4.0 + + # Mixed arithmetic with functions + ctx2 = VarContext() + c = compile_expr("a + b * sin(c)", ctx2) + values = zeros(Float64, length(ctx2)) + values[1] = 1.0 # a + values[2] = 2.0 # b + values[3] = 0.0 # c (sin(0) = 0) + @test eval_compiled(c, values) == 1.0 + + values[3] = pi / 2 # sin(pi/2) = 1 + @test eval_compiled(c, values) ≈ 3.0 + + # Deeply nested + c = compile_expr("((1 + 2) * (3 - 4) + (5 * 6)) / ((7 - 8) * (9 + 10) + 11)", ctx) + @test eval_compiled(c, Float64[]) ≈ eval_expr(parse_expr("((1 + 2) * (3 - 4) + (5 * 6)) / ((7 - 8) * (9 + 10) + 11)")) + end + + @testset "Correctness vs tree eval" begin + numeric_exprs = [ + "1 + 2", + "3 * 4 + 5", + "10 - 3 * 2", + "2 ^ 3 ^ 2", + "(1 + 2) * 3", + "1 + 2 + 3 + 4 + 5", + "10 / 2 / 5", + "sin(1)", + "cos(0)", + "sqrt(16)", + "abs(-42)", + "exp(1)", + "log(1)", + "atan(1)", + "sin(1) ^ 2 + cos(1) ^ 2", + "-5 + 3", + "-(1 + 2)", + "1 + 2 * 3 + 4", + "2 ^ 10", + "1.5 + 2.5", + "1e2 + 1", + "sqrt(abs(-9))", + "3 > 2", + "2 < 3", + "5 == 5", + "5 != 6", + "3 >= 3", + "2 <= 3", + "true", + "false", + "1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10", + "2 * 3 * 4 * 5", + ] + + for expr_str in numeric_exprs + tree_result = eval_expr(parse_expr(expr_str)) + ctx = VarContext() + compiled = compile_expr(expr_str, ctx) + vm_result = eval_compiled(compiled, Float64[]) + if isnan(tree_result) + @test isnan(vm_result) + elseif tree_result isa Bool + @test vm_result == Float64(tree_result) + else + @test vm_result ≈ Float64(tree_result) atol = 1e-10 + end + end + end + + @testset "Resolver function" begin + ctx = VarContext() + c = compile_expr("a + b * 2", ctx) + result = eval_compiled(c, idx -> idx == 1 ? 10.0 : 20.0) + @test result == 50.0 + end + + @testset "Shared VarContext" begin + ctx = VarContext() + c1 = compile_expr("a + b", ctx) + c2 = compile_expr("b * c", ctx) + c3 = compile_expr("a + c", ctx) + @test length(ctx) == 3 # a, b, c + + values = [1.0, 2.0, 3.0] + @test eval_compiled(c1, values) == 3.0 # 1 + 2 + @test eval_compiled(c2, values) == 6.0 # 2 * 3 + @test eval_compiled(c3, values) == 4.0 # 1 + 3 + end + + @testset "Pre-allocated stack" begin + ctx = VarContext() + c = compile_expr("1 + 2 * 3", ctx) + stack = Vector{Float64}(undef, c.max_stack) + @test eval_compiled(c, Float64[], stack) == 7.0 + # Reuse stack + @test eval_compiled(c, Float64[], stack) == 7.0 + end + + @testset "String operations rejected" begin + ctx = VarContext() + @test_throws ErrorException compile_expr("'hello' * 'world'", ctx) + end + + @testset "Memory compactness" begin + ctx = VarContext() + compiled = compile_expr("a + b * sin(c)", ctx) + @test sizeof(compiled.code) + sizeof(compiled.constants) < 100 + end + + @testset "Zero allocations" begin + ctx = VarContext() + c = compile_expr("a + b * c", ctx) + values = [1.0, 2.0, 3.0] + stack = Vector{Float64}(undef, c.max_stack) + function measure_allocs(c, values, stack) + for _ in 1:5 + eval_compiled(c, values, stack) + end + return @allocated eval_compiled(c, values, stack) + end + allocs = measure_allocs(c, values, stack) + @test allocs == 0 + end + + @testset "Tagged variables" begin + ctx = VarContext() + c = compile_expr("price{store='ny'} + tax[state='nys']", ctx) + @test length(ctx) == 2 + values = [100.0, 8.0] + @test eval_compiled(c, values) == 108.0 + end + + @testset "Bytecode structure" begin + ctx = VarContext() + + # Simple constant + c = compile_expr("42", ctx) + @test c.code[1] == NumExpr.OP_LOAD_CONST + @test c.max_stack == 1 + @test length(c.constants) == 1 + @test c.constants[1] == 42.0 + + # Simple addition + c = compile_expr("1 + 2", ctx) + @test c.max_stack == 2 + + # NaN + c = compile_expr("NaN", ctx) + @test c.code[1] == NumExpr.OP_LOAD_NAN + + # true / false + c = compile_expr("true", ctx) + @test c.code[1] == NumExpr.OP_LOAD_TRUE + + c = compile_expr("false", ctx) + @test c.code[1] == NumExpr.OP_LOAD_FALSE + end + + @testset "Compile from ExprNode" begin + ctx = VarContext() + node = parse_expr("1 + 2 * 3") + c = compile_expr(node, ctx) + @test eval_compiled(c, Float64[]) == 7.0 + end + + @testset "Stress: many formulas" begin + ctx = VarContext() + formulas = NumExpr.CompiledExpr[] + for i in 1:1000 + a = rand() + b = rand() + push!(formulas, compile_expr("$a + $b", ctx)) + end + @test length(formulas) == 1000 + for (i, f) in enumerate(formulas) + result = eval_compiled(f, Float64[]) + @test isfinite(result) + end + end + + @testset "Named values convenience" begin + ctx = VarContext() + c = compile_expr("price + tax * quantity", ctx) + @test eval_compiled(c, ctx, "price" => 100.0, "tax" => 8.0, "quantity" => 5.0) == 140.0 + end + + @testset "VarContext indexing workflow" begin + ctx = VarContext() + f = compile_expr("bid * qty + ask * qty", ctx) + values = zeros(Float64, length(ctx)) + values[ctx["bid"]] = 10.0 + values[ctx["qty"]] = 5.0 + values[ctx["ask"]] = 12.0 + @test eval_compiled(f, values) == 110.0 + end + + @testset "Edge cases" begin + ctx = VarContext() + values = Float64[] + + # Division by zero + c = compile_expr("1 / 0", ctx) + @test eval_compiled(c, values) == Inf + + # Negative division by zero + c = compile_expr("-1 / 0", ctx) + @test eval_compiled(c, values) == -Inf + + # 0/0 = NaN + c = compile_expr("0 / 0", ctx) + @test isnan(eval_compiled(c, values)) + + # Large exponent + c = compile_expr("2 ^ 100", ctx) + @test eval_compiled(c, values) == 2.0^100 + + # Scientific notation + c = compile_expr("1e10 + 1e10", ctx) + @test eval_compiled(c, values) == 2e10 + end + + @testset "Extended unary opcodes" begin + ctx = VarContext() + values = Float64[] + + # isnan + c = compile_expr("isnan(NaN)", ctx) + @test eval_compiled(c, values) == 1.0 + c = compile_expr("isnan(42)", ctx) + @test eval_compiled(c, values) == 0.0 + + # not + c = compile_expr("not(true)", ctx) + @test eval_compiled(c, values) == 0.0 + c = compile_expr("not(false)", ctx) + @test eval_compiled(c, values) == 1.0 + c = compile_expr("not(NaN)", ctx) + @test isnan(eval_compiled(c, values)) + + # iszero + c = compile_expr("iszero(0)", ctx) + @test eval_compiled(c, values) == 1.0 + c = compile_expr("iszero(5)", ctx) + @test eval_compiled(c, values) == 0.0 + + # isone + c = compile_expr("isone(1)", ctx) + @test eval_compiled(c, values) == 1.0 + c = compile_expr("isone(0)", ctx) + @test eval_compiled(c, values) == 0.0 + + # floor + c = compile_expr("floor(3.7)", ctx) + @test eval_compiled(c, values) == 3.0 + c = compile_expr("floor(-2.3)", ctx) + @test eval_compiled(c, values) == -3.0 + + # ceil + c = compile_expr("ceil(3.2)", ctx) + @test eval_compiled(c, values) == 4.0 + c = compile_expr("ceil(-2.8)", ctx) + @test eval_compiled(c, values) == -2.0 + end + + @testset "Extended binary/ternary opcodes" begin + ctx = VarContext() + values = Float64[] + + # max (binary) + c = compile_expr("max(3, 7)", ctx) + @test eval_compiled(c, values) == 7.0 + c = compile_expr("max(10, 2)", ctx) + @test eval_compiled(c, values) == 10.0 + + # min (binary) + c = compile_expr("min(3, 7)", ctx) + @test eval_compiled(c, values) == 3.0 + c = compile_expr("min(10, 2)", ctx) + @test eval_compiled(c, values) == 2.0 + + # ifelse (ternary) + c = compile_expr("ifelse(true, 10, 20)", ctx) + @test eval_compiled(c, values) == 10.0 + c = compile_expr("ifelse(false, 10, 20)", ctx) + @test eval_compiled(c, values) == 20.0 + + # get (binary: isnan(a) ? b : a) + c = compile_expr("get(NaN, 42)", ctx) + @test eval_compiled(c, values) == 42.0 + c = compile_expr("get(7, 42)", ctx) + @test eval_compiled(c, values) == 7.0 + + # round (binary) + c = compile_expr("round(3.456, 2)", ctx) + @test eval_compiled(c, values) ≈ 3.46 + c = compile_expr("round(3.456, 0)", ctx) + @test eval_compiled(c, values) == 3.0 + + # isless (binary) + c = compile_expr("isless(2, 3)", ctx) + @test eval_compiled(c, values) == 1.0 + c = compile_expr("isless(3, 2)", ctx) + @test eval_compiled(c, values) == 0.0 + c = compile_expr("isless(NaN, 1)", ctx) + @test eval_compiled(c, values) == 0.0 + + # div (binary: integer division) + c = compile_expr("div(17, 5)", ctx) + @test eval_compiled(c, values) == 3.0 + c = compile_expr("div(10, 3)", ctx) + @test eval_compiled(c, values) == 3.0 + end + + @testset "Extended opcodes with variables" begin + ctx = VarContext() + c = compile_expr("ifelse(a > 0, a * 2, get(a, b))", ctx) + @test length(ctx) == 2 + + # a=5, b=10 → a>0 → a*2 = 10 + vals = [5.0, 10.0] + @test eval_compiled(c, vals) == 10.0 + + # a=NaN, b=10 → a>0 is false → get(NaN,10) = 10 + vals = [NaN, 10.0] + @test eval_compiled(c, vals) == 10.0 + + # max/min with vars + c2 = compile_expr("max(a, b) - min(a, b)", ctx) + vals = [3.0, 7.0] + @test eval_compiled(c2, vals) == 4.0 + end + + @testset "Extended opcodes correctness vs tree eval" begin + test_cases = [ + "isnan(NaN)", + "isnan(1)", + "not(true)", + "not(false)", + "iszero(0)", + "iszero(1)", + "isone(1)", + "isone(0)", + "floor(3.7)", + "ceil(3.2)", + "max(5, 10)", + "min(5, 10)", + "ifelse(true, 1, 2)", + "ifelse(false, 1, 2)", + "get(NaN, 42)", + "get(7, 42)", + "round(3.456, 2)", + "div(17, 5)", + "tg(0)", + "tg(1)", + "ctg(1)", + "mean(2, 4, 6)", + "mean(10)", + "mean(1, 2, 3, 4, 5)", + ] + for expr_str in test_cases + tree_result = eval_expr(parse_expr(expr_str)) + ctx = VarContext() + compiled = compile_expr(expr_str, ctx) + vm_result = eval_compiled(compiled, Float64[]) + if isnan(tree_result) + @test isnan(vm_result) + else + @test vm_result ≈ Float64(tree_result) atol=1e-10 + end + end + end + + @testset "tg / ctg opcodes" begin + ctx = VarContext() + c = compile_expr("tg(x)", ctx) + @test eval_compiled(c, [0.0]) == 0.0 + @test eval_compiled(c, [1.0]) ≈ tan(1.0) + @test eval_compiled(c, [Float64(π)/4]) ≈ 1.0 + + c2 = compile_expr("ctg(x)", ctx) + @test eval_compiled(c2, [Float64(π)/4]) ≈ 1.0 + @test eval_compiled(c2, [1.0]) ≈ cos(1.0) / sin(1.0) + @test eval_compiled(c2, [Float64(π)/2]) ≈ 0.0 atol=1e-15 + + # Combined expression + c3 = compile_expr("tg(x) * ctg(x)", ctx) + @test eval_compiled(c3, [1.0]) ≈ 1.0 + @test eval_compiled(c3, [0.5]) ≈ 1.0 + end + + @testset "mean opcode" begin + ctx = VarContext() + + # Constants only + c = compile_expr("mean(2, 4, 6)", ctx) + @test eval_compiled(c, Float64[]) == 4.0 + + c2 = compile_expr("mean(10)", ctx) + @test eval_compiled(c2, Float64[]) == 10.0 + + c3 = compile_expr("mean(1, 2, 3, 4, 5)", ctx) + @test eval_compiled(c3, Float64[]) == 3.0 + + # With variables + c4 = compile_expr("mean(a, b, c)", ctx) + @test eval_compiled(c4, [2.0, 4.0, 6.0]) == 4.0 + @test eval_compiled(c4, [10.0, 20.0, 30.0]) == 20.0 + + # Two args + c5 = compile_expr("mean(a, b)", ctx) + @test eval_compiled(c5, [3.0, 7.0, 0.0]) == 5.0 + + # In larger expression + c6 = compile_expr("mean(a, b, c) + 1", ctx) + @test eval_compiled(c6, [2.0, 4.0, 6.0]) == 5.0 + end + + @testset "Callable resolver (functor)" begin + ctx = VarContext() + c = compile_expr("a + b * 2", ctx) + + # Functor (callable struct) + struct TestResolver + data::Vector{Float64} + end + (r::TestResolver)(idx::Int) = r.data[idx] + + resolver = TestResolver([3.0, 4.0]) + @test eval_compiled(c, resolver) == 11.0 + + # Regular function still works + @test eval_compiled(c, idx -> [3.0, 4.0][idx]) == 11.0 + end + + @testset "VarContext sizehint" begin + ctx = VarContext(; sizehint = 10) + @test length(ctx) == 0 + + c = compile_expr("a + b", ctx) + @test length(ctx) == 2 + @test ctx["a"] == 1 + @test ctx["b"] == 2 + @test eval_compiled(c, [1.0, 2.0]) == 3.0 + end + + @testset "Constant pool overflow check" begin + ctx = VarContext() + # Build an expression with more than 65535 distinct constants. + # 0.0 and 1.0 are emitted as immediate opcodes and don't consume the pool, + # so start at 2 to make sure the pool actually overflows. + parts = [string(Float64(i)) for i in 2:65538] + huge_expr = join(parts, " + ") + @test_throws ErrorException compile_expr(huge_expr, ctx) + end + + @testset "rem opcode" begin + ctx = VarContext() + @test eval_compiled(compile_expr("rem(7, 3)", ctx), Float64[]) == rem(7.0, 3.0) + @test eval_compiled(compile_expr("rem(-7, 3)", ctx), Float64[]) == rem(-7.0, 3.0) + @test eval_compiled(compile_expr("rem(7.5, 2.0)", ctx), Float64[]) == rem(7.5, 2.0) + + # tree and VM agree + for s in ("rem(10, 3)", "rem(-10, 3)", "rem(10, -3)", "rem(7.5, 2.0)") + @test eval_expr(parse_expr(s)) == eval_compiled(compile_expr(s, VarContext()), Float64[]) + end + end + + @testset "Load-immediate opcodes for 0 and 1" begin + ctx = VarContext() + + c0 = compile_expr("0", ctx) + @test c0.code[1] == NumExpr.OP_LOAD_ZERO + @test isempty(c0.constants) + @test eval_compiled(c0, Float64[]) == 0.0 + + c1 = compile_expr("1", ctx) + @test c1.code[1] == NumExpr.OP_LOAD_ONE + @test isempty(c1.constants) + @test eval_compiled(c1, Float64[]) == 1.0 + + # 0.0 and 1.0 stay out of the constant pool when used in larger exprs + c2 = compile_expr("a + 1 + 0", ctx) + @test isempty(c2.constants) + + # Other literals still go through the constant pool + c3 = compile_expr("2", VarContext()) + @test c3.code[1] == NumExpr.OP_LOAD_CONST + @test c3.constants == [2.0] + + # -0.0 is NOT folded: parser produces unary-minus over 0.0, + # so we get OP_LOAD_ZERO + OP_NEG → -0.0 + c4 = compile_expr("-0", VarContext()) + @test eval_compiled(c4, Float64[]) === -0.0 + end + + @testset "Unsupported function clean error" begin + ctx = VarContext() + err1 = try; compile_expr("foo(1)", ctx); nothing; catch e; e; end + @test err1 isa ErrorException + @test occursin("compiled mode does not support", err1.msg) + @test occursin("foo", err1.msg) + + @test_throws ErrorException compile_expr("mod(7, 3)", VarContext()) + @test_throws ErrorException compile_expr("xor(1, 0)", VarContext()) + end + + @testset "Function arity errors" begin + ctx = VarContext() + @test_throws ErrorException compile_expr("max(1, 2, 3)", ctx) + @test_throws ErrorException compile_expr("min(1, 2, 3)", ctx) + @test_throws ErrorException compile_expr("ifelse(1)", ctx) + @test_throws ErrorException compile_expr("ifelse(1, 2)", ctx) + @test_throws ErrorException compile_expr("ifelse(1, 2, 3, 4)", ctx) + @test_throws ErrorException compile_expr("get(1)", ctx) + @test_throws ErrorException compile_expr("rem(1, 2, 3)", ctx) + @test_throws ErrorException compile_expr("sqrt(1, 2)", ctx) + @test_throws ErrorException compile_expr("mean()", ctx) + end + + @testset "ifelse / not truthy semantics" begin + ctx = VarContext() + + # Any non-zero, non-NaN condition is "true" for ifelse + @test eval_compiled(compile_expr("ifelse(2, 10, 20)", ctx), Float64[]) == 10.0 + @test eval_compiled(compile_expr("ifelse(-1, 10, 20)", ctx), Float64[]) == 10.0 + @test eval_compiled(compile_expr("ifelse(0.5, 10, 20)", ctx), Float64[]) == 10.0 + @test eval_compiled(compile_expr("ifelse(0, 10, 20)", ctx), Float64[]) == 20.0 + @test isnan(eval_compiled(compile_expr("ifelse(NaN, 10, 20)", ctx), Float64[])) + + # not: any non-zero non-NaN → 0; zero → 1; NaN → NaN + @test eval_compiled(compile_expr("not(0)", ctx), Float64[]) == 1.0 + @test eval_compiled(compile_expr("not(5)", ctx), Float64[]) == 0.0 + @test eval_compiled(compile_expr("not(-2)", ctx), Float64[]) == 0.0 + @test isnan(eval_compiled(compile_expr("not(NaN)", ctx), Float64[])) + + # Tree-walk and VM agree on the new semantics + for s in ( + "ifelse(2, 10, 20)", "ifelse(-3, 10, 20)", "ifelse(0, 10, 20)", + "not(0)", "not(7)", "not(-1)", + ) + tree = eval_expr(parse_expr(s)) + vm = eval_compiled(compile_expr(s, VarContext()), Float64[]) + @test tree == vm + end + for s in ("ifelse(NaN, 10, 20)", "not(NaN)") + @test isnan(eval_expr(parse_expr(s))) + @test isnan(eval_compiled(compile_expr(s, VarContext()), Float64[])) + end + end + end + + @testset verbose = true "Variable tags" begin + # Tag-less variable (format1) has nothing tags + v = parse_expr("abc") + @test v.tags === nothing + @test !var_has_tags(v) + @test NumExpr.var_tags(v) == Dict{String,String}() + + # Tagged variable (format2) has real tags + v2 = parse_expr("abc{key='val'}") + @test v2.tags !== nothing + @test var_has_tags(v2) + @test NumExpr.var_tags(v2) == Dict("key" => "val") + + # Global tag-less + v3 = parse_expr("[myvar]") + @test v3.tags === nothing + @test !var_has_tags(v3) + end end