Skip to content

Commit 519ff2a

Browse files
committed
Execute graph initializers in a single call
1 parent 2cdccbc commit 519ff2a

File tree

3 files changed

+106
-2
lines changed

3 files changed

+106
-2
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,20 @@ public synchronized void addInitializer(Op initializer) {
183183
}
184184

185185
/**
186-
* Returns an op which initializers all the variables.
187-
* @return The initializer operation.
186+
* Create an op which executes all registered initializers in the graph.
187+
* @return the initializer op
188188
*/
189189
public NoOp variablesInitializer() {
190190
return variablesInitializer(DEFAULT_INIT_NAME);
191191
}
192192

193+
/**
194+
* Returns an op with the given {@code name} which executes all registered initializers in the
195+
* graph.
196+
*
197+
* @param name name to give to the initializer op
198+
* @return the initializer op
199+
*/
193200
public NoOp variablesInitializer(String name) {
194201
Scope scope = new Scope(this);
195202
scope = scope.withName(name).withControlDependencies(initializers);

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.tensorflow.internal.c_api.TF_SessionOptions;
3737
import org.tensorflow.internal.c_api.TF_Status;
3838
import org.tensorflow.internal.c_api.TF_Tensor;
39+
import org.tensorflow.op.Op;
3940

4041
/**
4142
* Driver for {@link Graph} execution.
@@ -432,6 +433,42 @@ public Runner runner() {
432433
return new Runner();
433434
}
434435

436+
/**
437+
* Run all graph initializers.
438+
*
439+
* <p>Initializers must be executed once before running the graph using a session
440+
* {@link Runner}.</p>
441+
*
442+
* <p>This method can be used only if the initializers were added programatically to the graph
443+
* via {@link Graph#addInitializer(Op)}. Otherwise, {@link #runInitializers(String)} must be used
444+
* to provide the name of the initializer operation to execute.
445+
*/
446+
public void runInitializers() {
447+
runner().addTarget(graph.variablesInitializer().op()).run();
448+
}
449+
450+
/**
451+
* Run all graph initializers grouped under the {@code initializerOpName} operation.
452+
*
453+
* <p>Initializers must be executed once before running the graph using a session
454+
* {@link Runner}.</p>
455+
*
456+
* <p>The {@code initializerOpName} is the name of a single operation already registered to the
457+
* graph that executes all graph initializers at once. If the initializers were created using
458+
* {@link Graph#variablesInitializer(String)}, the names in both methods must match. Otherwise,
459+
* {@link Graph#DEFAULT_INIT_NAME} should be used.</p>
460+
*
461+
* @param initializerOpName name of the initializer operation.
462+
*/
463+
public void runInitializers(String initializerOpName) {
464+
Operation operation = graph.operation(initializerOpName);
465+
if (operation == null) {
466+
throw new IllegalArgumentException("Initializer operation named '"
467+
+ initializerOpName + "' cannot be found in the graph");
468+
}
469+
runner().addTarget(operation).run();
470+
}
471+
435472
/**
436473
* Output tensors and metadata obtained when executing a session.
437474
*

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
import org.junit.runners.JUnit4;
2525
import org.tensorflow.op.Ops;
2626
import org.tensorflow.op.core.Split;
27+
import org.tensorflow.op.core.Variable;
2728
import org.tensorflow.op.linalg.MatMul;
29+
import org.tensorflow.op.math.Add;
2830
import org.tensorflow.tools.Shape;
2931
import org.tensorflow.tools.ndarray.NdArrays;
3032
import org.tensorflow.tools.ndarray.StdArrays;
@@ -161,6 +163,64 @@ public void createWithConfigProto() {
161163
Session s = new Session(g, singleThreadConfigProto())) {}
162164
}
163165

166+
@Test
167+
public void runInitializers() {
168+
try (Graph g = new Graph()) {
169+
Ops tf = Ops.create(g);
170+
171+
Variable<TInt32> var1 = tf.variable(Shape.scalar(), TInt32.DTYPE);
172+
Variable<TInt32> var2 = tf.variable(Shape.scalar(), TInt32.DTYPE);
173+
Add<TInt32> add = tf.math.add(var1, var2);
174+
175+
g.addInitializer(tf.assign(var1, tf.constant(10)));
176+
g.addInitializer(tf.assign(var2, tf.constant(20)));
177+
178+
try (Session s = new Session(g)) {
179+
s.runInitializers();
180+
181+
try (Tensor<TInt32> t = s.runner().fetch(add).run().get(0).expect(TInt32.DTYPE)) {
182+
assertEquals(30, t.data().getInt());
183+
}
184+
}
185+
}
186+
}
187+
188+
@Test
189+
public void runInitializersByName() {
190+
try (Graph g = new Graph()) {
191+
Ops tf = Ops.create(g);
192+
193+
Variable<TInt32> var1 = tf.variable(Shape.scalar(), TInt32.DTYPE);
194+
Variable<TInt32> var2 = tf.variable(Shape.scalar(), TInt32.DTYPE);
195+
Add<TInt32> add = tf.math.add(var1, var2);
196+
197+
g.addInitializer(tf.assign(var1, tf.constant(10)));
198+
g.addInitializer(tf.assign(var2, tf.constant(20)));
199+
g.variablesInitializer("init_test");
200+
try {
201+
g.variablesInitializer("init_test");
202+
fail();
203+
} catch (IllegalArgumentException e) {
204+
// as expected, cannot register initializers twice
205+
}
206+
207+
try (Session s = new Session(g)) {
208+
s.runInitializers("init_test");
209+
210+
try (Tensor<TInt32> t = s.runner().fetch(add).run().get(0).expect(TInt32.DTYPE)) {
211+
assertEquals(30, t.data().getInt());
212+
}
213+
214+
try {
215+
s.runInitializers("wrong_name");
216+
fail();
217+
} catch (IllegalArgumentException e) {
218+
// as expected
219+
}
220+
}
221+
}
222+
}
223+
164224
private static byte[] fullTraceRunOptions() {
165225
// Ideally this would use the generated Java sources for protocol buffers
166226
// and end up with something like the snippet below. However, generating

0 commit comments

Comments
 (0)