Skip to content

Initial review of MLJ interface. #120

@ablaom

Description

@ablaom

I'm posting this in response to the request at JuliaAI/MLJModels.jl#571.

I can see a some work has gone into understanding MLJ's API requirements (and in understanding internals of MLJFlux).

I have not made an exhaustive review of the interface but list below some issues identified so far. Read point 4 first, as it is the more serious.

1. Form of predictions

Whenever possible, probabilistic predictions must take the form of a vector of distributions, where a "distribution" is something implementing Distributions.pdf and Random.rand (docs). So, instead of returning raw probabilities, the classifier should return a vector with element type UnivariateFinite (owned by CategoricalDistributions.jl). For example, here's what MLJFlux.NeuralNetworkClassifier predictions look like:

julia> predict(mach, rows=1:3)
3-element UnivariateFiniteVector{Multiclass{3}, String, UInt32, Float32}:
 UnivariateFinite{Multiclass{3}}(setosa=>0.342, versicolor=>0.349, virginica=>0.308)
 UnivariateFinite{Multiclass{3}}(setosa=>0.334, versicolor=>0.345, virginica=>0.322)
 UnivariateFinite{Multiclass{3}}(setosa=>0.337, versicolor=>0.345, virginica=>0.318)

Perhaps you mirror the code for that model, here.

Similarly, the regressor should return a vector of whatever Distributions distribution you are returning, e.g. a vector of Distributions.Normal, and not simply return parameters.

2. Table handling

I suspect there is something not generic about tables handling. If I train the classifier using data X, y = MLJBase.@load_iris I get an error, although training with X, y = make_moons() works fine. Getting the number of rows of a generic table (if that's the issue) has always been a bit of a problem, because the Tables.jl API was designed to include tables without length. I think the idea is that you should use DataAPI.nrow (row singular) for this, but I think MLJModelInterface.nrows or MLJBase.nrows (rows plural) are probably okay.

3. Metadata/traits

The load_paths are wrong (see correction below).

Your input and target types need some tweaking. For example, I'm getting warnings with the above data sets about the type of data when a I do machine(model, X, y). One problem is you have Finite in some places you probably want <:Finite, because Finite is a UnionAll type (parameterised, Finite{N}). See my suggestion below.

a Do you really support input's X with categorical features? (If you are you may be interested in the pending MLJFlux PR which adds entity embedding for categorical features for the non-image models. This might be more useful than static one-hot encoding, if that is what you do to handle categoricals.)

b Do you really support classification for non categorical targets y (you currently allow y to be Continuous)?

c Do you really intend to support regression with categorical targets y. What would that mean?

d So you really intend to exclude mixed data types in input X (some categorical, some continuous)?

e Do you handle OrderedFactor and Multiclass differently (as you probably should)? If not, perhaps you mean to restrict to Multiclass and have the user coerce OrderedFactor to Continuous (assuming you do not already do this under the hood).

Assuming the answers to a- d are: yes, no, no, no, here's my stab at a revised metadata declaration:

MLJBase.metadata_model(
    LaplaceClassification;
    input_scitype=Union{
        AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}} # matrix with mixed types
        MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
    },
    target_scitype=AbstractArray{<:MLJBase.Finite}}, # ordered factor or multiclass
    load_path="LaplaceRedux.LaplaceClassification",
)
# metadata for each model,
MLJBase.metadata_model(
    LaplaceRegression;
    input_scitype=Union{
        AbstractMatrix{<:Union{MLJBase.Finite, MLJBase.Continuous}} # matrix with mixed types
        MLJBase.Table(MLJBase.Finite, MLJBase.Contintuous), # table with mixed types
    },
     target_scitype=AbstractArray{MLJBase.Continuous},
    load_path="LaplaceRedux.MLJFlux.LaplaceRegression",
)

4. Use of private API (more serious)

The overloaded methods MLJFlux.shape, MLJFlux.build(::FluxModel, ...), MLJFlux.fitresult, and MLJFlux.train are not public API. They are simply abstractions that arose to try to remove some code duplication with the different models provided by MLJFlux. I am consequently reluctant to make this public. Indeed, the entity embedding PR referred to above breaks MLJ.fitresult, and future patch releases may break the API further. There may be a good argument for making this API public, but I feel this requires a substantial rethink. Indeed your own attempt to "hack" aspects of this API reveal the inadequacies: The fact that you feel the need to overload MLJFlux.train at all; the fact that the chain get's modified in train and not is some earlier stage, etc.

Unfortunately, I personally don't have the bandwidth for this kind of refactoring of MLJFlux any time soon. Your best option may simply be to cut and paste the MLJFlux code you need and have LaplaceRedux own independent versions of the private MLJFlux API methods referenced above. Alternatively, you could leave things as they are and live with breakages, as they occur. Not sure how keen I am on registering such a model, however. Perhaps we wait and see how stable the internal API winds up being.

5. Minor nomenclature point

For consistency with other MLJ models, I suggest LaplaceRegressor over LaplaceRegression and LaplaceClassifier over LaplaceClassification. Of course I understand you may have other reasons for the name choices.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions