Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Optimized some scattered optimization points in the framework #5544

Open
wants to merge 1 commit into
base: feature/colossal-infer
Choose a base branch
from

Conversation

yuehuayingxueluo
Copy link
Contributor

@yuehuayingxueluo yuehuayingxueluo commented Apr 2, 2024

馃搶 Checklist before creating the PR

  • I have created an issue for this PR for traceability
  • The title follows the standard format: [doc/gemini/tensor/...]: A concise description
  • I have added relevant tags if possible for us to better distinguish different PRs

馃毃 Issue number

Link this PR to your issue with words like fixed to automatically close the linked issue upon merge

e.g. fixed #1234, closed #1234, resolved #1234

馃摑 What does this PR do?

Summarize your work here.
if you have any plots/diagrams/screenshots/tables, please attach them here.
pytest:
image
model benchmark:

bsz in_len out_len Throughput (tokens/sec)
16 128 128 1823.16-> 1831.51
32 128 128 3144.30 -> 3164.13
64 128 128 5024.28 -> 5130.96
16 128 256 1791.81-> 1844.73
32 128 256 3134.06 -> 3153.95
64 128 256 5056.01 -> 5102.04

馃挜 Checklist before requesting a review

  • I have linked my PR to an issue (instruction)
  • My issue clearly describes the problem/feature/proposal, with diagrams/charts/table/code if possible
  • I have performed a self-review of my code
  • I have added thorough tests.
  • I have added docstrings for all the functions/methods I implemented

猸愶笍 Do you enjoy contributing to Colossal-AI?

  • 馃対 Yes, I do.
  • 馃寶 No, I don't.

Tell us more if you don't enjoy contributing to Colossal-AI.

@yuehuayingxueluo yuehuayingxueluo marked this pull request as ready for review April 3, 2024 04:45
@yuehuayingxueluo yuehuayingxueluo requested a review from a team as a code owner April 3, 2024 04:45
colossalai/inference/core/request_handler.py Show resolved Hide resolved
if end_indexes.numel() > 0:
# contiguous cache exists
end_idx = end_indexes[0].item() + 1 # open interval
start_idx = end_idx - num_blocks_required # closed interval
alloc_block_ids = torch.arange(start_idx, end_idx)
alloc_block_ids = torch.arange(start_idx, end_idx, device=block_tables.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assigning alloc_block_ids.device to that of block_tables might trigger error in L259

self._block_states[alloc_block_ids] = 0

Notice that self._block_states is on the host. If the passed-in block tables tensor was on a device, you will get runtime error Expected all tensors to be on the same device, but found ....

At this moment, there exist no difference of adding device=block_tables.device here, since in batch bucket class the block tables tensor is on host, which cause no error and no functionality here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will fix it.

@@ -34,18 +34,8 @@ __global__ void act_and_mul_kernel(

// Note(LiuYang):This func is designed for calculation mode like
// silu(x[:half_1stdim]) * (x[half_1stdim:])
torch::Tensor silu_and_mul(const torch::Tensor& ins)
void silu_and_mul(const torch::Tensor& ins, torch::Tensor& outs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't handle the condition of outs is None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a None value is passed in, it will be an illegal operation and C++ will report an error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the case should be considered, whether you dispatch to a different kernel or not. The modifications here make it lose the capabilities of handling the regular way of calling the kernel (only inputs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let me think about how to fix it
.

@@ -20,7 +20,8 @@ def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype):
act_out = torch.nn.functional.silu(ref_input[0], inplace=True)
ref_out = act_out * ref_input[1]

origin_out = inference_ops.silu_and_mul(origin_input)
origin_out = torch.empty_like(ref_out)
inference_ops.silu_and_mul(origin_input, origin_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above. No test for None as output tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above reply

@@ -167,6 +171,7 @@ def llama_decoder_layer_forward(
kv_seq_len: int = 0,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
silu_and_mul_output: torch.Tensor = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it's a good idea to just put the silu_and_mul output tensor as an arg and pass it module by module to MLP layer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also feel that there are too many parameters to pass like this, I feel that we can put all these temporary outputs into a struct for unified management in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we only need to pass this struct each time."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some advice :) . Firstly, It's not a good idea to design a such ACT API that you should add a output_tensor as an arg, if you really want to do such things, you'd better make it a inplace API. Secondly, I don't think it's a good idea to help torch to do such memory management by you own before you really understand it or you've already designed a great memory management system, meanwhile, the profit of performance seems little and maybe it's just normal value fluctuation, so that this opt point may not work well. finally, maybe it's not a good idea to write trick code for just little performance profit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During testing, it is possible to obtain a stable performance benefit, moreover, compared to other optimizations, such performance benefits already seem quite considerable. Also, this does not involve helping torch manage memory; instead, it should be attributed to our unreasonable use of memory. Of course, I also agree that this operator should be implemented as an inplace operator, which will avoid redundant memory allocation operations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that this is only a temporary optimization solution, and the optimal solution would be to implement this operator as an inplace one. And we can put a TODO here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants