-
I'm curious about the memory reusing / garbage collection mechanism during a single eval. For example, for the following script: import mlx.core as mx
import mlx.nn as nn
mx.metal.set_memory_limit(10000, relaxed=True)
mx.metal.set_cache_limit(0)
B = 100
N = 10
D = 64
module_list = []
for _ in range(N):
module_list.append(nn.Linear(D, D, bias=False))
module_list.append(nn.ReLU())
model = nn.Sequential(*module_list)
x = mx.random.normal((B, D))
y = model(x)
mx.eval(y)
print(mx.metal.get_peak_memory())
print(mx.metal.get_active_memory()) I got Could someone help me understand how this magic happens? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Read deeper into the code. Is it because after calling |
Beta Was this translation helpful? Give feedback.
Read deeper into the code. Is it because after calling
eval
on each intermediate array,arr.detach()
is called, so that its inputs' ref count --, and eventually get destructed?