|
19 | 19 | import java.util.Map;
|
20 | 20 | import org.junit.jupiter.api.Test;
|
21 | 21 | import org.tensorflow.Signature.TensorDescription;
|
| 22 | +import org.tensorflow.ndarray.Shape; |
22 | 23 | import org.tensorflow.op.Ops;
|
| 24 | +import org.tensorflow.op.core.Placeholder; |
| 25 | +import org.tensorflow.op.math.Sign; |
23 | 26 | import org.tensorflow.proto.DataType;
|
| 27 | +import org.tensorflow.types.TInt32; |
24 | 28 |
|
25 | 29 | public class SignatureTest {
|
26 | 30 |
|
@@ -80,4 +84,28 @@ public void emptyMethodNameConvertedToNull() {
|
80 | 84 | signature = Signature.builder().key("f").methodName(null).build();
|
81 | 85 | assertNull(signature.methodName());
|
82 | 86 | }
|
| 87 | + |
| 88 | + @Test |
| 89 | + public void createTensorInfoFromOperandWithUnknownShape() { |
| 90 | + try (Graph g = new Graph()) { |
| 91 | + var tf = Ops.create(g); |
| 92 | + var placeholder = tf.placeholder(TInt32.class); |
| 93 | + var tensorInfo = Signature.Builder.toTensorInfo(placeholder.asOutput()); |
| 94 | + assertTrue(tensorInfo.getTensorShape().getUnknownRank()); |
| 95 | + assertEquals(0, tensorInfo.getTensorShape().getDimCount()); |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + @Test |
| 100 | + public void createTensorInfoFromOperandWithPartiallyUnknownShape() { |
| 101 | + try (Graph g = new Graph()) { |
| 102 | + var tf = Ops.create(g); |
| 103 | + var placeholder = tf.placeholder(TInt32.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE, 10))); |
| 104 | + var tensorInfo = Signature.Builder.toTensorInfo(placeholder.asOutput()); |
| 105 | + assertFalse(tensorInfo.getTensorShape().getUnknownRank()); |
| 106 | + assertEquals(2, tensorInfo.getTensorShape().getDimCount()); |
| 107 | + assertEquals(-1, tensorInfo.getTensorShape().getDim(0).getSize()); |
| 108 | + assertEquals(10, tensorInfo.getTensorShape().getDim(1).getSize()); |
| 109 | + } |
| 110 | + } |
83 | 111 | }
|
0 commit comments