Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #37: WIP - M1 NCCL Error - Utilizing Llama2 M1 Bug Fix #44

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,11 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/



# models
Meta-Llama-3-8B
Meta-Llama-3-8B-Instruct
Meta-Llama-3-70B
Meta-Llama-3-70B-Instruct
43 changes: 30 additions & 13 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,30 @@ def build(
or if the model parallel size does not match the number of checkpoint files.

Note:
This method initializes the distributed process group, sets the device to CUDA,
This method initializes the distributed process group, sets the device based on availability,
and loads the pre-trained model and tokenizer.
"""
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if torch.backends.mps.is_available():
torch.distributed.init_process_group("gloo")
else:
torch.distributed.init_process_group("nccl")

if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if torch.backends.mps.is_available():
device = torch.device("mps")
local_rank = 0
else:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
device = torch.device("cuda")
else:
device = torch.device("cpu")

# seed must be the same in all processes
torch.manual_seed(seed)
Expand All @@ -98,12 +110,18 @@ def build(
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert model_args.vocab_size == tokenizer.n_words
if torch.cuda.is_bf16_supported():

if torch.backends.mps.is_available():
torch.set_default_tensor_type(torch.FloatTensor)
elif torch.cuda.is_available() and torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
torch.set_default_tensor_type(torch.FloatTensor)

model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
model.to(device)

print(f"Loaded in {time.time() - start_time:.2f} seconds")

return Llama(model, tokenizer)
Expand Down Expand Up @@ -140,7 +158,6 @@ def generate(
Note:
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.

"""
params = self.model.params
bsz = len(prompt_tokens)
Expand All @@ -152,14 +169,15 @@ def generate(
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
device = next(self.model.parameters()).device
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device=device)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
eos_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
logits = self.model.forward(tokens, prev_pos)
Expand All @@ -170,7 +188,7 @@ def generate(
ignore_index=pad_id,
)

stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens))
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens), device=device)

for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
Expand Down Expand Up @@ -249,7 +267,6 @@ def text_completion(
Note:
This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
If logprobs is True, token log probabilities are computed for each generated token.

"""
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
Expand Down
26 changes: 18 additions & 8 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
)
from torch import nn

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

@dataclass
class ModelArgs:
Expand Down Expand Up @@ -67,12 +73,17 @@ def apply_rotary_emb(
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not torch.cuda.is_available():
xq = xq.to('cpu')
xk = xk.to('cpu')
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
if not torch.cuda.is_available():
freqs_cis = freqs_cis.to('cpu')
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
return xq_out.type_as(xq).to(device), xk_out.type_as(xk).to(device)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand All @@ -84,7 +95,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
) #.to(device)


class Attention(nn.Module):
Expand Down Expand Up @@ -133,15 +144,15 @@ def __init__(self, args: ModelArgs):
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)

def forward(
self,
Expand Down Expand Up @@ -256,7 +267,7 @@ def __init__(self, params: ModelArgs):
self.n_layers = params.n_layers

self.tok_embeddings = VocabParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
params.vocab_size, params.dim, init_method=lambda x: x,
)

self.layers = torch.nn.ModuleList()
Expand All @@ -278,12 +289,11 @@ def __init__(self, params: ModelArgs):
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.full((seqlen, seqlen), float("-inf"), device=torch.device('cpu'))

mask = torch.triu(mask, diagonal=1)

Expand All @@ -292,7 +302,7 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
[torch.zeros((seqlen, start_pos), device=tokens.device), (mask.to(device) if mask is not None else mask)]
).type_as(h)

for layer in self.layers:
Expand Down
2 changes: 1 addition & 1 deletion llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def decode(self, t: Sequence[int]) -> str:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
return self.model.decode(list(filter(lambda tk: tk != -1, t)))

@staticmethod
def _split_whitespaces_or_nonwhitespaces(
Expand Down