Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ authors = ["Ritter.Marc <[email protected]>, Hiroshi Shinaoka <
version = "0.9.15"

[deps]
BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"

[compat]
BitIntegers = "0.3.5"
EllipsisNotation = "1"
QuadGK = "2.9"
julia = "1.6"
Expand Down
1 change: 1 addition & 0 deletions src/TensorCrossInterpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module TensorCrossInterpolation

using LinearAlgebra
using EllipsisNotation
using BitIntegers
import QuadGK

# To add a method for rank(tci)
Expand Down
7 changes: 5 additions & 2 deletions src/cachedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function using the usual function call syntax.
The type `K` denotes the type of the keys used to cache function values, which could be an integer type. This defaults to `UInt128`. A safer but slower alternative is `BigInt`, which is better suited for functions with a large number of arguments.
`CachedFunction` does not support batch evaluation of function values.
"""
struct CachedFunction{ValueType,K<:Union{UInt32,UInt64,UInt128,BigInt}} <: BatchEvaluator{ValueType}
struct CachedFunction{ValueType,K<:Union{UInt32,UInt64,UInt128,BigInt,BitIntegers.AbstractBitUnsigned}} <: BatchEvaluator{ValueType}
f::Function
localdims::Vector{Int}
cache::Dict{K,ValueType}
Expand All @@ -15,8 +15,11 @@ struct CachedFunction{ValueType,K<:Union{UInt32,UInt64,UInt128,BigInt}} <: Batch
for n in 2:length(localdims)
coeffs[n] = localdims[n-1] * coeffs[n-1]
end
if K == BigInt
@warn "Using BigInt for keys. This is SUPER slower and uses more memory. The use of BigInt is kept only for compatibility with older code. Use BitIntegers.UInt256 or bigger integer types with fixed size instead."
end
if K != BigInt
sum(coeffs .* (localdims .- 1)) < typemax(K) || error("Too many dimensions. Use BigInt instead of UInt128.")
sum(coeffs .* (localdims .- 1)) < typemax(K) || error("Overflow in CachedFunction. Use ValueType = a bigger type with fixed size, e.g., BitIntegers.UInt256")
end
new(f, localdims, cache, coeffs)
end
Expand Down
4 changes: 2 additions & 2 deletions test/test_cachedfunction.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Test
import TensorCrossInterpolation as TCI
import TensorCrossInterpolation: BatchEvaluator, MultiIndex

import BitIntegers
struct TestF <: TCI.BatchEvaluator{Float64}
end

Expand Down Expand Up @@ -96,7 +96,7 @@ end
f(x) = 1.0
nint = 4
N = 64 * nint
cf = TCI.CachedFunction{Float64,BigInt}(f, fill(2, N))
cf = TCI.CachedFunction{Float64,BitIntegers.UInt512}(f, fill(2, N))
x = ones(Int, N)
@test cf(x) == 1.0
@test TCI._key(cf, x) == 0
Expand Down
Loading