Skip to content

MAML & LLMs #827

Answered by awni
pharringtonp19 asked this question in Q&A
Mar 13, 2024 · 4 comments · 18 replies
Discussion options

You must be logged in to vote

Yes for sure you can compose vjp / value_and_grad /grad to any depth and it will work. So to do a bilevel thing you would do something like:

def step(outer_w, inner_w):
    def loss(inner_w, x, y)
        nn.losses.mse(inner_w @ x, y)
  
    dloss_dinner_w = mx.grad(loss)(inner_w, x, y)
    inner_w = inner_w + (outer_w @ x) * d_loss_dinner_w

dstep_douter_w = mx.grad(step)(outer_w, inner_w)

(Super simple + untested but just to give you the flavor of how that could go).

Replies: 4 comments 18 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by pharringtonp19
Comment options

You must be logged in to vote
2 replies
@awni
Comment options

@pharringtonp19
Comment options

Comment options

You must be logged in to vote
1 reply
@awni
Comment options

Comment options

You must be logged in to vote
15 replies
@pharringtonp19
Comment options

@awni
Comment options

@pharringtonp19
Comment options

@awni
Comment options

@pharringtonp19
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants