Skip to content

Wrapping deep learning models from the package Flux.jl for use in the MLJ.jl toolbox

License

Notifications You must be signed in to change notification settings

FluxML/MLJFlux.jl

Repository files navigation

MLJFlux

An interface to the Flux deep learning models for the MLJ machine learning framework.

Branch Julia CPU CI GPU CI Coverage
master v1 Continuous Integration (CPU) Continuous Integration (GPU) Code Coverage
dev v1 Continuous Integration (CPU) Continuous Integration (GPU) Code Coverage

Stable

Code Snippet

using MLJ, MLJFlux, RDatasets, Plots

Grab some data and split into features and target:

iris = RDatasets.dataset("datasets", "iris");
y, X = unpack(iris, ==(:Species), colname -> true, rng=123);
X = Float32.(X);      # To optmise for GPUs

Load model code and instantiate an MLJFlux model:

NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux

clf = NeuralNetworkClassifier(
    builder=MLJFlux.MLP(; hidden=(5,4)),
    batch_size=8,
    epochs=50,
    acceleration=CUDALibs()  # for training on a GPU
)

Wrap in "iteration controls":

stop_conditions = [
    Step(1),            # Apply controls every epoch
    NumberLimit(1000),   # Don't train for more than 100 steps
    Patience(4),        # Stop after 5 iterations of deteriation in validation loss
    NumberSinceBest(5), # Or if the best loss occurred 9 iterations ago
    TimeLimit(30/60),   # Or if 30 minutes passed
]

validation_losses = []
train_losses = []
callbacks = [
    WithLossDo(loss->push!(validation_losses, loss)),
    WithTrainingLossesDo(losses->push!(train_losses, losses[end])),
]

iterated_model = IteratedModel(
    model=clf,
    resampling=Holdout(fraction_train=0.5); # loss and stopping are based on out-of-sample
    measures=log_loss,
    controls=vcat(stop_conditions, callbacks),
);

Train the wrapped model:

julia> mach = machine(iterated_model, X, y)
julia> fit!(mach)

[ Info: Training machine(ProbabilisticIteratedModel(model = NeuralNetworkClassifier(builder = MLP(hidden = (5, 4), …), …), …), …).
[ Info: No iteration parameter specified. Using `iteration_parameter=:(epochs)`.
[ Info: final loss: 0.10431026246922499
[ Info: final training loss: 0.046286315
[ Info: Stop triggered by Patience(4) stopping criterion.
[ Info: Total of 349 iterations.

Inspect results:

julia> plot(train_losses, label="Validation Loss", linewidth=2, size=(800,400))
julia> plot!(validation_losses, label="Validation Loss", linewidth=2, size=(800,400))