diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 6395b1770a9..58fb62b5fee 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -16,9 +16,7 @@ package org.tensorflow; import static org.tensorflow.Graph.resolveOutputs; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrType; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; @@ -38,8 +36,12 @@ import org.tensorflow.internal.c_api.TF_SessionOptions; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.registry.TensorTypeRegistry; import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReadVariableOp; import org.tensorflow.proto.framework.ConfigProto; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.util.SaverDef; @@ -192,6 +194,11 @@ public Runner feed(String operation, int index, Tensor t) { * @return this session runner */ public Runner feed(Operand operand, Tensor t) { + if (operand.env() != graph) { + throw new IllegalStateException("Can't feed value for operand " + operand + ", it is from " + + (operand.env().isEager() ? "an eager session" : "a different graph") + "."); + } + inputs.add(operand.asOutput()); inputTensors.add(t); return this; @@ -200,6 +207,8 @@ public Runner feed(Operand operand, Tensor t) { /** * Make {@link #run()} return the output of {@code operation}. * + * If the output is a resource variable, will fetch the value. + * * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code * fetch(operation, 0)}, or it is a string of the form * operation_name:output_index , in which case this method acts like {@code @@ -215,6 +224,8 @@ public Runner fetch(String operation) { /** * Make {@link #run()} return the {@code index}-th output of {@code operation}. * + * If the output is a resource variable, will fetch the value. + * *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one to return. * @@ -225,24 +236,61 @@ public Runner fetch(String operation) { */ public Runner fetch(String operation, int index) { Operation op = graph.operationOrThrow(operation); - outputs.add(op.output(index)); - return this; + return fetch(op.output(index)); } /** * Makes {@link #run()} return the Tensor referred to by {@code output}. * + * If {@code output} is a resource variable, will fetch the value. + * * @param output the node to fetch the tensor from * @return this session runner */ public Runner fetch(Output output) { - outputs.add(output); + if (output.env() != graph) { + throw new IllegalStateException("Can't fetch output " + output + ", it is from " + + (output.env().isEager() ? "an eager session" : "a different graph") + "."); + } + + if (output.dataType() == DataType.DT_RESOURCE) { + int[] rawDt = new int[1]; + + GraphOperation graphOp = (GraphOperation) output.op(); + + try (PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + TF_OperationGetAttrType(graphOp.getUnsafeNativeHandle(), "dtype", rawDt, status); + status.throwExceptionIfNotOK(); + } + + DataType valueDt = DataType.forNumber(rawDt[0]); + + Operand read = null; + for (GraphOperation op : graphOp.consumers()) { + if (op.dtype(0) == valueDt && op.type().equals(ReadVariableOp.OP_NAME)) { + read = op.output(0); + break; + } + } + + if (read == null) { + read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read") + .readVariableOp(output, TensorTypeRegistry.find(valueDt).type()); + } + + outputs.add(read.asOutput()); + } else { + outputs.add(output); + } return this; } /** * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. * + * If {@code operand} is a resource variable, will fetch the value. + * * @param operand the node to fetch the tensor from, as an operand * @return this session runner */ @@ -258,9 +306,7 @@ public Runner fetch(Operand operand) { * @throws IllegalArgumentException if no operation exists with the provided name */ public Runner addTarget(String operation) { - GraphOperation op = graph.operationOrThrow(operation); - targets.add(op); - return this; + return addTarget(graph.operationOrThrow(operation)); } /** @@ -269,13 +315,12 @@ public Runner addTarget(String operation) { * @param operation the operation to execute * @return this session runner * @throws IllegalArgumentException if the operation is not a {@link GraphOperation} + * @throws IllegalStateException if the operation is not from the session's graph. */ public Runner addTarget(Operation operation) { - if (!(operation instanceof GraphOperation)) { - throw new IllegalArgumentException( - "Operation of type " - + operation.getClass().getName() - + " is not supported in graph sessions"); + if (operation.env() != graph) { + throw new IllegalStateException("Can't target operation " + operation + ", it is from " + + (operation.env().isEager() ? "an eager session" : "a different graph") + "."); } targets.add((GraphOperation) operation); return this; @@ -594,12 +639,12 @@ private static void delete(TF_Session handle) { * * @param handle to the C API TF_Session object (Session.nativeHandle) * @param runOptions A RunOptions protocol buffer, or null - * @param inputOpHandles (see inputOpIndices) - * @param inputOpIndices (see inputTensorHandles) * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed" * (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a * Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, * it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. + * @param inputOpHandles (see inputOpIndices) + * @param inputOpIndices (see inputTensorHandles) * @param outputOpHandles (see outputOpIndices) * @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The * outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length == diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index cd8ac7e2ae4..ff93e317805 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -27,8 +27,8 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Collections; -import java.util.Map; import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.ndarray.FloatNdArray; @@ -292,21 +292,29 @@ public void pythonTfFunction() { ConcreteFunction add = bundle.function("add"); Map args = new HashMap(); try (TFloat32 a = TFloat32.scalarOf(10.0f); - TFloat32 b = TFloat32.scalarOf(15.5f)) { + TFloat32 b = TFloat32.scalarOf(15.5f)) { args.put("a", a); args.put("b", b); Map result = add.call(args); assertEquals(result.size(), 1); - try (TFloat32 c = (TFloat32)result.values().iterator().next()) { + try (TFloat32 c = (TFloat32) result.values().iterator().next()) { assertEquals(25.5f, c.getFloat()); } } + + // variable unwrapping happens in Session, which is used by ConcreteFunction.call + ConcreteFunction getVariable = bundle.function("get_variable"); + try (TFloat32 v = (TFloat32) getVariable.call(new HashMap<>()) + .get(getVariable.signature().outputNames().iterator().next())) { + assertEquals(2f, v.getFloat()); + } + } } private static Signature buildGraphWithVariables(Ops tf, Shape xShape) { Placeholder x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape)); - Variable y = tf + Variable y = tf.withName("variable") .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class)); ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); Init init = tf.init(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index d7ea381d315..4223a03ee23 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -20,15 +20,17 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -import java.io.BufferedOutputStream; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.util.Comparator; - +import java.util.Iterator; import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Split; @@ -38,13 +40,12 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.GraphDef; import org.tensorflow.proto.framework.RunOptions; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.StdArrays; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.Session}. */ +/** + * Unit tests for {@link org.tensorflow.Session}. + */ public class SessionTest { @Test @@ -52,12 +53,12 @@ public void runUsingOperationNames() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); + transpose_A_times_X(tf, new int[][]{{2}, {3}}); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}})); AutoCloseableList outputs = new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); + assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } } } @@ -67,14 +68,14 @@ public void runUsingOperationHandles() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); + transpose_A_times_X(tf, new int[][]{{2}, {3}}); Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}})); AutoCloseableList outputs = new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); + assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } } } @@ -88,18 +89,18 @@ public void runUsingColonSeparatedNames() { tf.math.add(split.output().get(0), split.output().get(1)); // Fetch using colon separated names. - try (TInt32 fetched = (TInt32)s.runner().fetch("Split:1").run().get(0)) { + try (TInt32 fetched = (TInt32) s.runner().fetch("Split:1").run().get(0)) { assertEquals(3, fetched.getInt(0)); assertEquals(4, fetched.getInt(1)); } // Feed using colon separated names. try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1); TInt32 fetched = (TInt32) s.runner() - .feed("Split:0", fed) - .feed("Split:1", fed) - .fetch("Add") - .run() - .get(0)) { + .feed("Split:0", fed) + .feed("Split:1", fed) + .fetch("Add") + .run() + .get(0)) { assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched); } } @@ -110,17 +111,17 @@ public void runWithMetadata() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { + transpose_A_times_X(tf, new int[][]{{2}, {3}}); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}))) { Session.Run result = s.runner() - .feed("X", x) - .fetch("Y") - .setOptions(fullTraceRunOptions()) - .runAndFetchMetadata(); + .feed("X", x) + .fetch("Y") + .setOptions(fullTraceRunOptions()) + .runAndFetchMetadata(); // Sanity check on outputs. AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); + assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); // Sanity check on metadata assertNotNull(result.metadata); assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); @@ -139,8 +140,8 @@ public void runMultipleOutputs() { AutoCloseableList outputs = new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run()); assertEquals(2, outputs.size()); - assertEquals(31415, ((TInt32)outputs.get(0)).getInt()); - assertEquals(2718, ((TInt32)outputs.get(1)).getInt()); + assertEquals(31415, ((TInt32) outputs.get(0)).getInt()); + assertEquals(2718, ((TInt32) outputs.get(1)).getInt()); outputs.close(); } } @@ -162,7 +163,8 @@ public void failOnUseAfterClose() { @Test public void createWithConfigProto() { try (Graph g = new Graph(); - Session s = new Session(g, singleThreadConfigProto())) {} + Session s = new Session(g, singleThreadConfigProto())) { + } } @Test @@ -213,12 +215,14 @@ public void runInitByName() { } @Test - public void saveAndRestore() throws IOException { + public void saveAndRestore() throws IOException { Path testFolder = Files.createTempDirectory("tf-session-save-restore-test"); try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable x = tf.withName("x").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); - Variable y = tf.withName("y").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable x = tf.withName("x") + .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable y = tf.withName("y") + .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); Init init = tf.init(); try (Session s = new Session(g)) { @@ -231,9 +235,10 @@ public void saveAndRestore() throws IOException { try (Session restoredSession = new Session(restoredGraph)) { restoredSession.restore(testFolder.resolve("checkpoint").toString()); try (AutoCloseableList oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); - AutoCloseableList newList = new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())){ - assertEquals(oldList.get(0),newList.get(0)); - assertEquals(oldList.get(1),newList.get(1)); + AutoCloseableList newList = new AutoCloseableList<>( + restoredSession.runner().fetch("x").fetch("y").run())) { + assertEquals(oldList.get(0), newList.get(0)); + assertEquals(oldList.get(1), newList.get(1)); } } } @@ -244,9 +249,54 @@ public void saveAndRestore() throws IOException { // Cleanup test dir Files.walk(testFolder) - .sorted(Comparator.reverseOrder()) - .map(Path::toFile) - .forEach(File::delete); + .sorted(Comparator.reverseOrder()) + .map(Path::toFile) + .forEach(File::delete); + } + + @Test + public static void testFetchVariable() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + Ops tf = Ops.create(g); + Operand variable = tf.varHandleOp(TInt32.class, Shape.scalar()); + Op assign = tf.assignVariableOp(variable, tf.constant(2)); + + try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) { + assertEquals(2, value.getInt()); + } + + } + } + + private static int numOperations(Graph g) { + int numOperations = 0; + for (Iterator it = g.operations(); it.hasNext(); ) { + Operation o = it.next(); + numOperations++; + } + return numOperations; + } + + @Test + public static void testFetchVariableReusingRead() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + Ops tf = Ops.create(g); + Operand variable = tf.varHandleOp(TInt32.class, Shape.scalar()); + Op assign = tf.assignVariableOp(variable, tf.constant(2)); + + Operand read = tf.readVariableOp(variable, TInt32.class); + + int ops = numOperations(g); + + try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) { + assertEquals(2, value.getInt()); + } + + assertEquals(0, numOperations(g) - ops); + + } } private static RunOptions fullTraceRunOptions() { diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb index 169e0095a3e..d9498dd4b74 100644 Binary files a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb differ diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 index ac369237d31..2f870d59b9c 100644 Binary files a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index index 8be702a36ed..ed8ff96c1d6 100644 Binary files a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index and b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index differ diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py index f5160401515..c8dfc5a15bf 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py +++ b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/source_model.py @@ -28,6 +28,7 @@ def __init__(self): self.const_scalar = tf.constant(0.0) self.const_vector = tf.constant([0.0, 0.0, 0.0]) self.const_matrix = tf.constant([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + self.variable = tf.Variable(2.0) @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='request')]) def serve(self, x): @@ -41,7 +42,8 @@ def get_scalar(self, x): def get_vector(self, x): return self.const_vector + x - @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32, name='input')]) + @tf.function(input_signature=[ + tf.TensorSpec(shape=None, dtype=tf.float32, name='input')]) def get_matrix(self, x): return self.const_matrix + x @@ -51,13 +53,18 @@ def get_matrix(self, x): def add(self, a, b): return a + b + @tf.function(input_signature=[]) + def get_variable(self): + return self.variable + model = MyModel() signatures = { "get_const_scalar": model.get_scalar, "get_const_vector": model.get_vector, "get_const_matrix": model.get_matrix, - "add": model.add + "add": model.add, + "get_variable": model.get_variable } tf.saved_model.save(obj=model, export_dir='model', signatures=signatures)