Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implement Nx.LinAlg.eigh as defn #1027

Open
polvalente opened this issue Dec 16, 2022 · 6 comments
Open

Re-implement Nx.LinAlg.eigh as defn #1027

polvalente opened this issue Dec 16, 2022 · 6 comments
Labels
area:exla Applies to EXLA area:nx Applies to nx

Comments

@polvalente
Copy link
Contributor

Currently, we have a custom implementation for Nx.BinaryBackend and call the XLA implementation for eigh in EXLA.
However, the XLA implementation seems to suffer from similar issues to the SVD one, in which it ends up being slower and with a different accuracy from the one Jax uses (https://github.com/google/jax/blob/main/jax/_src/lax/eigh.py).

Especially since we already have QDWH implemented in Nx.LinAlg.SVD.qdwh, it seems like a good idea to also reimplement eigh as a defn with optional+custom_grad (like Nx.LinAlg.svd)

@polvalente
Copy link
Contributor Author

Although #1424 did move eigh to defn, it's still worth looking into using a new implementation for speed and accuracy

@christianjgreen
Copy link

Although #1424 did move eigh to defn, it's still worth looking into using a new implementation for speed and accuracy

Do you have a basic test I could use to compare them?

@polvalente
Copy link
Contributor Author

@christianjgreen in short, you can just compare the execution time of jax.linalg.eigh vs Nx.LinAlg.eigh using a 200x200 f32 tensor using EXLA as the default compiler and backend. You'll notice that Nx will barely handle it -- takes 42s on my machine -- while jax handles it just fine -- takes around 43ms on my machine.

SVD will consequently suffer with the same performance drop because SVD uses eigh under the hood.

@christianjgreen
Copy link

@christianjgreen in short, you can just compare the execution time of jax.linalg.eigh vs Nx.LinAlg.eigh using a 200x200 f32 tensor using EXLA as the default compiler and backend. You'll notice that Nx will barely handle it -- takes 42s on my machine -- while jax handles it just fine -- takes around 43ms on my machine.

SVD will consequently suffer with the same performance drop because SVD uses eigh under the hood.

Thanks for the info!
I just finished a high level jacobi eigh method that is solving a 200x200 matrix in 3 seconds on my machine compared to ~70 seconds using qr. I thougt qr was generally faster at larger matrices, but I don't know much about the algorithm. I was going to turn the jacobi method into the QDWH-eigh method, but don't mind making a pr for it in the meantime.

@christianjgreen
Copy link

Update: After adding some optimizations to the QR algorithm, I got it down to ~8 seconds which is still twice as slow as the jacobi method, which makes me think there is something else that can be optimized. What would be best for the library owners? Starting work on a QDWH-eigh with the jacobi method, or try to optimize the QR code so that it falls within its big O predictions?

@christianjgreen
Copy link

Last update and sorry for all the pings! Even those QR-eigh decomposition is supposed to be much faster than Jacobi on large matrices, I can't seem to at least get it to match the performance of the jacobi method, which leads me to believe something is amiss with the way the QR algorithm gets compiled down. I've tried a few things like wilkinson shifts, deflating, and only checking subdiags but the iterations still grow too high before converging.

Current testing on my naïve implementation with the default 1000 iterations and an eps of 1.0e-4 takes 3.2s on my machine vs 88s with the current QR implementation. Adding wilkinson shifts and other optimizations can bring that down to around 10-30. but with not much accuracy.

I'll defer to @polvalente and @josevalim for next steps as I'm a complete newbie here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area:exla Applies to EXLA area:nx Applies to nx
Projects
None yet
Development

No branches or pull requests

2 participants