diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java
index 6395b1770a9..58fb62b5fee 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java
@@ -16,9 +16,7 @@
package org.tensorflow;
import static org.tensorflow.Graph.resolveOutputs;
-import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession;
-import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession;
-import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession;
+import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrType;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;
@@ -38,8 +36,12 @@
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_Tensor;
+import org.tensorflow.internal.types.registry.TensorTypeRegistry;
import org.tensorflow.op.Op;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.ReadVariableOp;
import org.tensorflow.proto.framework.ConfigProto;
+import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.RunMetadata;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.proto.util.SaverDef;
@@ -192,6 +194,11 @@ public Runner feed(String operation, int index, Tensor t) {
* @return this session runner
*/
public Runner feed(Operand> operand, Tensor t) {
+ if (operand.env() != graph) {
+ throw new IllegalStateException("Can't feed value for operand " + operand + ", it is from " +
+ (operand.env().isEager() ? "an eager session" : "a different graph") + ".");
+ }
+
inputs.add(operand.asOutput());
inputTensors.add(t);
return this;
@@ -200,6 +207,8 @@ public Runner feed(Operand> operand, Tensor t) {
/**
* Make {@link #run()} return the output of {@code operation}.
*
+ * If the output is a resource variable, will fetch the value.
+ *
* @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code
* fetch(operation, 0)}, or it is a string of the form
* operation_name:output_index , in which case this method acts like {@code
@@ -215,6 +224,8 @@ public Runner fetch(String operation) {
/**
* Make {@link #run()} return the {@code index}-th output of {@code operation}.
*
+ * If the output is a resource variable, will fetch the value.
+ *
*
Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
* one to return.
*
@@ -225,24 +236,61 @@ public Runner fetch(String operation) {
*/
public Runner fetch(String operation, int index) {
Operation op = graph.operationOrThrow(operation);
- outputs.add(op.output(index));
- return this;
+ return fetch(op.output(index));
}
/**
* Makes {@link #run()} return the Tensor referred to by {@code output}.
*
+ * If {@code output} is a resource variable, will fetch the value.
+ *
* @param output the node to fetch the tensor from
* @return this session runner
*/
public Runner fetch(Output> output) {
- outputs.add(output);
+ if (output.env() != graph) {
+ throw new IllegalStateException("Can't fetch output " + output + ", it is from " +
+ (output.env().isEager() ? "an eager session" : "a different graph") + ".");
+ }
+
+ if (output.dataType() == DataType.DT_RESOURCE) {
+ int[] rawDt = new int[1];
+
+ GraphOperation graphOp = (GraphOperation) output.op();
+
+ try (PointerScope scope = new PointerScope()) {
+ TF_Status status = TF_Status.newStatus();
+ TF_OperationGetAttrType(graphOp.getUnsafeNativeHandle(), "dtype", rawDt, status);
+ status.throwExceptionIfNotOK();
+ }
+
+ DataType valueDt = DataType.forNumber(rawDt[0]);
+
+ Operand> read = null;
+ for (GraphOperation op : graphOp.consumers()) {
+ if (op.dtype(0) == valueDt && op.type().equals(ReadVariableOp.OP_NAME)) {
+ read = op.output(0);
+ break;
+ }
+ }
+
+ if (read == null) {
+ read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read")
+ .readVariableOp(output, TensorTypeRegistry.find(valueDt).type());
+ }
+
+ outputs.add(read.asOutput());
+ } else {
+ outputs.add(output);
+ }
return this;
}
/**
* Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
*
+ * If {@code operand} is a resource variable, will fetch the value.
+ *
* @param operand the node to fetch the tensor from, as an operand
* @return this session runner
*/
@@ -258,9 +306,7 @@ public Runner fetch(Operand> operand) {
* @throws IllegalArgumentException if no operation exists with the provided name
*/
public Runner addTarget(String operation) {
- GraphOperation op = graph.operationOrThrow(operation);
- targets.add(op);
- return this;
+ return addTarget(graph.operationOrThrow(operation));
}
/**
@@ -269,13 +315,12 @@ public Runner addTarget(String operation) {
* @param operation the operation to execute
* @return this session runner
* @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
+ * @throws IllegalStateException if the operation is not from the session's graph.
*/
public Runner addTarget(Operation operation) {
- if (!(operation instanceof GraphOperation)) {
- throw new IllegalArgumentException(
- "Operation of type "
- + operation.getClass().getName()
- + " is not supported in graph sessions");
+ if (operation.env() != graph) {
+ throw new IllegalStateException("Can't target operation " + operation + ", it is from " +
+ (operation.env().isEager() ? "an eager session" : "a different graph") + ".");
}
targets.add((GraphOperation) operation);
return this;
@@ -594,12 +639,12 @@ private static void delete(TF_Session handle) {
*
* @param handle to the C API TF_Session object (Session.nativeHandle)
* @param runOptions A RunOptions protocol buffer, or null
- * @param inputOpHandles (see inputOpIndices)
- * @param inputOpIndices (see inputTensorHandles)
* @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed"
* (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a
* Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus,
* it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
+ * @param inputOpHandles (see inputOpIndices)
+ * @param inputOpIndices (see inputTensorHandles)
* @param outputOpHandles (see outputOpIndices)
* @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The
* outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length ==
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
index cd8ac7e2ae4..ff93e317805 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java
@@ -27,8 +27,8 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
-import java.util.Map;
import java.util.HashMap;
+import java.util.Map;
import org.junit.jupiter.api.Test;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.ndarray.FloatNdArray;
@@ -292,21 +292,29 @@ public void pythonTfFunction() {
ConcreteFunction add = bundle.function("add");
Map args = new HashMap();
try (TFloat32 a = TFloat32.scalarOf(10.0f);
- TFloat32 b = TFloat32.scalarOf(15.5f)) {
+ TFloat32 b = TFloat32.scalarOf(15.5f)) {
args.put("a", a);
args.put("b", b);
Map result = add.call(args);
assertEquals(result.size(), 1);
- try (TFloat32 c = (TFloat32)result.values().iterator().next()) {
+ try (TFloat32 c = (TFloat32) result.values().iterator().next()) {
assertEquals(25.5f, c.getFloat());
}
}
+
+ // variable unwrapping happens in Session, which is used by ConcreteFunction.call
+ ConcreteFunction getVariable = bundle.function("get_variable");
+ try (TFloat32 v = (TFloat32) getVariable.call(new HashMap<>())
+ .get(getVariable.signature().outputNames().iterator().next())) {
+ assertEquals(2f, v.getFloat());
+ }
+
}
}
private static Signature buildGraphWithVariables(Ops tf, Shape xShape) {
Placeholder x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape));
- Variable y = tf
+ Variable y = tf.withName("variable")
.variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class));
ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1));
Init init = tf.init();
diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java
index d7ea381d315..4223a03ee23 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java
+++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java
@@ -20,15 +20,17 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
-import java.io.BufferedOutputStream;
import java.io.File;
-import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Comparator;
-
+import java.util.Iterator;
import org.junit.jupiter.api.Test;
+import org.tensorflow.ndarray.NdArrays;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.ndarray.StdArrays;
+import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Init;
import org.tensorflow.op.core.Split;
@@ -38,13 +40,12 @@
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.proto.framework.RunOptions;
-import org.tensorflow.ndarray.Shape;
-import org.tensorflow.ndarray.NdArrays;
-import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
-/** Unit tests for {@link org.tensorflow.Session}. */
+/**
+ * Unit tests for {@link org.tensorflow.Session}.
+ */
public class SessionTest {
@Test
@@ -52,12 +53,12 @@ public void runUsingOperationNames() {
try (Graph g = new Graph();
Session s = new Session(g)) {
Ops tf = Ops.create(g);
- transpose_A_times_X(tf, new int[][] {{2}, {3}});
- try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}));
+ transpose_A_times_X(tf, new int[][]{{2}, {3}});
+ try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}));
AutoCloseableList outputs =
new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) {
assertEquals(1, outputs.size());
- assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0));
+ assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0));
}
}
}
@@ -67,14 +68,14 @@ public void runUsingOperationHandles() {
try (Graph g = new Graph();
Session s = new Session(g)) {
Ops tf = Ops.create(g);
- transpose_A_times_X(tf, new int[][] {{2}, {3}});
+ transpose_A_times_X(tf, new int[][]{{2}, {3}});
Output feed = g.operation("X").output(0);
Output fetch = g.operation("Y").output(0);
- try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}));
+ try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}));
AutoCloseableList outputs =
new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) {
assertEquals(1, outputs.size());
- assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0));
+ assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0));
}
}
}
@@ -88,18 +89,18 @@ public void runUsingColonSeparatedNames() {
tf.math.add(split.output().get(0), split.output().get(1));
// Fetch using colon separated names.
- try (TInt32 fetched = (TInt32)s.runner().fetch("Split:1").run().get(0)) {
+ try (TInt32 fetched = (TInt32) s.runner().fetch("Split:1").run().get(0)) {
assertEquals(3, fetched.getInt(0));
assertEquals(4, fetched.getInt(1));
}
// Feed using colon separated names.
try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1);
TInt32 fetched = (TInt32) s.runner()
- .feed("Split:0", fed)
- .feed("Split:1", fed)
- .fetch("Add")
- .run()
- .get(0)) {
+ .feed("Split:0", fed)
+ .feed("Split:1", fed)
+ .fetch("Add")
+ .run()
+ .get(0)) {
assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched);
}
}
@@ -110,17 +111,17 @@ public void runWithMetadata() {
try (Graph g = new Graph();
Session s = new Session(g)) {
Ops tf = Ops.create(g);
- transpose_A_times_X(tf, new int[][] {{2}, {3}});
- try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) {
+ transpose_A_times_X(tf, new int[][]{{2}, {3}});
+ try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}))) {
Session.Run result = s.runner()
- .feed("X", x)
- .fetch("Y")
- .setOptions(fullTraceRunOptions())
- .runAndFetchMetadata();
+ .feed("X", x)
+ .fetch("Y")
+ .setOptions(fullTraceRunOptions())
+ .runAndFetchMetadata();
// Sanity check on outputs.
AutoCloseableList outputs = new AutoCloseableList<>(result.outputs);
assertEquals(1, outputs.size());
- assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0));
+ assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0));
// Sanity check on metadata
assertNotNull(result.metadata);
assertTrue(result.metadata.hasStepStats(), result.metadata.toString());
@@ -139,8 +140,8 @@ public void runMultipleOutputs() {
AutoCloseableList outputs =
new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run());
assertEquals(2, outputs.size());
- assertEquals(31415, ((TInt32)outputs.get(0)).getInt());
- assertEquals(2718, ((TInt32)outputs.get(1)).getInt());
+ assertEquals(31415, ((TInt32) outputs.get(0)).getInt());
+ assertEquals(2718, ((TInt32) outputs.get(1)).getInt());
outputs.close();
}
}
@@ -162,7 +163,8 @@ public void failOnUseAfterClose() {
@Test
public void createWithConfigProto() {
try (Graph g = new Graph();
- Session s = new Session(g, singleThreadConfigProto())) {}
+ Session s = new Session(g, singleThreadConfigProto())) {
+ }
}
@Test
@@ -213,12 +215,14 @@ public void runInitByName() {
}
@Test
- public void saveAndRestore() throws IOException {
+ public void saveAndRestore() throws IOException {
Path testFolder = Files.createTempDirectory("tf-session-save-restore-test");
try (Graph g = new Graph()) {
Ops tf = Ops.create(g);
- Variable x = tf.withName("x").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
- Variable y = tf.withName("y").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
+ Variable x = tf.withName("x")
+ .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
+ Variable y = tf.withName("y")
+ .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
Init init = tf.init();
try (Session s = new Session(g)) {
@@ -231,9 +235,10 @@ public void saveAndRestore() throws IOException {
try (Session restoredSession = new Session(restoredGraph)) {
restoredSession.restore(testFolder.resolve("checkpoint").toString());
try (AutoCloseableList oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run());
- AutoCloseableList newList = new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())){
- assertEquals(oldList.get(0),newList.get(0));
- assertEquals(oldList.get(1),newList.get(1));
+ AutoCloseableList newList = new AutoCloseableList<>(
+ restoredSession.runner().fetch("x").fetch("y").run())) {
+ assertEquals(oldList.get(0), newList.get(0));
+ assertEquals(oldList.get(1), newList.get(1));
}
}
}
@@ -244,9 +249,54 @@ public void saveAndRestore() throws IOException {
// Cleanup test dir
Files.walk(testFolder)
- .sorted(Comparator.reverseOrder())
- .map(Path::toFile)
- .forEach(File::delete);
+ .sorted(Comparator.reverseOrder())
+ .map(Path::toFile)
+ .forEach(File::delete);
+ }
+
+ @Test
+ public static void testFetchVariable() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+ Ops tf = Ops.create(g);
+ Operand> variable = tf.varHandleOp(TInt32.class, Shape.scalar());
+ Op assign = tf.assignVariableOp(variable, tf.constant(2));
+
+ try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) {
+ assertEquals(2, value.getInt());
+ }
+
+ }
+ }
+
+ private static int numOperations(Graph g) {
+ int numOperations = 0;
+ for (Iterator it = g.operations(); it.hasNext(); ) {
+ Operation o = it.next();
+ numOperations++;
+ }
+ return numOperations;
+ }
+
+ @Test
+ public static void testFetchVariableReusingRead() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+ Ops tf = Ops.create(g);
+ Operand> variable = tf.varHandleOp(TInt32.class, Shape.scalar());
+ Op assign = tf.assignVariableOp(variable, tf.constant(2));
+
+ Operand read = tf.readVariableOp(variable, TInt32.class);
+
+ int ops = numOperations(g);
+
+ try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) {
+ assertEquals(2, value.getInt());
+ }
+
+ assertEquals(0, numOperations(g) - ops);
+
+ }
}
private static RunOptions fullTraceRunOptions() {
diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb
index 169e0095a3e..d9498dd4b74 100644
Binary files a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb differ
diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001
index ac369237d31..2f870d59b9c 100644
Binary files a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 differ
diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index
index 8be702a36ed..ed8ff96c1d6 100644
Binary files a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index differ
diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py
index f5160401515..c8dfc5a15bf 100644
--- a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py
+++ b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py
@@ -28,6 +28,7 @@ def __init__(self):
self.const_scalar = tf.constant(0.0)
self.const_vector = tf.constant([0.0, 0.0, 0.0])
self.const_matrix = tf.constant([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
+ self.variable = tf.Variable(2.0)
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='request')])
def serve(self, x):
@@ -41,7 +42,8 @@ def get_scalar(self, x):
def get_vector(self, x):
return self.const_vector + x
- @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
+ @tf.function(input_signature=[
+ tf.TensorSpec(shape=None, dtype=tf.float32, name='input')])
def get_matrix(self, x):
return self.const_matrix + x
@@ -51,13 +53,18 @@ def get_matrix(self, x):
def add(self, a, b):
return a + b
+ @tf.function(input_signature=[])
+ def get_variable(self):
+ return self.variable
+
model = MyModel()
signatures = {
"get_const_scalar": model.get_scalar,
"get_const_vector": model.get_vector,
"get_const_matrix": model.get_matrix,
- "add": model.add
+ "add": model.add,
+ "get_variable": model.get_variable
}
tf.saved_model.save(obj=model, export_dir='model', signatures=signatures)