Skip to content

Commit

Permalink
relax joinobs
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Feb 1, 2025
1 parent 1813426 commit ba98dc5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/obstransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,12 @@ end

# joinumobs

struct JoinedData{T,N} <: AbstractDataContainer
datas::NTuple{N,T}
struct JoinedData{T<:Tuple,N} <: AbstractDataContainer
datas::T
ns::NTuple{N,Int}
end

JoinedData(datas) = JoinedData(datas, numobs.(datas))
JoinedData(datas::Tuple) = JoinedData(datas, numobs.(datas))

Base.length(data::JoinedData) = sum(data.ns)

Expand Down Expand Up @@ -194,7 +194,12 @@ jdata = joinumobs(data1, data2)
getobs(jdata, 15) == 15
```
"""
joinobs(datas...) = JoinedData(datas)
joinobs(datas...) = JoinedData(cleanjoin(datas...))

cleanjoin(x::JoinedData, ys...) = (x.datas..., cleanjoin(ys...)...)
cleanjoin(x, ys...) = (x, cleanjoin(ys...)...)
cleanjoin() = ()


"""
shuffleobs([rng], data)
Expand Down
21 changes: 21 additions & 0 deletions test/obstransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,27 @@ end
@test data[5:6] == [5, 6]
data = joinobs(ones(2, 3), zeros(2, 3))
@test data[3:4] == [[1.0, 1.0], [0.0, 0.0]]

@testset "joins of joins" begin
data1, data2 = 1:10, 11:20
data12 = joinobs(data1, data2)
data3 = 21:30
data123 = joinobs(data12, data3)
@test getobs(data123, 15) == 15
@test getobs(data123, 25) == 25
@test length(data123) == 30
@test data123.datas[1] == data1
@test data123.datas[2] == data2
@test data123.datas[3] == data3
end

@testset "join different types" begin
data1 = 1:5
data2 = ones(2, 3)
data12 = joinobs(data1, data2)
@test data12[3] == 3
@test data12[6] == [1.0, 1.0]
end
end

@testset "shuffleobs" begin
Expand Down

0 comments on commit ba98dc5

Please sign in to comment.