Skip to content

Flat arrays in conjugate solvers#334

Open
camevor wants to merge 3 commits intodev/kaminofrom
dev/ca/conjugate-flat-arrays
Open

Flat arrays in conjugate solvers#334
camevor wants to merge 3 commits intodev/kaminofrom
dev/ca/conjugate-flat-arrays

Conversation

@camevor
Copy link
Copy Markdown
Collaborator

@camevor camevor commented Apr 7, 2026

Description

Changes conjugate solvers to use flat arrays with offsets instead of reshaping the arrays to 2d. This allows for individual max_dims per world.

resolves #193

Checklist

  • New or existing tests cover these changes
  • The documentation is up to date with these changes
  • CHANGELOG.md has been updated (if user-facing change)

@camevor camevor self-assigned this Apr 7, 2026

# Compute total_vec_size from per-world max row dimensions
max_dims_np = A.max_dims.numpy()
total_vec_size = int(max_dims_np[:, 0].sum())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like the sparse matrix class is missing a sum_of_max_dims, to avoid having to compute it here, maybe we should add that field in the class (there is a sum_of_num_nzb, max_of_num_nzb, max_of_max_dims but no sum_of_max_dims)? @vastsoun what do you think?

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is a good idea, let's do that. @camevor can you also include this change into this PR?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it


# Compute total_vec_size from per-world max row dimensions
max_dims_np = A.bsm.max_dims.numpy()
total_vec_size = int(max_dims_np[:, 0].sum())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also applies here

Copy link
Copy Markdown
Collaborator

@Guirec-Maloisel Guirec-Maloisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this, the code changes look good to me overall! I expected the dot product to be more of an issue, but it seems that the masking that was already there is enough to deal with this as well. We should probably take a pass through the FK solver at some point then, to use the same conventions everywhere.

Did you measure any change in performance (e.g. on DR Legs)?


# Compute total_vec_size from per-world max row dimensions
max_dims_np = A.max_dims.numpy()
total_vec_size = int(max_dims_np[:, 0].sum())
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is a good idea, let's do that. @camevor can you also include this change into this PR?

@camevor
Copy link
Copy Markdown
Collaborator Author

camevor commented Apr 8, 2026

Thanks, @Guirec-Maloisel and @vastsoun.

Did you measure any change in performance (e.g. on DR Legs)?

On a few quick runs, it looks like it might be 1-2% slower, overall.

The breakdown of the comparative runtime as a result of the change is interesting:
CR kernel 1: +28%
CR kernel 2: +25%
dot product: +1.5%
Block-sparse transpose matvec kernel: -6%
Block-sparse gemv_regularization kernel: -5%

This could be explained by caching of the offset arrays, moving stalls out of the matvec kernels and into the CR kernels.

@camevor camevor requested a review from vastsoun April 8, 2026 15:23
Copy link
Copy Markdown
Collaborator

@Guirec-Maloisel Guirec-Maloisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for running some timings (and sorry for the delay in replying)! I think I'm happy with the 1-2% regression if it buys us generality, so this is good to merge from my side (but we should probably get @vastsoun's opinion too).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix compatibility of CG and sparse solver for heterogenous worlds (2d vs 1d vector stacks)

3 participants