Skip to content

Commit 26672b3

Browse files
committed
optional broadcasting + test
Signed-off-by: Ryan Nett <[email protected]>
1 parent b73209f commit 26672b3

File tree

3 files changed

+124
-52
lines changed

3 files changed

+124
-52
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,6 @@ public Options axis(Integer axis) {
149149
return this;
150150
}
151151

152-
/**
153-
* @param axis (Optional) The axis to mask from, or 0 if not set.
154-
*/
155-
public Options axis(int axis) {
156-
this.axis = axis;
157-
return this;
158-
}
159-
160152
private Integer axis;
161153

162154
private Options() {

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java

Lines changed: 82 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,74 @@
3232
@Operator
3333
public abstract class BooleanMaskUpdate {
3434

35+
/*
36+
Python:
37+
def boolean_mask_update(tensor, mask, update, axis=0, name="boolean_mask_update"):
38+
with tf.name_scope(name):
39+
tensor = tf.convert_to_tensor(tensor, name="tensor")
40+
mask = tf.convert_to_tensor(mask, name="mask")
41+
update = tf.convert_to_tensor(update, name="value")
42+
43+
shape_mask = mask.get_shape()
44+
ndims_mask = shape_mask.ndims
45+
shape_tensor = tensor.get_shape()
46+
if ndims_mask == 0:
47+
raise ValueError("mask cannot be scalar.")
48+
if ndims_mask is None:
49+
raise ValueError(
50+
"Number of mask dimensions must be specified, even if some dimensions"
51+
" are None. E.g. shape=[None] is ok, but shape=None is not.")
52+
axis = 0 if axis is None else axis
53+
axis_value = tf.constant(axis)
54+
if axis_value is not None:
55+
axis = axis_value
56+
shape_tensor[axis:axis + ndims_mask].assert_is_compatible_with(shape_mask)
57+
58+
leading_size = tf.reduce_prod(tf.shape(tensor)[:axis + ndims_mask], [0])
59+
innerShape = tf.shape(tensor)[axis + ndims_mask:]
60+
61+
tensor = tf.reshape(
62+
tensor,
63+
tf.concat([
64+
[leading_size],
65+
innerShape
66+
], 0))
67+
68+
indices = tf.where(mask)
69+
70+
updateShape = tf.concat([tf.shape(indices)[:-1], innerShape], 0)
71+
72+
update = tf.broadcast_to(update, updateShape)
73+
result = tf.tensor_scatter_nd_update(tensor, indices, update)
74+
return tf.reshape(result, shape_tensor)
75+
*/
76+
3577
/**
3678
* TODO
3779
*
38-
* @param scope
3980
* @param tensor The tensor to mask.
4081
* @param mask The mask to apply.
41-
* @param value the new values
82+
* @param updates the new values
4283
* @param options carries optional attributes values
4384
* @return The masked tensor.
4485
*/
4586
@Endpoint(name = "booleanMaskUpdate")
46-
public static <T extends TType> Operand<T> create(Scope scope, Operand<T> tensor, Operand<TBool> mask, Operand<T> value,
87+
public static <T extends TType> Operand<T> create(Scope scope, Operand<T> tensor, Operand<TBool> mask,
88+
Operand<T> updates,
4789
Options... options) {
4890

4991
scope = scope.withNameAsSubScope("BooleanMaskUpdate");
5092

5193
int axis = 0;
94+
boolean broadcast = true;
5295
if (options != null) {
5396
for (Options opts : options) {
5497
if (opts.axis != null) {
5598
axis = opts.axis;
5699
}
100+
if (opts.broadcast != null) {
101+
broadcast = opts.broadcast;
102+
}
57103
}
58104
}
59105

@@ -77,7 +123,7 @@ public static <T extends TType> Operand<T> create(Scope scope, Operand<T> tensor
77123
"Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + ".");
78124
}
79125

80-
org.tensorflow.op.core.Shape<TInt32> liveShape = org.tensorflow.op.core.Shape.create(scope, tensor);
126+
Operand<TInt32> liveShape = org.tensorflow.op.core.Shape.create(scope, tensor);
81127

82128
Operand<TInt32> leadingSize = ReduceProd.create(scope,
83129
StridedSliceHelper.stridedSlice(scope,
@@ -87,40 +133,55 @@ public static <T extends TType> Operand<T> create(Scope scope, Operand<T> tensor
87133
Constant.arrayOf(scope, 0)
88134
);
89135

136+
Operand<TInt32> innerShape = StridedSliceHelper
137+
.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions()));
138+
90139
Operand<T> reshaped = Reshape.create(scope, tensor, Concat.create(
91140
scope,
92141
Arrays.asList(
93142
Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)),
94-
StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions()))
143+
innerShape
95144
),
96145
Constant.scalarOf(scope, 0)
97146
));
98147

99148
Operand<TInt64> indices = Where.create(scope, mask);
100-
//TODO I'd like to broadcast value to the required shape. Need to figure out the shape first
101-
Operand<T> newValue = TensorScatterNdUpdate.create(scope, reshaped, indices, value);
149+
150+
if(broadcast) {
151+
Operand<TInt32> indicesShape = org.tensorflow.op.core.Shape.create(scope, indices);
152+
Operand<TInt32> batchShape = StridedSliceHelper.stridedSlice(scope, indicesShape, Indices.sliceTo(-1));
153+
154+
Operand<TInt32> updateShape = Concat.create(
155+
scope,
156+
Arrays.asList(
157+
batchShape,
158+
innerShape
159+
),
160+
Constant.scalarOf(scope, 0)
161+
);
162+
163+
updates = BroadcastTo.create(scope, updates, updateShape);
164+
}
165+
166+
Operand<T> newValue = TensorScatterNdUpdate.create(scope, reshaped, indices, updates);
102167
return Reshape.create(scope, newValue, liveShape);
103168
}
104169

105170
/**
106-
* Used to indicate the axis to mask from.
107-
* {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
171+
* Used to indicate the axis to mask from. {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
108172
* the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
173+
*
109174
* @param axis the axis to mask from. Uses 0 if null.
110175
*/
111-
public static Options axis(Integer axis){
176+
public static Options axis(Integer axis) {
112177
return new Options().axis(axis);
113178
}
114179

115-
116180
/**
117-
* Used to indicate the axis to mask from.
118-
* {@code axis + dim(mask) <= dim(tensor)} and {@code mask}'s shape must match
119-
* the first {@code axis + dim(mask)} dimensions of {@code tensor}'s shape.
120-
* @param axis the axis to mask from.
181+
* Whether to try broadcasting update. True by default.
121182
*/
122-
public static Options axis(int axis){
123-
return new Options().axis(axis);
183+
public static Options broadcast(Boolean broadcast){
184+
return new Options().broadcast(broadcast);
124185
}
125186

126187
/**
@@ -137,14 +198,15 @@ public Options axis(Integer axis) {
137198
}
138199

139200
/**
140-
* @param axis (Optional) The axis to mask from, or 0 if not set.
201+
* @param broadcast (Optional) Whether to try broadcasting update. True by default.
141202
*/
142-
public Options axis(int axis) {
143-
this.axis = axis;
203+
public Options broadcast(Boolean broadcast) {
204+
this.broadcast = broadcast;
144205
return this;
145206
}
146207

147208
private Integer axis;
209+
private Boolean broadcast;
148210

149211
private Options() {
150212
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
2020

21+
import java.util.List;
2122
import org.junit.Test;
2223
import org.tensorflow.Graph;
2324
import org.tensorflow.Operand;
2425
import org.tensorflow.Session;
26+
import org.tensorflow.Session.Run;
27+
import org.tensorflow.Tensor;
2528
import org.tensorflow.ndarray.Shape;
2629
import org.tensorflow.ndarray.index.Indices;
2730
import org.tensorflow.op.Scope;
@@ -37,26 +40,33 @@ public void testBooleanMaskUpdateSlice() {
3740
Session sess = new Session(g)) {
3841
Scope scope = new Scope(g);
3942

40-
Operand<TInt32> input = Constant.tensorOf(scope, new int[][]{ {0, 0, 0}, {1, 1, 1}, {2, 2, 2}});
43+
Operand<TInt32> input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}});
4144

4245
Operand<TBool> mask = Constant.arrayOf(scope, true, false, false);
4346

4447
Operand<TInt32> value = Constant.tensorOf(scope, new int[][]{{-1, -1, -1}});
4548

4649
Operand<TInt32> output = BooleanMaskUpdate.create(scope, input, mask, value);
4750

48-
try (TFloat32 result = (TFloat32) sess.runner().fetch(output).run().get(0)) {
49-
// expected shape from Python tensorflow
51+
Operand<TInt32> bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1));
52+
53+
List<Tensor> results = sess.runner().fetch(output).fetch(bcastOutput).run();
54+
try (TInt32 result = (TInt32) results.get(0);
55+
TInt32 bcastResult = (TInt32) results.get(1)) {
56+
5057
assertEquals(Shape.of(3, 3), result.shape());
51-
assertEquals(-1, result.getFloat(0, 0));
52-
assertEquals(-1, result.getFloat(0, 1));
53-
assertEquals(-1, result.getFloat(0, 2));
54-
assertEquals(1, result.getFloat(1, 0));
55-
assertEquals(1, result.getFloat(1, 1));
56-
assertEquals(1, result.getFloat(1, 2));
57-
assertEquals(2, result.getFloat(2, 0));
58-
assertEquals(2, result.getFloat(2, 1));
59-
assertEquals(2, result.getFloat(2, 2));
58+
59+
assertEquals(-1, result.getInt(0, 0));
60+
assertEquals(-1, result.getInt(0, 1));
61+
assertEquals(-1, result.getInt(0, 2));
62+
assertEquals(1, result.getInt(1, 0));
63+
assertEquals(1, result.getInt(1, 1));
64+
assertEquals(1, result.getInt(1, 2));
65+
assertEquals(2, result.getInt(2, 0));
66+
assertEquals(2, result.getInt(2, 1));
67+
assertEquals(2, result.getInt(2, 2));
68+
69+
assertEquals(result, bcastResult);
6070
}
6171
}
6272
}
@@ -75,19 +85,27 @@ public void testBooleanMaskUpdateAxis() {
7585

7686
Operand<TInt32> output = BooleanMaskUpdate.create(scope, input, mask, value, BooleanMaskUpdate.axis(2));
7787

78-
try (TFloat32 result = (TFloat32) sess.runner().fetch(output).run().get(0)) {
79-
// expected shape from Python tensorflow
88+
Operand<TInt32> bcastOutput = BooleanMaskUpdate
89+
.create(scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2));
90+
91+
List<Tensor> results = sess.runner().fetch(output).fetch(bcastOutput).run();
92+
try (TInt32 result = (TInt32) results.get(0);
93+
TInt32 bcastResult = (TInt32) results.get(1)) {
94+
8095
assertEquals(Shape.of(1, 1, 10), result.shape());
81-
assertEquals(-1, result.getFloat(0, 0, 0));
82-
assertEquals(-1, result.getFloat(0, 0, 1));
83-
assertEquals(2, result.getFloat(0, 0, 2));
84-
assertEquals(3, result.getFloat(0, 0, 3));
85-
assertEquals(-1, result.getFloat(0, 0, 4));
86-
assertEquals(-1, result.getFloat(0, 0, 5));
87-
assertEquals(-1, result.getFloat(0, 0, 6));
88-
assertEquals(7, result.getFloat(0, 0, 7));
89-
assertEquals(8, result.getFloat(0, 0, 8));
90-
assertEquals(9, result.getFloat(0, 0, 9));
96+
97+
assertEquals(-1, result.getInt(0, 0, 0));
98+
assertEquals(-1, result.getInt(0, 0, 1));
99+
assertEquals(2, result.getInt(0, 0, 2));
100+
assertEquals(3, result.getInt(0, 0, 3));
101+
assertEquals(-1, result.getInt(0, 0, 4));
102+
assertEquals(-1, result.getInt(0, 0, 5));
103+
assertEquals(-1, result.getInt(0, 0, 6));
104+
assertEquals(7, result.getInt(0, 0, 7));
105+
assertEquals(8, result.getInt(0, 0, 8));
106+
assertEquals(9, result.getInt(0, 0, 9));
107+
108+
assertEquals(result, bcastResult);
91109
}
92110
}
93111
}

0 commit comments

Comments
 (0)