diff --git a/src/obstransform.jl b/src/obstransform.jl
index e5c6168..e2aefbb 100644
--- a/src/obstransform.jl
+++ b/src/obstransform.jl
@@ -208,11 +208,7 @@ accomplish that, which means that the return value is likely of a
 different type than `data`.
 
 Optionally, a random number generator `rng` can be passed as the
-first argument.
-
-The optional parameter `rng` allows one to specify the
-random number generator used for shuffling. This is useful when
-reproducible results are desired.
+first argument. 
 
 For this function to work, the type of `data` must implement
 [`numobs`](@ref) and [`getobs`](@ref). 
diff --git a/src/splitobs.jl b/src/splitobs.jl
index 7061d43..9ea6e15 100644
--- a/src/splitobs.jl
+++ b/src/splitobs.jl
@@ -2,7 +2,7 @@
     splitobs(n::Int; at) -> Tuple
 
 Compute the indices for two or more disjoint subsets of
-the range `1:n` with splits given by `at`.
+the range `1:n` with split sizes determined by `at`.
 
 # Examples
 
@@ -18,13 +18,12 @@ splitobs(n::Int; at) = _splitobs(n, at)
 
 _splitobs(n::Int, at::Integer) = _splitobs(n::Int, at / n) 
 _splitobs(n::Int, at::NTuple{N, <:Integer}) where {N} = _splitobs(n::Int, at ./ n) 
-
 _splitobs(n::Int, at::Tuple{}) = (1:n,)
 
 function _splitobs(n::Int, at::AbstractFloat)
     0 <= at <= 1 || throw(ArgumentError("the parameter \"at\" must be in interval (0, 1)"))
-    n1 = clamp(round(Int, at*n), 0, n)
-    (1:n1, n1+1:n)
+    n1 = round(Int, n * at)
+    return (1:n1, n1+1:n)
 end
 
 function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N
@@ -37,22 +36,24 @@ function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N
     return (a, rest...)
 end
 
+
 """
-    splitobs([rng], data; at, shuffle=false) -> Tuple
+    splitobs([rng,] data; at, shuffle=false, stratified=nothing) -> Tuple
 
 Partition the `data` into two or more subsets.
 
-When `at` is a number between 0 and 1, this specifies the proportion in the first subset.
-
-When `at` is an integer, it specifies the number of observations in the first subset.
-
-When `at` is a tuple, entries specifies the number or proportion in each subset, except
+The argument `at` specifies how to split the data:
+- When `at` is a number between 0 and 1, this specifies the proportion in the first subset.
+- When `at` is an integer, it specifies the number of observations in the first subset.
+- When `at` is a tuple, entries specifies the number or proportion in each subset, except
 for the last which will contain the remaning observations. 
 The number of returned subsets is `length(at)+1`.
 
 If `shuffle=true`, randomly permute the observations before splitting.
 A random number generator `rng` can be optionally passed as the first argument.
 
+If `stratified` is not `nothing`, it should be an array of labels with the same length as the data.
+The observations will be split in such a way that the proportion of each label is preserved in each subset.
 
 Supports any datatype implementing [`numobs`](@ref). 
 
@@ -74,14 +75,41 @@ julia> train, test = splitobs((reshape(1.0:100.0, 1, :), 101:200), at=0.7, shuff
 
 julia> vec(test[1]) .+ 100 == test[2]
 true
+
+julia> splitobs(1:10, at=0.5, stratified=[0,0,0,0,1,1,1,1,1,1]) # 2 zeros and 3 ones in each subset
+([1, 2, 5, 6, 7], [3, 4, 8, 9, 10])
 ```
 """
 splitobs(data; kws...) = splitobs(Random.default_rng(), data; kws...)
 
-function splitobs(rng::AbstractRNG, data; at, shuffle::Bool=false)
+function splitobs(rng::AbstractRNG, data; at, 
+        shuffle::Bool=false, 
+        stratified::Union{Nothing,AbstractVector}=nothing)
+    n = numobs(data)
+    at = _normalize_at(n, at)
     if shuffle
-        data = shuffleobs(rng, data)
+        perm = randperm(rng, n)
+        data = obsview(data, perm) # same as shuffleobs(rng, data), but make it explicit to keep perm
     end
-    n = numobs(data)
-    return map(idx -> obsview(data, idx), splitobs(n; at))
+    if stratified !== nothing
+        @assert length(stratified) == n
+        if shuffle
+            stratified = stratified[perm]
+        end
+        idxs_groups = group_indices(stratified)
+        idxs_splits = ntuple(i -> Int[], length(at)+1)
+        for (lbl, idxs) in idxs_groups
+            new_idxs_splits = splitobs(idxs; at, shuffle=false)
+            for i in 1:length(idxs_splits)
+                append!(idxs_splits[i], new_idxs_splits[i])
+            end
+        end
+    else
+        idxs_splits = splitobs(n; at)
+    end
+    return map(idxs -> obsview(data, idxs), idxs_splits)
 end
+
+_normalize_at(n, at::Integer) = at / n
+_normalize_at(n, at::NTuple{N, <:Integer}) where N = at ./ n
+_normalize_at(n, at) = at
\ No newline at end of file
diff --git a/src/utils.jl b/src/utils.jl
index b928135..82ca407 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -380,7 +380,7 @@ function batch(xs::Vector{<:NamedTuple})
     all_keys = [sort(collect(keys(x))) for x in xs]
     ks = all_keys[1]
     @assert all(==(ks), all_keys) "Cannot batch named tuples with different keys"
-    NamedTuple(k => batch([x[k] for x in xs]) for k in ks)
+    return NamedTuple(k => batch([x[k] for x in xs]) for k in ks)
 end
 
 function batch(xs::Vector{<:Dict})
@@ -388,7 +388,7 @@ function batch(xs::Vector{<:Dict})
     all_keys = [sort(collect(keys(x))) for x in xs]
     ks = all_keys[1]
     @assert all(==(ks), all_keys) "cannot batch dicts with different keys"
-    Dict(k => batch([x[k] for x in xs]) for k in ks)
+    return Dict(k => batch([x[k] for x in xs]) for k in ks)
 end
 
 """
diff --git a/test/splitobs.jl b/test/splitobs.jl
index f5ce335..04a9e62 100644
--- a/test/splitobs.jl
+++ b/test/splitobs.jl
@@ -90,3 +90,20 @@ end
     p2, _ = splitobs(rng, data, at=3, shuffle=true)
     @test p1 == p2
 end
+
+@testset "stratified" begin
+    data = (a=zeros(Float32, 2, 10), b=[0,0,0,0,1,1,1,1,1,1])
+    d1, d2 = splitobs(data, at=0.5, stratified=data.b)
+    @test d1.b == [0,0,1,1,1]
+    @test d2.b == [0,0,1,1,1]
+    d1, d2 = splitobs(data, at=0.25, stratified=data.b)
+    @test d1.b == [0,1,1]
+    @test d2.b == [0,0,0,1,1,1,1]
+
+    d1, d2 = splitobs(data, at=0., stratified=data.b)
+    @test d1.b == []
+    @test d2.b == [0,0,0,0,1,1,1,1,1,1]
+    d1, d2 = splitobs(data, at=1., stratified=data.b)
+    @test d1.b == [0,0,0,0,1,1,1,1,1,1]
+    @test d2.b == []
+end