diff --git a/src/Reactant.jl b/src/Reactant.jl index 1bf87de1d..1dbb846cd 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -103,6 +103,8 @@ end @inline ConcreteRArray(data::T) where {T <: Number} = ConcreteRArray{T, (), 0}(data) +Base.similar(x::ConcreteRArray{T, Shape, N}, ::Type{T2}) where {T, Shape, N, T2} = ConcreteRArray{T, Shape, N}(x.data) + mutable struct TracedRArray{ElType,Shape,N} <: RArray{ElType, Shape, N} paths::Tuple mlir_data::Union{Nothing,MLIR.IR.Value} diff --git a/test/basic.jl b/test/basic.jl index c26b4170a..e762f2644 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -98,3 +98,8 @@ end @test r ≈ mul(ones(50,70),ones(70,30)) end + +@testset "ConcreteRArray" begin + c = Reactant.ConcreteRArray(ones(50,70)) + similar(c) +end