Skip to content

Cache graph functions and add If gradient test#637

Open
nfeybesse wants to merge 1 commit intotensorflow:masterfrom
nfeybesse:custom/graph-function-cache
Open

Cache graph functions and add If gradient test#637
nfeybesse wants to merge 1 commit intotensorflow:masterfrom
nfeybesse:custom/graph-function-cache

Conversation

@nfeybesse
Copy link
Copy Markdown
Contributor

This PR introduces a small cache of ConcreteFunction instances attached to a Graph.

When functions are attached via attachFunction, they are stored in a local cache
indexed by their defined name. This avoids repeatedly scanning the native
TensorFlow function library when resolving functions during gradient construction.

A helper method getFunctionCached(String prefix) is also added to allow quick lookup
of cached functions by name prefix.

In addition, this PR introduces IfGradientTest, a unit test validating correct
gradient propagation through a StatefulIf operation.

@nfeybesse nfeybesse force-pushed the custom/graph-function-cache branch from 846b6cf to 16b27d2 Compare March 12, 2026 13:17
@Craigacp
Copy link
Copy Markdown
Collaborator

We've removed Windows support in preparation for the 1.2 release, please don't add it back to the CI. Google don't release the libtensorflow binaries on Windows which we've used for the past few releases, and building libtensorflow from source isn't tenable with the GitHub Actions resources we have access to.

@nfeybesse nfeybesse force-pushed the custom/graph-function-cache branch from 16b27d2 to 9c3b19d Compare March 12, 2026 14:30
* @return a cached {@link ConcreteFunction} whose name starts with {@code prefix}, or {@code
* null} if none is found
*/
public ConcreteFunction getFunctionCached(String prefix) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @nfeybesse, can you please remove this method? Looks like it is not being used, and it might not be desirable neither since if you have multiple functions with the same prefix, you don't know which one it gonna return (unless you return the whole list of matching functions?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, that makes sense regarding the prefix-based lookup.

I tried removing access to the cached functions completely and relying only on Graph.getFunction(exactName), but this makes the implementation substantially more complicated. In particular, during custom gradient construction, calling Graph.getFunction(...) may end up scanning/querying the native function library while the graph is already being manipulated by the gradient builder. In my test case this can hang, so resolving the gradient functions through the native function library does not seem safe in that context.

I can still avoid the ambiguous prefix lookup by keeping an exact-name Java-side map in the test/code that creates the gradient functions. That works, but it means duplicating bookkeeping outside Graph even though Graph already has the information.

Maybe a middle-ground would be to expose a read-only view of the cached function names, for example a keySet() or functionNames() method. Then callers could resolve ambiguity themselves, choose an exact name deterministically, and still avoid exposing a method that returns an arbitrary function for a prefix.

Tell me what you prefer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants