Skip to content

Add fetchVariable method to Session to get value of resource variable #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
package org.tensorflow;

import static org.tensorflow.Graph.resolveOutputs;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrType;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;

Expand All @@ -38,8 +36,12 @@
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.internal.types.registry.TensorTypeRegistry;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ReadVariableOp;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.RunMetadata;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.proto.util.SaverDef;
Expand Down Expand Up @@ -192,6 +194,10 @@ public Runner feed(String operation, int index, Tensor t) {
* @return this session runner
*/
public Runner feed(Operand<?> operand, Tensor t) {
if (operand.env() != graph) {
throw new IllegalStateException("Can't feed value to operand " + operand + ", it is from a different graph.");
}

inputs.add(operand.asOutput());
inputTensors.add(t);
return this;
Expand All @@ -200,6 +206,8 @@ public Runner feed(Operand<?> operand, Tensor t) {
/**
* Make {@link #run()} return the output of {@code operation}.
*
* If the output is a resource variable, will fetch the value.
*
* @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code
* fetch(operation, 0)}, or it is a string of the form
* <tt>operation_name:output_index</tt> , in which case this method acts like {@code
Expand All @@ -215,6 +223,8 @@ public Runner fetch(String operation) {
/**
* Make {@link #run()} return the {@code index}-th output of {@code operation}.
*
* If the output is a resource variable, will fetch the value.
*
* <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
* one to return.
*
Expand All @@ -225,24 +235,59 @@ public Runner fetch(String operation) {
*/
public Runner fetch(String operation, int index) {
Operation op = graph.operationOrThrow(operation);
outputs.add(op.output(index));
return this;
return fetch(op.output(index));
}

/**
* Makes {@link #run()} return the Tensor referred to by {@code output}.
*
* If {@code output} is a resource variable, will fetch the value.
*
* @param output the node to fetch the tensor from
* @return this session runner
*/
public Runner fetch(Output<?> output) {
outputs.add(output);
if (output.env() != graph) {
throw new IllegalStateException("Can't fetch output " + output + ", it is from a different graph.");
}

if (output.dataType() == DataType.DT_RESOURCE) {
int[] rawDt = new int[1];

GraphOperation graphOp = (GraphOperation) output.op();

try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
TF_OperationGetAttrType(graphOp.getUnsafeNativeHandle(), "dtype", rawDt, status);
status.throwExceptionIfNotOK();
}

DataType valueDt = DataType.forNumber(rawDt[0]);

Operand<?> read = null;
for (GraphOperation op : graphOp.consumers()) {
if (op.dtype(0) == valueDt && op.type().equals(ReadVariableOp.OP_NAME)) {
read = op.output(0);
}
}

if (read == null) {
read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read")
.readVariableOp(output, TensorTypeRegistry.find(valueDt).type());
}

outputs.add(read.asOutput());
} else {
outputs.add(output);
}
return this;
}

/**
* Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
*
* If {@code operand} is a resource variable, will fetch the value.
*
* @param operand the node to fetch the tensor from, as an operand
* @return this session runner
*/
Expand All @@ -258,9 +303,7 @@ public Runner fetch(Operand<?> operand) {
* @throws IllegalArgumentException if no operation exists with the provided name
*/
public Runner addTarget(String operation) {
GraphOperation op = graph.operationOrThrow(operation);
targets.add(op);
return this;
return addTarget(graph.operationOrThrow(operation));
}

/**
Expand All @@ -269,13 +312,11 @@ public Runner addTarget(String operation) {
* @param operation the operation to execute
* @return this session runner
* @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
* @throws IllegalStateException if the operation is not from the session's graph.
*/
public Runner addTarget(Operation operation) {
if (!(operation instanceof GraphOperation)) {
throw new IllegalArgumentException(
"Operation of type "
+ operation.getClass().getName()
+ " is not supported in graph sessions");
if (operation.env() != graph) {
throw new IllegalStateException("Can't fetch operation " + operation + ", it is from a different graph.");
}
targets.add((GraphOperation) operation);
return this;
Expand Down Expand Up @@ -594,12 +635,12 @@ private static void delete(TF_Session handle) {
*
* @param handle to the C API TF_Session object (Session.nativeHandle)
* @param runOptions A RunOptions protocol buffer, or null
* @param inputOpHandles (see inputOpIndices)
* @param inputOpIndices (see inputTensorHandles)
* @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed"
* (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a
* Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus,
* it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
* @param inputOpHandles (see inputOpIndices)
* @param inputOpIndices (see inputTensorHandles)
* @param outputOpHandles (see outputOpIndices)
* @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The
* outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length ==
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Comparator;

import java.util.Iterator;
import org.junit.jupiter.api.Test;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Init;
import org.tensorflow.op.core.Split;
Expand All @@ -38,26 +40,25 @@
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;

/** Unit tests for {@link org.tensorflow.Session}. */
/**
* Unit tests for {@link org.tensorflow.Session}.
*/
public class SessionTest {

@Test
public void runUsingOperationNames() {
try (Graph g = new Graph();
Session s = new Session(g)) {
Ops tf = Ops.create(g);
transpose_A_times_X(tf, new int[][] {{2}, {3}});
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}));
transpose_A_times_X(tf, new int[][]{{2}, {3}});
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}));
AutoCloseableList<Tensor> outputs =
new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) {
assertEquals(1, outputs.size());
assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0));
assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0));
}
}
}
Expand All @@ -67,14 +68,14 @@ public void runUsingOperationHandles() {
try (Graph g = new Graph();
Session s = new Session(g)) {
Ops tf = Ops.create(g);
transpose_A_times_X(tf, new int[][] {{2}, {3}});
transpose_A_times_X(tf, new int[][]{{2}, {3}});
Output<TInt32> feed = g.operation("X").output(0);
Output<TInt32> fetch = g.operation("Y").output(0);
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}));
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}));
AutoCloseableList<Tensor> outputs =
new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) {
assertEquals(1, outputs.size());
assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0));
assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0));
}
}
}
Expand All @@ -88,18 +89,18 @@ public void runUsingColonSeparatedNames() {
tf.math.add(split.output().get(0), split.output().get(1));

// Fetch using colon separated names.
try (TInt32 fetched = (TInt32)s.runner().fetch("Split:1").run().get(0)) {
try (TInt32 fetched = (TInt32) s.runner().fetch("Split:1").run().get(0)) {
assertEquals(3, fetched.getInt(0));
assertEquals(4, fetched.getInt(1));
}
// Feed using colon separated names.
try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1);
TInt32 fetched = (TInt32) s.runner()
.feed("Split:0", fed)
.feed("Split:1", fed)
.fetch("Add")
.run()
.get(0)) {
.feed("Split:0", fed)
.feed("Split:1", fed)
.fetch("Add")
.run()
.get(0)) {
assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched);
}
}
Expand All @@ -110,17 +111,17 @@ public void runWithMetadata() {
try (Graph g = new Graph();
Session s = new Session(g)) {
Ops tf = Ops.create(g);
transpose_A_times_X(tf, new int[][] {{2}, {3}});
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) {
transpose_A_times_X(tf, new int[][]{{2}, {3}});
try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}))) {
Session.Run result = s.runner()
.feed("X", x)
.fetch("Y")
.setOptions(fullTraceRunOptions())
.runAndFetchMetadata();
.feed("X", x)
.fetch("Y")
.setOptions(fullTraceRunOptions())
.runAndFetchMetadata();
// Sanity check on outputs.
AutoCloseableList<Tensor> outputs = new AutoCloseableList<>(result.outputs);
assertEquals(1, outputs.size());
assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0));
assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0));
// Sanity check on metadata
assertNotNull(result.metadata);
assertTrue(result.metadata.hasStepStats(), result.metadata.toString());
Expand All @@ -139,8 +140,8 @@ public void runMultipleOutputs() {
AutoCloseableList<Tensor> outputs =
new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run());
assertEquals(2, outputs.size());
assertEquals(31415, ((TInt32)outputs.get(0)).getInt());
assertEquals(2718, ((TInt32)outputs.get(1)).getInt());
assertEquals(31415, ((TInt32) outputs.get(0)).getInt());
assertEquals(2718, ((TInt32) outputs.get(1)).getInt());
outputs.close();
}
}
Expand All @@ -162,7 +163,8 @@ public void failOnUseAfterClose() {
@Test
public void createWithConfigProto() {
try (Graph g = new Graph();
Session s = new Session(g, singleThreadConfigProto())) {}
Session s = new Session(g, singleThreadConfigProto())) {
}
}

@Test
Expand Down Expand Up @@ -213,12 +215,14 @@ public void runInitByName() {
}

@Test
public void saveAndRestore() throws IOException {
public void saveAndRestore() throws IOException {
Path testFolder = Files.createTempDirectory("tf-session-save-restore-test");
try (Graph g = new Graph()) {
Ops tf = Ops.create(g);
Variable<TFloat32> x = tf.withName("x").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
Variable<TFloat32> y = tf.withName("y").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
Variable<TFloat32> x = tf.withName("x")
.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
Variable<TFloat32> y = tf.withName("y")
.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
Init init = tf.init();

try (Session s = new Session(g)) {
Expand All @@ -231,9 +235,10 @@ public void saveAndRestore() throws IOException {
try (Session restoredSession = new Session(restoredGraph)) {
restoredSession.restore(testFolder.resolve("checkpoint").toString());
try (AutoCloseableList<Tensor> oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run());
AutoCloseableList<Tensor> newList = new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())){
assertEquals(oldList.get(0),newList.get(0));
assertEquals(oldList.get(1),newList.get(1));
AutoCloseableList<Tensor> newList = new AutoCloseableList<>(
restoredSession.runner().fetch("x").fetch("y").run())) {
assertEquals(oldList.get(0), newList.get(0));
assertEquals(oldList.get(1), newList.get(1));
}
}
}
Expand All @@ -244,9 +249,54 @@ public void saveAndRestore() throws IOException {

// Cleanup test dir
Files.walk(testFolder)
.sorted(Comparator.reverseOrder())
.map(Path::toFile)
.forEach(File::delete);
.sorted(Comparator.reverseOrder())
.map(Path::toFile)
.forEach(File::delete);
}

@Test
public static void testFetchVariable() {
try (Graph g = new Graph();
Session s = new Session(g)) {
Ops tf = Ops.create(g);
Operand<?> variable = tf.varHandleOp(TInt32.class, Shape.scalar());
Op assign = tf.assignVariableOp(variable, tf.constant(2));

try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) {
assertEquals(2, value.getInt());
}

}
}

private static int numOperations(Graph g) {
int numOperations = 0;
for (Iterator<Operation> it = g.operations(); it.hasNext(); ) {
Operation o = it.next();
numOperations++;
}
return numOperations;
}

@Test
public static void testFetchVariableReusingRead() {
try (Graph g = new Graph();
Session s = new Session(g)) {
Ops tf = Ops.create(g);
Operand<?> variable = tf.varHandleOp(TInt32.class, Shape.scalar());
Op assign = tf.assignVariableOp(variable, tf.constant(2));

Operand<TInt32> read = tf.readVariableOp(variable, TInt32.class);

int ops = numOperations(g);

try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) {
assertEquals(2, value.getInt());
}

assertEquals(0, numOperations(g) - ops);

}
}

private static RunOptions fullTraceRunOptions() {
Expand Down