Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NOT FOR MERGE] Rwitten host offload demo #535

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

rwitten
Copy link
Collaborator

@rwitten rwitten commented Mar 19, 2024

[rwitten@t1v-n-621261c1-w-0 2024-03-19 23:26:33] ~/maxtext (rwitten_shmap_collective_matmul_finalized) python3 pedagogical_examples/host_offload.py 
F0319 23:26:38.216995 1800494 llo_decomposer.cc:893] Unexpected opcode: dma-vmem-to-host-ram
*** Check failure stack trace: ***
    @     0x7f2ff084ec24  (unknown)                                                     
    @     0x7f2ff084e744  (unknown)   
    @     0x7f2ff084ef89  (unknown)
    @     0x7f2fe8526dec  (unknown)                                                     
    @     0x7f2fe8520bdd  (unknown)
    @     0x7f2fe7bcefde  (unknown)
    @     0x7f2fe7bcbfc8  (unknown)
    @     0x7f2fe7c23c59  (unknown)                                                     
    @     0x7f2ff0456dc3  (unknown)
    @     0x7f2ff045d3a4  (unknown)                                                     
    @     0x7f2ff0466405  (unknown)   
    @     0x7f2ff07250e3  (unknown)
    @     0x7f30a0894ac3  (unknown)                                                     
https://symbolize.stripped_domain/r/?trace=7f2ff084ec24,7f2ff084e743,7f2ff084ef88,7f2fe8526deb,7f2fe8520bdc,7f2fe7bcefdd,7f2fe7bcbfc7,7f2fe7c23c58,7f2ff0456dc2,7f2ff045d3a3,7f2
ff0466404,7f2ff07250e2,7f30a0894ac2&map= 
https://symbolize.stripped_domain/r/?trace=7f30a08969fc,7f30a084251f&map= 
*** SIGABRT received by PID 1799158 (TID 1800494) on cpu 59 from PID 1799158; ***
E0319 23:26:38.220272 1800494 coredump_hook.cc:455] RAW: Remote crash data gathering hook invoked.
E0319 23:26:38.220281 1800494 coredump_hook.cc:494] RAW: Skipping coredump since rlimit was 0 at process start.
E0319 23:26:38.220291 1800494 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0319 23:26:38.220295 1800494 coredump_hook.cc:550] RAW: Sending fingerprint to remote end.
E0319 23:26:38.220307 1800494 coredump_hook.cc:559] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/
remote_coredump.socket (Is the listener running?): No such file or directory
E0319 23:26:38.220313 1800494 coredump_hook.cc:611] RAW: Dumping core locally.
F0319 23:26:38.216995 1800494 llo_decomposer.cc:893] Unexpected opcode: dma-vmem-to-host-ram
E0319 23:26:38.427992 1800494 process_state.cc:799] RAW: Raising signal 6 with default behavior
Aborted (core dumped) 

(Crashing because the features aren't ready yet.)

data = [generate_array() for i in range(num_tensors)]
shardings = jax.tree.map(lambda x : x.sharding, data)

host_out_shardings = jax.tree.map(lambda x : x.with_memory_kind('unpinned_host'), shardings)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use pinned_host, then the crash will go away (after the changes go in)

@rwitten
Copy link
Collaborator Author

rwitten commented Mar 20, 2024

New crash:

Traceback (most recent call last):
File "/home/rwitten/maxtext/pedagogical_examples/host_offload.py", line 55, in
host_out_shardings = jax.tree.map(lambda x : x.with_memory_kind('pinned_host'), shardings)
File "/home/rwitten/.local/lib/python3.10/site-packages/jax/_src/tree.py", line 61, in map
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
File "/home/rwitten/.local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 312, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/rwitten/.local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 312, in
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/rwitten/maxtext/pedagogical_examples/host_offload.py", line 55, in
host_out_shardings = jax.tree.map(lambda x : x.with_memory_kind('pinned_host'), shardings)
File "/home/rwitten/.local/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 378, in with_memory_kind
return NamedSharding(self.mesh, self.spec, memory_kind=kind)
ValueError: Could not find memory addressable by device TPU v4. Device TPU v4 can address the following memory kinds: device, unpinned_host. Got memory kind: pinned_host

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants