Skip to content

Commit ebdcc0a

Browse files
committed
Fix unknown shapes in signature proto
1 parent 87fcbec commit ebdcc0a

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,15 @@ public Signature build() {
158158
return new Signature(key, signatureBuilder.build());
159159
}
160160

161-
private static TensorInfo toTensorInfo(Output<?> operand) {
161+
static TensorInfo toTensorInfo(Output<?> operand) {
162162
Shape shape = operand.shape();
163163
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
164-
for (int i = 0; i < shape.numDimensions(); ++i) {
165-
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i)));
164+
if (shape.isUnknown()) {
165+
tensorShapeBuilder.setUnknownRank(true);
166+
} else {
167+
for (int i = 0; i < shape.numDimensions(); ++i) {
168+
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.get(i)));
169+
}
166170
}
167171
return TensorInfo.newBuilder()
168172
.setDtype(operand.dataType())

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
import java.util.Map;
2020
import org.junit.jupiter.api.Test;
2121
import org.tensorflow.Signature.TensorDescription;
22+
import org.tensorflow.ndarray.Shape;
2223
import org.tensorflow.op.Ops;
24+
import org.tensorflow.op.core.Placeholder;
25+
import org.tensorflow.op.math.Sign;
2326
import org.tensorflow.proto.DataType;
27+
import org.tensorflow.types.TInt32;
2428

2529
public class SignatureTest {
2630

@@ -80,4 +84,28 @@ public void emptyMethodNameConvertedToNull() {
8084
signature = Signature.builder().key("f").methodName(null).build();
8185
assertNull(signature.methodName());
8286
}
87+
88+
@Test
89+
public void createTensorInfoFromOperandWithUnknownShape() {
90+
try (Graph g = new Graph()) {
91+
var tf = Ops.create(g);
92+
var placeholder = tf.placeholder(TInt32.class);
93+
var tensorInfo = Signature.Builder.toTensorInfo(placeholder.asOutput());
94+
assertTrue(tensorInfo.getTensorShape().getUnknownRank());
95+
assertEquals(0, tensorInfo.getTensorShape().getDimCount());
96+
}
97+
}
98+
99+
@Test
100+
public void createTensorInfoFromOperandWithPartiallyUnknownShape() {
101+
try (Graph g = new Graph()) {
102+
var tf = Ops.create(g);
103+
var placeholder = tf.placeholder(TInt32.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE, 10)));
104+
var tensorInfo = Signature.Builder.toTensorInfo(placeholder.asOutput());
105+
assertFalse(tensorInfo.getTensorShape().getUnknownRank());
106+
assertEquals(2, tensorInfo.getTensorShape().getDimCount());
107+
assertEquals(-1, tensorInfo.getTensorShape().getDim(0).getSize());
108+
assertEquals(10, tensorInfo.getTensorShape().getDim(1).getSize());
109+
}
110+
}
83111
}

0 commit comments

Comments
 (0)