-
Notifications
You must be signed in to change notification settings - Fork 899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: move allocation logic to rust #1835
base: main
Are you sure you want to change the base?
Conversation
9ff4f18
to
a9beab0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIce
@@ -21,7 +21,12 @@ hf-hub = { version = "0.3.1", features = ["tokio"] } | |||
[profile.release] | |||
debug = 1 | |||
incremental = true | |||
panic = "abort" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really need to abort ? What's the benefit ?
incremental default is false by default, but I'm probably misunderstanding the doc, because incremental compilation is definitely active in release profiles:
https://doc.rust-lang.org/cargo/reference/profiles.html
Why debug=1 btw ?Is that for the logs/traces ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should maybe extract this profile thing into its own PR so we can merge fast.
window_size: Option<u32>, | ||
) -> Self { | ||
// Create channel | ||
let (sender, receiver) = mpsc::unbounded_channel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to the PR, but maybe it's a culprit for the leaks.
https://docs.rs/tokio/latest/tokio/sync/mpsc/fn.unbounded_channel.html
} => { | ||
let _parent_span = span.enter(); | ||
let next_batch = state | ||
.next_batch(min_size, max_size, prefill_token_budget, token_budget) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in scope doesn't work with async I guess ?
Self { | ||
entries: VecDeque::with_capacity(128), | ||
next_id: 0, | ||
next_batch_id: 0, | ||
requires_padding, | ||
block_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is requires_padding gone ? We still need it for inferentia/TPU targets, no ?
@@ -1,11 +1,11 @@ | |||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9 | |||
flash_att_v2_commit_cuda := e6f1bb0f92d0f0c91a89d848523cc70cbf4de8a0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we live on v2.5.9post1 (latest iiuc). ?
@@ -812,18 +884,16 @@ def warmup(self, batch: FlashCausalLMBatch): | |||
# Leave 5% for some wiggle room | |||
int((free_memory * 0.95) // total_cache_size) | |||
# Add batch.blocks as we allocated it above, so it is included in the peak memory. | |||
+ cache_manager.num_blocks | |||
+ batch.blocks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_blocks
was a bit easier to understand (blocks makes me wonder whyu it's not len(blocks)
.
let mut free_blocks: Vec<u32> = (1..blocks).collect(); | ||
while let Some(cmd) = receiver.recv().await { | ||
match cmd { | ||
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My allocator fu is not super advanced, but I feel like the current algorithm has a few drawbacks (not sure how the previous one worked).
Block = vec![0, 1, 2, 3, 4]
Many pops means you're deordering the blocks (4, 3, 2).
if you receive them in that order (which I think is likely, you're going to .extend
them in order gettting back 0, 4, 3, 2
.
Meaning the next block will either be order, or shuffled in weird ways.
Since cache locality seems to play an active role we may want to be careful with that.
https://doc.rust-lang.org/std/vec/struct.Vec.html#method.drain probably (maybe there's a even simpler method for fixed size drains.
Pop/Push means we're
let mut slots = Vec::with_capacity((required_blocks * block_size) as usize); | ||
|
||
for _ in 0..required_blocks { | ||
let block_id = free_blocks.pop().unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let block_id = free_blocks.pop().unwrap(); | |
let block_id = free_blocks.pop().expect("Having free slots'); |
Feels like a pretty bad panic we'd really like to know where it comes from, no ?
if required_blocks > free_blocks.len() as u32 { | ||
response_sender.send(None).unwrap(); | ||
} else { | ||
let mut blocks = Vec::with_capacity(required_blocks as usize); | ||
let mut slots = Vec::with_capacity((required_blocks * block_size) as usize); | ||
|
||
for _ in 0..required_blocks { | ||
let block_id = free_blocks.pop().unwrap(); | ||
blocks.push(block_id); | ||
for s in (block_id * block_size)..((block_id + 1) * block_size) { | ||
slots.push(s); | ||
} | ||
} | ||
let blocks = blocks.repeat(repeats); | ||
response_sender.send(Some((blocks, slots))).unwrap(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to make sure we're always answering something, I'd refactor the whole thing to be.
let response = if required_blcoks..;
response_sender.send(response).unwrap()
No description provided.