Skip to content

Commit 8e8a580

Browse files
authored
Merge pull request #1294 from gridap/bugfix-basis-type-instability
Bugfix on basis type instability
2 parents ee7ca9b + dbef316 commit 8e8a580

11 files changed

Lines changed: 38 additions & 4 deletions

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
### Fixed
1515

1616
- Fixed type instability for tensor operations between `MultiValue` and scalars. Since PR[#1293](https://github.com/gridap/Gridap.jl/pull/1293).
17+
- Fixed type instability in basis construction when user gives a non-concrete output type. Since PR[#1294](https://github.com/gridap/Gridap.jl/pull/1294).
1718

1819
## [0.20.5] - 2026-04-28
1920

src/Polynomials/BernsteinBases.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,11 @@ struct BernsteinBasisOnSimplex{D,V,M,K} <: PolynomialBasis{D,V,Bernstein}
247247
function BernsteinBasisOnSimplex{D}(::Type{V},order::Int,vertices=nothing) where {D,V}
248248
_simplex_vertices_checks(Val(D), vertices)
249249

250+
VV = make_concretetype(V)
250251
cart_to_bary_matrix = _compute_cart_to_bary_matrix(vertices, Val(D+1))
251252
M = typeof(cart_to_bary_matrix) # Nothing or SMatrix
252253
K = order
253-
new{D,V,M,K}(cart_to_bary_matrix)
254+
new{D,VV,M,K}(cart_to_bary_matrix)
254255
end
255256
end
256257

src/Polynomials/CartProdPolyBases.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@ struct CartProdPolyBasis{D,V,PT} <: PolynomialBasis{D,V,PT}
3939
orders::NTuple{D,Int},
4040
terms::Vector{CartesianIndex{D}}) where {D,V,PT<:Polynomial}
4141

42+
VV = make_concretetype(V)
4243
@check isconcretetype(PT) "PT needs to be a concrete <:Polynomial type"
4344

4445
K = maximum(orders; init=0)
4546
msg = "Some term contain a higher index than the maximum degree + 1."
4647
@check all( term -> (maximum(Tuple(term), init=0) <= K+1), terms) msg
47-
new{D,V,PT}(K,orders,terms)
48+
new{D,VV,PT}(K,orders,terms)
4849
end
4950
end
5051

src/Polynomials/CompWiseTensorPolyBases.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,13 @@ struct CompWiseTensorPolyBasis{D,V,PT} <: PolynomialBasis{D,V,PT}
5959
@check size(orders,2) == D msg3
6060
@check isconcretetype(PT) "PT needs to be a concrete <:Polynomial type"
6161

62+
VV = make_concretetype(V)
6263
K = maximum(orders)
6364
num_poly = mapreduce(length, +, comp_terms)
6465

6566
#TODO check orders in `orders` greater or equal than max index in terms
6667

67-
new{D,V,PT}(num_poly,K,orders,comp_terms)
68+
new{D,VV,PT}(num_poly,K,orders,comp_terms)
6869
end
6970
end
7071

src/Polynomials/ModalC0Bases.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ struct ModalC0Basis{D,V,T} <: PolynomialBasis{D,V,ModalC0}
4343
_msg = "The number of bounding box points in a and b should match the number of terms"
4444
@check length(terms) == length(a) == length(b) _msg
4545
@check T == eltype(V) "Point and polynomial values should have the same scalar body"
46+
VV = make_concretetype(V)
4647
K = maximum(orders, init=0)
4748

48-
new{D,V,T}(K,orders,terms,a,b)
49+
new{D,VV,T}(K,orders,terms,a,b)
4950
end
5051
end
5152

src/TensorValues/MultiValueTypes.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ function rand(rng::AbstractRNG,::Random.SamplerType{V}) where V<:MultiValue{D,T}
9797
V(Tuple(vrand))
9898
end
9999

100+
function make_concretetype(::Type{T}) where T <: Number
101+
TT = ifelse(isconcretetype(T),T,typeof(zero(T)))
102+
@check isconcretetype(TT) "Type $(T) cannot be made concrete."
103+
return TT
104+
end
100105

101106
## ATM it is not possible to implement array like axes because lazy_mapping
102107
## operations / broadcast rely on axes(::MultiValue) adopting the Number convention to return ().

src/TensorValues/TensorValues.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ export skew_symmetric_part
6666
export num_components
6767
export num_indep_components
6868
export change_eltype
69+
export make_concretetype
6970
export diagonal_tensor
7071
export
7172
export

test/PolynomialsTests/BernsteinBasesTests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,4 +345,11 @@ Hbx = _Hbx(D,order,x,H,x2λ)
345345
test_field_array(b,x,bx,, grad=Gbx, gradgrad=Hbx)
346346
test_field_array(b,x1,bx[1,:],,grad=Gbx[1,:],gradgrad=Hbx[1,:])
347347

348+
# value_type must be concrete even when user passes a non-concrete tensor type
349+
for V in (TensorValue{2,2,Float64}, SymTensorValue{2,Float64},
350+
SkewSymTensorValue{2,Float64}, SymFourthOrderTensorValue{2,Float64})
351+
@test isconcretetype(value_type(BernsteinBasis(Val(2), V, 1)))
352+
@test isconcretetype(value_type(BernsteinBasisOnSimplex{2}(V, 1)))
353+
end
354+
348355
end # module

test/PolynomialsTests/CompWiseTensorPolyBasesTests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,10 @@ Gbx = hcat( [ getindex.(bi,1) .* VectorValue(1.,0)⊗V(1.,0) for bi in evaluate(
5959

6060
test_field_array(b,x,bx,, grad=Gbx)
6161

62+
# value_type must be concrete even when user passes a non-concrete tensor type
63+
for V in (TensorValue{2,2,Float64}, SymTensorValue{2,Float64}, SymFourthOrderTensorValue{2,Float64})
64+
L = num_indep_components(V)
65+
@test isconcretetype(value_type(CompWiseTensorPolyBasis{2}(Monomial, V, ones(Int, L, 2))))
66+
end
6267

6368
end # module

test/PolynomialsTests/ModalC0BasesTests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,5 +141,10 @@ G = gradient_type(V,x1)
141141
r = zeros(G, (1,1))
142142
@test_throws ErrorException Polynomials._set_derivative_mc0!(r,1,s,0,0,V)
143143

144+
# value_type must be concrete even when user passes a non-concrete tensor type
145+
for V in (TensorValue{2,2,Float64}, SymTensorValue{2,Float64},
146+
SkewSymTensorValue{2,Float64}, SymFourthOrderTensorValue{2,Float64})
147+
@test isconcretetype(value_type(ModalC0Basis{2}(V, 1)))
148+
end
144149

145150
end # module

0 commit comments

Comments
 (0)