Conversation
|
|
||
| # Compute total_vec_size from per-world max row dimensions | ||
| max_dims_np = A.max_dims.numpy() | ||
| total_vec_size = int(max_dims_np[:, 0].sum()) |
There was a problem hiding this comment.
It feels like the sparse matrix class is missing a sum_of_max_dims, to avoid having to compute it here, maybe we should add that field in the class (there is a sum_of_num_nzb, max_of_num_nzb, max_of_max_dims but no sum_of_max_dims)? @vastsoun what do you think?
There was a problem hiding this comment.
Yeah, this is a good idea, let's do that. @camevor can you also include this change into this PR?
|
|
||
| # Compute total_vec_size from per-world max row dimensions | ||
| max_dims_np = A.bsm.max_dims.numpy() | ||
| total_vec_size = int(max_dims_np[:, 0].sum()) |
Guirec-Maloisel
left a comment
There was a problem hiding this comment.
Thanks for working on this, the code changes look good to me overall! I expected the dot product to be more of an issue, but it seems that the masking that was already there is enough to deal with this as well. We should probably take a pass through the FK solver at some point then, to use the same conventions everywhere.
Did you measure any change in performance (e.g. on DR Legs)?
|
|
||
| # Compute total_vec_size from per-world max row dimensions | ||
| max_dims_np = A.max_dims.numpy() | ||
| total_vec_size = int(max_dims_np[:, 0].sum()) |
There was a problem hiding this comment.
Yeah, this is a good idea, let's do that. @camevor can you also include this change into this PR?
|
Thanks, @Guirec-Maloisel and @vastsoun.
On a few quick runs, it looks like it might be 1-2% slower, overall. The breakdown of the comparative runtime as a result of the change is interesting: This could be explained by caching of the offset arrays, moving stalls out of the matvec kernels and into the CR kernels. |
Guirec-Maloisel
left a comment
There was a problem hiding this comment.
Thanks for running some timings (and sorry for the delay in replying)! I think I'm happy with the 1-2% regression if it buys us generality, so this is good to merge from my side (but we should probably get @vastsoun's opinion too).
Description
Changes conjugate solvers to use flat arrays with offsets instead of reshaping the arrays to 2d. This allows for individual
max_dimsper world.resolves #193
Checklist
CHANGELOG.mdhas been updated (if user-facing change)