Skip to content

Commit e4a11b3

Browse files
authored
Nicer error messages for mode-forbidden ops (#169)
* start fobbiden ops checks Signed-off-by: Ryan Nett <[email protected]> * fix style Signed-off-by: Ryan Nett <[email protected]> * move checks to builder method Signed-off-by: Ryan Nett <[email protected]>
1 parent f85623e commit e4a11b3

File tree

5 files changed

+43
-9
lines changed

5 files changed

+43
-9
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,15 @@ public EagerOperation build() {
7575

7676
@Override
7777
public EagerOperationBuilder addInput(Output<?> input) {
78-
addInput(opHandle, (TFE_TensorHandle)input.getUnsafeNativeHandle());
78+
addInput(opHandle, (TFE_TensorHandle) input.getUnsafeNativeHandle());
7979
return this;
8080
}
8181

8282
@Override
8383
public EagerOperationBuilder addInputList(Output<?>[] inputs) {
8484
TFE_TensorHandle[] inputHandles = new TFE_TensorHandle[inputs.length];
8585
for (int i = 0; i < inputs.length; ++i) {
86-
inputHandles[i] = (TFE_TensorHandle)inputs[i].getUnsafeNativeHandle();
86+
inputHandles[i] = (TFE_TensorHandle) inputs[i].getUnsafeNativeHandle();
8787
}
8888
addInputList(opHandle, inputHandles);
8989
return this;
@@ -226,7 +226,9 @@ public EagerOperationBuilder setAttr(String name, Shape[] values) {
226226
private final String type;
227227
private final String name;
228228

229-
/** This value should be >= to the maximum number of outputs in any op */
229+
/**
230+
* This value should be >= to the maximum number of outputs in any op
231+
*/
230232
private static final int MAX_OUTPUTS_PER_OP = 1000;
231233

232234
private static void requireOp(TFE_Op handle) {
@@ -358,7 +360,7 @@ private static void setAttrFloatList(TFE_Op opHandle, String name, float[] value
358360

359361
private static void setAttrBool(TFE_Op opHandle, String name, boolean value) {
360362
requireOp(opHandle);
361-
TFE_OpSetAttrBool(opHandle, name, (byte)(value ? 1 : 0));
363+
TFE_OpSetAttrBool(opHandle, name, (byte) (value ? 1 : 0));
362364
}
363365

364366
private static void setAttrBoolList(TFE_Op opHandle, String name, boolean[] values) {
@@ -410,7 +412,7 @@ private static void setAttrShapeList(TFE_Op opHandle, String name, long[] shapes
410412
}
411413
TF_Status status = TF_Status.newStatus();
412414
TFE_OpSetAttrShapeList(opHandle, new BytePointer(name), shapesPointers, new IntPointer(numDims),
413-
numDims.length, status);
415+
numDims.length, status);
414416
}
415417
}
416418
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
import org.tensorflow.internal.c_api.TFE_Context;
2828
import org.tensorflow.internal.c_api.TFE_ContextOptions;
2929
import org.tensorflow.internal.c_api.TF_Status;
30+
import org.tensorflow.op.core.Assign;
31+
import org.tensorflow.op.core.Placeholder;
32+
import org.tensorflow.op.core.Variable;
3033
import org.tensorflow.proto.framework.ConfigProto;
3134

3235
/**
@@ -271,6 +274,9 @@ static void closeDefaultForTest() {
271274
@Override
272275
public OperationBuilder opBuilder(String type, String name) {
273276
checkSession();
277+
if (!isOpEnabled(type)) {
278+
throw new IllegalArgumentException("Op " + type + " is not valid in eager mode.");
279+
}
274280
return new EagerOperationBuilder(this, type, name);
275281
}
276282

@@ -279,6 +285,18 @@ public Types environmentType() {
279285
return Types.EAGER;
280286
}
281287

288+
@Override
289+
public boolean isOpEnabled(String opType) {
290+
switch (opType) {
291+
case Variable.OP_NAME:
292+
case Placeholder.OP_NAME:
293+
case Assign.OP_NAME:
294+
return false;
295+
default:
296+
return true;
297+
}
298+
}
299+
282300
TFE_Context nativeHandle() {
283301
checkSession();
284302
return nativeHandle;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ enum Types {
3434
*/
3535
OperationBuilder opBuilder(String type, String name);
3636

37+
/**
38+
* Returns true if the given operation is valid in this execution environment.
39+
* @param opType The op to check.
40+
* @return Whether the given operation is valid in this execution environment.
41+
*/
42+
default boolean isOpEnabled(String opType){
43+
return true;
44+
}
45+
3746
/**
3847
* Get the type of this environment (from the `Environments` enumeration.
3948
*

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ public Iterator<Operation> operations() {
147147
*/
148148
@Override
149149
public GraphOperationBuilder opBuilder(String type, String name) {
150+
if (!isOpEnabled(type)) {
151+
throw new IllegalArgumentException("Op " + type + " is not valid in graph mode.");
152+
}
150153
return new GraphOperationBuilder(this, type, name);
151154
}
152155

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@
5454
import org.tensorflow.ndarray.Shape;
5555
import org.tensorflow.proto.framework.DataType;
5656

57-
/** An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}. */
57+
/**
58+
* An {@link OperationBuilder} for adding {@link GraphOperation}s to a {@link Graph}.
59+
*/
5860
public final class GraphOperationBuilder implements OperationBuilder {
5961

6062
GraphOperationBuilder(Graph graph, String type, String name) {
@@ -103,7 +105,7 @@ public GraphOperationBuilder addControlInput(Operation control) {
103105
public GraphOperationBuilder addInput(Output<?> input) {
104106
Graph.Reference r = graph.ref();
105107
try {
106-
addInput(unsafeNativeHandle, (TF_Operation)input.getUnsafeNativeHandle(), input.index());
108+
addInput(unsafeNativeHandle, (TF_Operation) input.getUnsafeNativeHandle(), input.index());
107109
} finally {
108110
r.close();
109111
}
@@ -117,7 +119,7 @@ public GraphOperationBuilder addInputList(Output<?>[] inputs) {
117119
TF_Operation[] opHandles = new TF_Operation[inputs.length];
118120
int[] indices = new int[inputs.length];
119121
for (int i = 0; i < inputs.length; ++i) {
120-
opHandles[i] = (TF_Operation)inputs[i].getUnsafeNativeHandle();
122+
opHandles[i] = (TF_Operation) inputs[i].getUnsafeNativeHandle();
121123
indices[i] = inputs[i].index();
122124
}
123125
addInputList(unsafeNativeHandle, opHandles, indices);
@@ -444,7 +446,7 @@ private static void setAttrFloatList(TF_OperationDescription handle, String name
444446

445447
private static void setAttrBool(TF_OperationDescription handle, String name, boolean value) {
446448
requireHandle(handle);
447-
TF_SetAttrBool(handle, name, (byte)(value ? 1 : 0));
449+
TF_SetAttrBool(handle, name, (byte) (value ? 1 : 0));
448450
}
449451

450452
private static void setAttrBoolList(TF_OperationDescription handle, String name, boolean[] value) {

0 commit comments

Comments
 (0)