Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matmul: dispatch on specific blas paths using an enum #55002

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

jishnub
Copy link
Contributor

@jishnub jishnub commented Jul 2, 2024

This expands on the approach taken by #54552.

We pass on more type information to generic_matmatmul_wrapper!, which lets us convert the branches to method dispatches. This helps spread the latency around, so that instead of compiling all the branches in the first call, we now compile the branches only when they are actually taken. While this reduces the latency in individual branches, there is no reduction in latency if all the branches are reachable.

julia> A = rand(2,2);

julia> @time A * A;
  0.479805 seconds (809.66 k allocations: 40.764 MiB, 99.93% compilation time) # 1.12.0-DEV.806
  0.346739 seconds (633.17 k allocations: 31.320 MiB, 99.90% compilation time) # This PR

julia> @time A * A';
  0.030413 seconds (101.98 k allocations: 5.359 MiB, 98.54% compilation time) # v1.12.0-DEV.806
  0.148118 seconds (219.51 k allocations: 11.652 MiB, 99.72% compilation time) # This PR

The latency is spread between the two calls here.

In fresh sessions:

julia> A = rand(2,2);

julia> @time A * A';
  0.473630 seconds (825.65 k allocations: 41.554 MiB, 99.91% compilation time) # v1.12.0-DEV.806
  0.490305 seconds (774.87 k allocations: 38.824 MiB, 99.90% compilation time) # This PR

In this case, both the syrk and gemm branches are reachable, so there is no reduction in latency.

Analogously, there is a reduction in latency in the second set of matrix multiplications where we call symm!/hemm! or _generic_matmatmul:

julia> using LinearAlgebra

julia> A = rand(2,2);

julia> @time Symmetric(A) * A;
  0.711178 seconds (2.06 M allocations: 103.878 MiB, 2.20% gc time, 99.98% compilation time) # v1.12.0-DEV.806
  0.540669 seconds (904.12 k allocations: 43.576 MiB, 2.60% gc time, 97.36% compilation time) # This PR

@jishnub jishnub added domain:linear algebra Linear algebra compiler:latency Compiler latency labels Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:latency Compiler latency domain:linear algebra Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant