Cache graph functions and add If gradient test#637
Cache graph functions and add If gradient test#637nfeybesse wants to merge 1 commit intotensorflow:masterfrom
Conversation
846b6cf to
16b27d2
Compare
|
We've removed Windows support in preparation for the 1.2 release, please don't add it back to the CI. Google don't release the libtensorflow binaries on Windows which we've used for the past few releases, and building libtensorflow from source isn't tenable with the GitHub Actions resources we have access to. |
16b27d2 to
9c3b19d
Compare
| * @return a cached {@link ConcreteFunction} whose name starts with {@code prefix}, or {@code | ||
| * null} if none is found | ||
| */ | ||
| public ConcreteFunction getFunctionCached(String prefix) { |
There was a problem hiding this comment.
Thanks for the PR @nfeybesse, can you please remove this method? Looks like it is not being used, and it might not be desirable neither since if you have multiple functions with the same prefix, you don't know which one it gonna return (unless you return the whole list of matching functions?)
There was a problem hiding this comment.
Thanks for the feedback, that makes sense regarding the prefix-based lookup.
I tried removing access to the cached functions completely and relying only on Graph.getFunction(exactName), but this makes the implementation substantially more complicated. In particular, during custom gradient construction, calling Graph.getFunction(...) may end up scanning/querying the native function library while the graph is already being manipulated by the gradient builder. In my test case this can hang, so resolving the gradient functions through the native function library does not seem safe in that context.
I can still avoid the ambiguous prefix lookup by keeping an exact-name Java-side map in the test/code that creates the gradient functions. That works, but it means duplicating bookkeeping outside Graph even though Graph already has the information.
Maybe a middle-ground would be to expose a read-only view of the cached function names, for example a keySet() or functionNames() method. Then callers could resolve ambiguity themselves, choose an exact name deterministically, and still avoid exposing a method that returns an arbitrary function for a prefix.
Tell me what you prefer
This PR introduces a small cache of ConcreteFunction instances attached to a Graph.
When functions are attached via attachFunction, they are stored in a local cache
indexed by their defined name. This avoids repeatedly scanning the native
TensorFlow function library when resolving functions during gradient construction.
A helper method getFunctionCached(String prefix) is also added to allow quick lookup
of cached functions by name prefix.
In addition, this PR introduces IfGradientTest, a unit test validating correct
gradient propagation through a StatefulIf operation.