Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for M1/M2 Apple Silicon #87

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

---


# Meta Llama 3

We are unlocking the power of large language models. Our latest version of Llama is now accessible to individuals, creators, researchers, and businesses of all sizes so that they can experiment, innovate, and scale their ideas responsibly.
Expand Down Expand Up @@ -98,6 +97,35 @@ Different models require different model-parallel (MP) values:

All models support sequence length up to 8192 tokens, but we pre-allocate the cache according to `max_seq_len` and `max_batch_size` values. So set those according to your hardware.


# Support for M1/M2 Apple Silicon
Llama 2 fork for running inference on Mac M1/M2 (MPS) devices
This is a fork of https://github.com/facebookresearch/llama that runs on Apple M2 (MPS - Metal Performance Shaders).

Note: user needs to set PYTORCH_ENABLE_MPS_FALLBACK=1 env variable to run this code.
This is a workaround for unsupported 'aten:polar.out' operator.

So the `example_text_completion.py` will look like this

```
PYTORCH_ENABLE_MPS_FALLBACK=1 torchrun --nproc_per_node 1 example_text_completion.py \
--ckpt_dir Meta-Llama-3-8B/ \
--tokenizer_path Meta-Llama-3-8B/tokenizer.model \
--max_seq_len 128 --max_batch_size 4
```
This will now use gloo backend instead of nccl in the case that you don't have nccl backend.

You can also check this
```
if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

```

### Pretrained Models

These models are not finetuned for chat or Q&A. They should be prompted so that the expected answer is the natural continuation of the prompt.
Expand Down
41 changes: 33 additions & 8 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from llama.model import ModelArgs, Transformer
from llama.tokenizer import ChatFormat, Dialog, Message, Tokenizer

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')


class CompletionPrediction(TypedDict, total=False):
generation: str
Expand Down Expand Up @@ -65,14 +72,20 @@ def build(
and loads the pre-trained model and tokenizer.
"""
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
# torch.distributed.init_process_group("nccl")
if device == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if device == "cuda":
torch.cuda.set_device(local_rank)
# torch.cuda.set_device(local_rank)

# seed must be the same in all processes
torch.manual_seed(seed)
Expand All @@ -98,10 +111,20 @@ def build(
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert model_args.vocab_size == tokenizer.n_words
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)

# if torch.cuda.is_bf16_supported():
# torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
# else:
# torch.set_default_tensor_type(torch.cuda.HalfTensor)

if device == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
torch.set_default_tensor_type(torch.HalfTensor)

model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
Expand Down Expand Up @@ -152,14 +175,16 @@ def generate(
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
# tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)

for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
eos_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
logits = self.model.forward(tokens, prev_pos)
Expand Down
21 changes: 16 additions & 5 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
)
from torch import nn

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')


@dataclass
class ModelArgs:
Expand Down Expand Up @@ -48,6 +55,7 @@ def forward(self, x):

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = freqs.to(device)
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
Expand All @@ -67,12 +75,15 @@ def apply_rotary_emb(
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq = xq.to('cpu')
xk = xk.to('cpu')
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# return xq_out.type_as(xq), xk_out.type_as(xk)
return xq_out.type_as(xq).to(device), xk_out.type_as(xk).to(device)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -133,15 +144,15 @@ def __init__(self, args: ModelArgs):
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)

def forward(
self,
Expand Down Expand Up @@ -283,7 +294,7 @@ def forward(self, tokens: torch.Tensor, start_pos: int):

mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.full((seqlen, seqlen), float("-inf"), device=torch.device('cpu'))

mask = torch.triu(mask, diagonal=1)

Expand All @@ -292,7 +303,7 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
[torch.zeros((seqlen, start_pos), device=torch.device('cpu')), mask]
).type_as(h)

for layer in self.layers:
Expand Down