Skip to content

Commit

Permalink
fixes for Term extension (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
ExpandingMan authored Jul 23, 2023
1 parent 30c316d commit 359175c
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "XGBoost"
uuid = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
version = "2.3.0"
version = "2.3.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
73 changes: 42 additions & 31 deletions ext/XGBoostTermExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function _features_display_string(fs, n)
end
end

function Term.Panel(dm::XGBoost.DMatrix)
function Term.Panel(dm::XGBoost.DMatrix; kw...)
str = if !XGBoost.hasdata(dm)
"{dim}(values not allocated){/dim}"
else
Expand All @@ -25,62 +25,73 @@ function Term.Panel(dm::XGBoost.DMatrix)
end
end
Term.Panel(_features_display_string(XGBoost.getfeaturenames(dm), size(dm,2)),
str;
style="magenta",
title="XGBoost.DMatrix",
title_style="bold cyan",
subtitle,
subtitle_style="blue",
)
str;
style="magenta",
title="XGBoost.DMatrix",
title_style="bold cyan",
subtitle,
subtitle_style="blue",
kw...
)
end

function Term.Panel(b::XGBoost.Booster)
Base.show(io::IO, m::MIME"text/plain", dm::XGBoost.DMatrix) = show(io, m, Term.Panel(dm))

function Term.Panel(b::XGBoost.Booster; kw...)
info = if isempty(b.params)
()
else
(paramspanel(b.params; header_style="bold green", columns_style=["bold yellow", "default"], box=:SIMPLE,),)
end
Term.Panel(_features_display_string(b.feature_names, XGBoost.nfeatures(b)),
info...;
style="magenta",
title="XGBoost.Booster",
title_style="bold cyan",
subtitle="boosted rounds: $(XGBoost.getnrounds(b))",
subtitle_style="blue",
)
info...;
style="magenta",
title="XGBoost.Booster",
title_style="bold cyan",
subtitle="boosted rounds: $(XGBoost.getnrounds(b))",
subtitle_style="blue",
kw...
)
end

function paramspanel(params::AbstractDict; kwargs...)
names = sort!(collect(keys(params)))
vals = map(k -> params[k], names)
Term.Table(OrderedDict(:Parameter=>names, :Value=>vals), kwargs...)
Term.Table(OrderedDict(:Parameter=>names, :Value=>vals); kwargs...)
end

function Term.Tree(
node::XGBoost.Node;
title="XGBoost Tree (from this node)",
title_style="bold green",
kwargs...,
)
Base.show(io::IO, m::MIME"text/plain", b::XGBoost.Booster) = show(io, m, Term.Panel(b))

function Term.Tree(node::XGBoost.Node;
title="XGBoost Tree (from this node)",
title_style="bold green",
kwargs...,
)
td = isempty(children(node)) ? Dict(repr(node)=>"leaf") : _tree_display(node)
Term.Tree(td; title, title_style, kwargs...)
end

function Term.Panel(node::XGBoost.Node)
function Term.Panel(node::XGBoost.Node; width::Union{Nothing,Int}=120, kw...)
subtitle = if isempty(children(node))
"{bold green}leaf{/bold green}"
else
string(length(children(node)), " children")
end

Term.Panel(paramstable(node; header_style="bold yellow", box=:SIMPLE),
Term.Tree(node);
style="magenta",
title="XGBoost.Node {italic blue}(id=$(node.id), depth=$(node.depth)){/italic blue}",
title_style="bold cyan",
subtitle,
subtitle_style="blue",
)
Term.Tree(node);
style="magenta",
title="XGBoost.Node {italic blue}(id=$(node.id), depth=$(node.depth)){/italic blue}",
title_style="bold cyan",
subtitle,
subtitle_style="blue",
width,
kw...
)
end

Base.show(io::IO, m::MIME"text/plain", node::XGBoost.Node) = show(io, m, Term.Panel(node))

function paramstable(node::XGBoost.Node; kwargs...)
if isempty(children(node))
_paramstable(node, :cover, :leaf; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/dmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ function DMatrix(tbl;
feature_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing,
kw...
)
cols = Tables.columns(tbl)
cols = Tables.Columns(tbl)
if feature_names === nothing
feature_names = [string(x) for x in Tables.columnnames(cols)]
end
Expand Down
15 changes: 15 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,21 @@ end
@test typeof(importancereport(bst)) <: Term.Tables.Table
end

# these just ensure we don't have any exceptions
@testset "Term extension" begin
dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"))
dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"))

bst = xgboost(dtrain, num_round=5,
η=1.0, max_depth=2,
objective="binary:logistic",
watchlist=Dict(),
)

@test Term.Panel(dtrain) isa Term.Panel
@test Term.Panel(bst) isa Term.Panel
end

@testset "Booster Save/Load/Serialize" begin
dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"))
dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"))
Expand Down

0 comments on commit 359175c

Please sign in to comment.