Skip to content

Commit bba783a

Browse files
karllessardKarl Lessard
and
Karl Lessard
authored
Save models as functions (#103)
* Draft: Java API to use tf.function available on SavedModel. (#89) Python models that contain tf.function is inconvenient to be consumed by Java clients. This proposal provides an API to (a) Invoke a tf.function, given the signature name (b) Retrieve the node name in the graph corresponding to a tf.function Co-authored-by: Shajan Dasan <[email protected]> * Change API for creating concrete functions and exporting them to a saved model Co-authored-by: Karl Lessard <[email protected]>
1 parent d00975b commit bba783a

File tree

9 files changed

+1058
-260
lines changed

9 files changed

+1058
-260
lines changed
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
/*
2+
* Copyright 2020 The TensorFlow Authors. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.tensorflow;
17+
18+
import java.io.IOException;
19+
import java.util.List;
20+
import java.util.ListIterator;
21+
import java.util.HashMap;
22+
import java.util.Map;
23+
import java.util.function.Function;
24+
import org.tensorflow.op.Ops;
25+
import org.tensorflow.proto.framework.SignatureDef;
26+
import org.tensorflow.proto.framework.TensorInfo;
27+
28+
/**
29+
* A graph that can be invoked as a single function, with an input and output signature.
30+
*
31+
* <p>A function can also invoke a
32+
* <a href="https://www.tensorflow.org/api_docs/python/tf/function">tf.function</a>
33+
* defined in a {@link SavedModelBundle}.
34+
*
35+
* <pre>{@code
36+
* ConcreteFunction myFunction = savedModelBundle.function("myFunctionSignatureName");
37+
* Map<String, Tensor<?>> outputTensorMap = myFunction.call(inputTensorMap);
38+
* }</pre>
39+
*/
40+
public class ConcreteFunction implements AutoCloseable {
41+
42+
/**
43+
* Creates a function by building a new graph.
44+
*
45+
* <p/>The {@code functionBuilder} must initialize the function graph from the provided
46+
* {@link Ops} instance and return a valid signature that will be used to feed the input tensors
47+
* and fetch the output tensors on execution.
48+
*
49+
* <p/>The function will be the owner of the new graph and its resulting session. Therefore,
50+
* the function must be enclosed properly with a try-with-resources block to guarantee that
51+
* all native resources will be freed once the function is discarded. For example:
52+
*
53+
* <pre>{@code
54+
* public class MyModel {
55+
*
56+
* public static Signature addTwo(Ops tf) {
57+
* Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
58+
* Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f));
59+
* return Signature.builder("addTwo").input("x", input).output("y", output).build();
60+
* }
61+
*
62+
* public static void main(String args[]) {
63+
* try (ConcreteFunction function = ConcreteFunction.create(MyModel::addTwo);
64+
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
65+
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
66+
* }
67+
* }
68+
* }
69+
* }</pre>
70+
*
71+
* @param functionBuilder function builder
72+
* @return the new function
73+
*/
74+
public static ConcreteFunction create(Function<Ops, Signature> functionBuilder) {
75+
Graph graph = new Graph();
76+
try {
77+
Ops tf = Ops.create(graph);
78+
Signature signature = functionBuilder.apply(tf);
79+
return new ConcreteFunction(signature, graph, new Session(graph), Ownership.GRAPH_AND_SESSION);
80+
} catch (Exception e) {
81+
graph.close();
82+
throw e;
83+
}
84+
}
85+
86+
/**
87+
* Create a function from a signature and an existing graph.
88+
*
89+
* <p/>The function will keep the ownership of the session used to run the graph but not
90+
* the graph itself, meaning that the lifetime of the latter can extend beyond the scope
91+
* of the function. For example:
92+
*
93+
* <pre>{@code
94+
* try (Graph g = new Graph()) {
95+
* Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
96+
* Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f));
97+
* Signature signature = Signature.builder().input("x", input).output("y", output).build();
98+
*
99+
* try (ConcreteFunction f = ConcreteFunction.create(signature, g);
100+
* Tensor<TFloat32> x = TFloat32.scalarOf(2.0f)) {
101+
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
102+
* }
103+
* // Graph g is still valid at this point
104+
* }
105+
* }</pre>
106+
*
107+
* @param signature signature of the function to create
108+
* @param graph a valid and initialized graph
109+
* @return a new function
110+
*/
111+
public static ConcreteFunction create(Signature signature, Graph graph) {
112+
return new ConcreteFunction(signature, graph, new Session(graph), Ownership.SESSION_ONLY);
113+
}
114+
115+
/**
116+
* Create a function from a signature and a valid graph session.
117+
*
118+
* <p/>The function will not own the session nor its graph, meaning that their lifetime
119+
* can extend beyond the scope of the function. Therefore the function does not need to be
120+
* closed after its usage. For example:
121+
*
122+
* <pre>{@code
123+
* try (Graph g = new Graph()) {
124+
* Placeholder<TFloat32> input = tf.placeholder(TFloat32.DTYPE);
125+
* Add<TFloat32> output = tf.math.add(input, tf.constant(2.0f));
126+
* Signature signature = Signature.builder().input("x", input).output("y", output).build();
127+
*
128+
* try (Session s = new Session(g)) {
129+
* // Auto-closing the function just as an example but this is not required since it has
130+
* // no effect
131+
* try (ConcreteFunction f = ConcreteFunction.create(signature, s);
132+
* Tensor<TFloat32> t = TFloat32.scalarOf(2.0f)) {
133+
* assertEquals(4.0f, function.call(x).expect(TFloat32.DTYPE).data().getFloat());
134+
* }
135+
* // Session s is still valid at this point
136+
* }
137+
* // Graph g is still valid at this point
138+
* }
139+
* }</pre>
140+
*
141+
* @param signature signature of the function to create
142+
* @param graph a valid session to an initialized graph
143+
* @return a new function
144+
*/
145+
public static ConcreteFunction create(Signature signature, Session session) {
146+
return new ConcreteFunction(signature, session.graph(), session, Ownership.NONE);
147+
}
148+
149+
/**
150+
* Returns the signature of this function
151+
*/
152+
public Signature signature() {
153+
return signature;
154+
}
155+
156+
/**
157+
* Invokes a function.
158+
*
159+
* <p>Caller is responsible for closing all Tensors.
160+
*
161+
* @param tensor input tensor
162+
* @return output tensor
163+
*/
164+
public Map<String, Tensor<?>> call(Map<String, Tensor<?>> arguments)
165+
throws IllegalArgumentException {
166+
167+
final SignatureDef signatureDef = signature.asSignatureDef();
168+
final Session.Runner runner = session.runner();
169+
170+
signatureDef.getInputsMap().forEach((argName, t) -> {
171+
Tensor<?> tensor = arguments.get(argName);
172+
if (tensor == null) {
173+
throw new IllegalArgumentException(String.format("Missing argument [%s]", argName));
174+
}
175+
runner.feed(t.getName(), tensor);
176+
});
177+
178+
Map<String, TensorInfo> outputToNode = signatureDef.getOutputsMap();
179+
outputToNode.values().forEach(t -> runner.fetch(t.getName()));
180+
181+
List<Tensor<?>> resultTensors = runner.run();
182+
try {
183+
ListIterator<Tensor<?>> resultTensorIter = resultTensors.listIterator();
184+
Map<String, Tensor<?>> returnMap = new HashMap<String, Tensor<?>>();
185+
186+
// Use the output names as present in the signature definition
187+
for (String nodeName: outputToNode.keySet()) {
188+
returnMap.put(nodeName, resultTensorIter.next());
189+
}
190+
return returnMap;
191+
192+
} catch (Exception e) {
193+
// Release tensors before throwing exception
194+
for (Tensor<?> t : resultTensors) {
195+
t.close();
196+
}
197+
throw e;
198+
}
199+
}
200+
201+
/**
202+
* Invokes a function with a single input and output.
203+
*
204+
* <p>Caller is responsible for closing all Tensors.
205+
*
206+
* @param tensor input tensor
207+
* @return output tensor
208+
* @throws IllegalArgumentException if there are multiple input or output parameters defined
209+
* in the function
210+
*/
211+
public Tensor<?> call(Tensor<?> tensor) throws IllegalArgumentException {
212+
final SignatureDef signatureDef = signature.asSignatureDef();
213+
214+
if (signatureDef.getInputsCount() != 1) {
215+
throw new IllegalArgumentException(
216+
String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName()));
217+
}
218+
String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName();
219+
220+
if (signatureDef.getOutputsCount() != 1) {
221+
throw new IllegalArgumentException(
222+
String.format("Function [%s] has multiple outputs", signatureDef.getMethodName()));
223+
}
224+
String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName();
225+
226+
return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0);
227+
}
228+
229+
/**
230+
* Export this function as a saved model.
231+
*
232+
* <p>This method is convenient shortcut equivalent to
233+
* {@code SavedModel.exporter(exportDir).withFunction(this).export()}
234+
*/
235+
public void save(String exportDir) throws IOException {
236+
SavedModelBundle.exporter(exportDir)
237+
.withFunction(this)
238+
.export();
239+
}
240+
241+
/**
242+
* Returns the session used to execute the graph when calling this function
243+
*
244+
* <p>In general, a user does not need to handle directly the session of a function and rely
245+
* on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to
246+
* the session might be necessary, as it allows more running options.
247+
*
248+
* @return the function session
249+
*/
250+
public Session session() {
251+
return session;
252+
}
253+
254+
/**
255+
* Returns the graph of this function
256+
*/
257+
public Graph graph() {
258+
return graph;
259+
}
260+
261+
@Override
262+
public void close() {
263+
if (ownership != Ownership.NONE) {
264+
session.close();
265+
if (ownership == Ownership.GRAPH_AND_SESSION) {
266+
graph.close();
267+
}
268+
}
269+
}
270+
271+
private enum Ownership {
272+
GRAPH_AND_SESSION, SESSION_ONLY, NONE;
273+
}
274+
275+
private final Graph graph;
276+
private final Session session;
277+
private final Signature signature;
278+
private final Ownership ownership;
279+
280+
ConcreteFunction(Signature signature, Graph graph, Session session, Ownership ownership) {
281+
this.graph = graph;
282+
this.session = session;
283+
this.signature = signature;
284+
this.ownership = ownership;
285+
}
286+
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,17 @@
4343
import org.tensorflow.internal.c_api.TF_Output;
4444
import org.tensorflow.internal.c_api.TF_Status;
4545
import org.tensorflow.internal.c_api.TF_WhileParams;
46+
import org.tensorflow.ndarray.StdArrays;
4647
import org.tensorflow.op.Op;
48+
import org.tensorflow.op.Ops;
49+
import org.tensorflow.op.core.Constant;
50+
import org.tensorflow.op.core.NoOp;
51+
import org.tensorflow.op.core.Placeholder;
52+
import org.tensorflow.op.train.Restore;
53+
import org.tensorflow.op.train.Save;
4754
import org.tensorflow.proto.framework.GraphDef;
55+
import org.tensorflow.proto.util.SaverDef;
56+
import org.tensorflow.types.TString;
4857

4958

5059
/**
@@ -67,6 +76,11 @@ public Graph() {
6776
this.nativeHandle = nativeHandle;
6877
}
6978

79+
Graph(TF_Graph nativeHandle, SaverDef saverDef) {
80+
this(nativeHandle);
81+
this.saverDef = saverDef;
82+
}
83+
7084
/**
7185
* Release resources associated with the Graph.
7286
*
@@ -402,9 +416,27 @@ public Output<?>[] whileLoop(
402416
}
403417
}
404418

419+
/**
420+
* Return the {@link SaverDef} instance used to save the state of all variables present in
421+
* this graph.
422+
*
423+
* <p/>On the first call of this method, all nodes necessary to save and restore the state of the
424+
* variables are added to the graph. Consequently, any variables that are added to the graph after
425+
* this call could not be saved nor restored using this {@link SaverDef}.
426+
*
427+
* @return a {@link SaverDef} instance
428+
*/
429+
synchronized SaverDef saverDef() {
430+
if (saverDef == null) {
431+
saverDef = addVariableSaver(this);
432+
}
433+
return saverDef;
434+
}
435+
405436
private final Object nativeHandleLock = new Object();
406437
private TF_Graph nativeHandle;
407438
private int refcount = 0;
439+
private SaverDef saverDef;
408440

409441
private final List<Op> initializers = new ArrayList<>();
410442

@@ -726,6 +758,53 @@ private static Object[] whileLoop(
726758
}
727759
}
728760

761+
private static SaverDef addVariableSaver(Graph graph) {
762+
Ops tf = Ops.create(graph).withSubScope("save");
763+
764+
List<String> varNames = new ArrayList<>();
765+
List<Operand<?>> varOutputs = new ArrayList<>();
766+
List<DataType<?>> varTypes = new ArrayList<>();
767+
768+
for (Iterator<Operation> iter = graph.operations(); iter.hasNext();) {
769+
Operation op = iter.next();
770+
if (op.type().equals("VariableV2")) {
771+
varNames.add(op.name());
772+
varOutputs.add(op.output(0));
773+
varTypes.add(op.output(0).dataType());
774+
}
775+
}
776+
777+
// FIXME Need an easier way to initialize an NdArray from a list
778+
String[] tmp = new String[varNames.size()];
779+
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
780+
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);
781+
782+
Placeholder<TString> saveFilename = tf.placeholder(TString.DTYPE);
783+
Save saveVariables = tf.train.save(
784+
saveFilename,
785+
varNamesTensor,
786+
varSlices,
787+
varOutputs
788+
);
789+
Restore restoreVariables = tf.train.restore(
790+
saveFilename,
791+
varNamesTensor,
792+
varSlices,
793+
varTypes
794+
);
795+
List<Op> restoreOps = new ArrayList<>(varOutputs.size());
796+
for (int i = 0; i < varOutputs.size(); ++i) {
797+
restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
798+
}
799+
NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp();
800+
801+
return SaverDef.newBuilder()
802+
.setFilenameTensorName(saveFilename.op().name())
803+
.setSaveTensorName(saveVariables.op().name())
804+
.setRestoreOpName(restoreAll.op().name())
805+
.build();
806+
}
807+
729808
static {
730809
TensorFlow.init();
731810
}

0 commit comments

Comments
 (0)