Skip to content

Commit 5532bf0

Browse files
authored
Constraints - 1 (#215)
* Initial Checkin * Clean up JavaDoc Change float attributes to double * Refactor Constraint to only have Generic parameter on call method. Add norm method on Constraint that is leveraged by the xxxxNorm constraints. Fix unit test cases to properly test the actual classes (oops). Fix Javadoc
1 parent cdd0298 commit 5532bf0

File tree

11 files changed

+827
-10
lines changed

11 files changed

+827
-10
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.constraints;
16+
17+
import org.tensorflow.Operand;
18+
import org.tensorflow.op.Ops;
19+
import org.tensorflow.op.core.ReduceSum;
20+
import org.tensorflow.types.family.TNumber;
21+
22+
import static org.tensorflow.framework.utils.CastHelper.cast;
23+
24+
/** Base class for Constraints. Constraint subclasses impose constraints on weight values */
25+
public abstract class Constraint {
26+
27+
public static final float EPSILON = 1e-7f;
28+
29+
private final Ops tf;
30+
31+
/**
32+
* Creates a Constraint
33+
*
34+
* @param tf the TensorFlow Ops
35+
*/
36+
public Constraint(Ops tf) {
37+
this.tf = tf;
38+
}
39+
40+
/**
41+
* Applies the constraint against the provided weights
42+
*
43+
* @param weights the weights
44+
* @return the constrained weights
45+
*/
46+
public abstract <T extends TNumber> Operand<T> call(Operand<T> weights);
47+
48+
/**
49+
* Gets the TensorFlow Ops
50+
*
51+
* @return the TensorFlow Ops
52+
*/
53+
public Ops getTF() {
54+
return tf;
55+
}
56+
57+
/**
58+
* Gets the element-wise square root.
59+
*
60+
* @param x the input Operand.
61+
* @return the element-wise square root.
62+
* @param <T> The data type for the operand and result.
63+
* @throws IllegalArgumentException if x is null
64+
*/
65+
protected <T extends TNumber> Operand<T> sqrt(Operand<T> x) {
66+
if (x == null) throw new IllegalArgumentException("Operand x must not be null");
67+
Class<T> type = x.type();
68+
Operand<T> zero = cast(tf, tf.constant(0), type);
69+
Operand<T> inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type);
70+
return tf.math.sqrt(tf.clipByValue(x, zero, inf));
71+
}
72+
73+
/**
74+
* Gets the element-wise value clipping.
75+
*
76+
* @param x the Operand to clip
77+
* @param minValue the minimum value
78+
* @param maxValue the maximum value
79+
* @return the operand with clipped values
80+
* @param <T> The data type for the operand and result.
81+
* @throws IllegalArgumentException if x is null
82+
*/
83+
protected <T extends TNumber> Operand<T> clip(Operand<T> x, double minValue, double maxValue) {
84+
if (x == null) throw new IllegalArgumentException("Operand x must not be null");
85+
Ops tf = getTF();
86+
Class<T> type = x.type();
87+
88+
double min = Math.min(minValue, maxValue);
89+
double max = Math.max(minValue, maxValue);
90+
91+
Operand<T> minValueConstant = cast(tf, tf.constant(min), type);
92+
Operand<T> maxValueConstant = cast(tf, tf.constant(max), type);
93+
return tf.clipByValue(x, minValueConstant, maxValueConstant);
94+
}
95+
96+
/**
97+
* Calculates the norm of the weights along the axes
98+
*
99+
* @param weights the weights used to calculate the norms
100+
* @param axes the axes along which to calculate weight norms.
101+
* @param <T> the data type for the weights and the result
102+
* @return the norms
103+
* @throws IllegalArgumentException if weights is null
104+
*/
105+
protected <T extends TNumber> Operand<T> norm(Operand<T> weights, int[] axes) {
106+
if (weights == null) throw new IllegalArgumentException("weights must not be null");
107+
return sqrt(
108+
tf.reduceSum(tf.math.square(weights), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)));
109+
}
110+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.constraints;
16+
17+
import org.tensorflow.Operand;
18+
import org.tensorflow.op.Ops;
19+
import org.tensorflow.types.family.TNumber;
20+
21+
import static org.tensorflow.framework.utils.CastHelper.cast;
22+
23+
/**
24+
* Constrains the weights incident to each hidden unit to have a norm less than or equal to a
25+
* desired value.
26+
*/
27+
public class MaxNorm extends Constraint {
28+
public static final double MAX_VALUE_DEFAULT = 2.0;
29+
public static final int AXIS_DEFAULT = 0;
30+
31+
/** the maximum norm for the incoming weights. */
32+
private final double maxValue;
33+
/** integer, axis along which to calculate weight norms. */
34+
private final int[] axes;
35+
36+
/**
37+
* Create a MaxNorm constraint using {@link #MAX_VALUE_DEFAULT} for the max value and {@link
38+
* #AXIS_DEFAULT} for the axis.
39+
*
40+
* @param tf the TensorFlow Ops
41+
*/
42+
public MaxNorm(Ops tf) {
43+
this(tf, MAX_VALUE_DEFAULT, AXIS_DEFAULT);
44+
}
45+
46+
/**
47+
* Create a MaxNorm constraint using {@link #AXIS_DEFAULT} for the axis.
48+
*
49+
* @param tf the TensorFlow Ops
50+
* @param maxValue the maximum norm for the incoming weights.
51+
*/
52+
public MaxNorm(Ops tf, double maxValue) {
53+
this(tf, maxValue, AXIS_DEFAULT);
54+
}
55+
56+
/**
57+
* Create a MaxNorm constraint
58+
*
59+
* @param tf the TensorFlow Ops
60+
* @param maxValue the maximum norm for the incoming weights.
61+
* @param axis axis along which to calculate weight norms.
62+
*/
63+
public MaxNorm(Ops tf, double maxValue, int axis) {
64+
this(tf, maxValue, new int[] {axis});
65+
}
66+
67+
/**
68+
* Create a MaxNorm constraint
69+
*
70+
* @param tf the TensorFlow Ops
71+
* @param maxValue the maximum norm for the incoming weights.
72+
* @param axes axes along which to calculate weight norms.
73+
*/
74+
public MaxNorm(Ops tf, double maxValue, int[] axes) {
75+
super(tf);
76+
this.maxValue = maxValue;
77+
this.axes = axes;
78+
}
79+
80+
/** {@inheritDoc} */
81+
@Override
82+
public <T extends TNumber> Operand<T> call(Operand<T> weights) {
83+
Ops tf = getTF();
84+
Class<T> type = weights.type();
85+
Operand<T> norms = norm(weights, getAxes());
86+
Operand<T> desired = clip(norms, 0f, this.getMaxValue());
87+
88+
return tf.math.mul(
89+
weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms)));
90+
}
91+
92+
/**
93+
* Gets the max value
94+
*
95+
* @return the maxValue
96+
*/
97+
public double getMaxValue() {
98+
return maxValue;
99+
}
100+
101+
/**
102+
* Gets the axes
103+
*
104+
* @return the axes
105+
*/
106+
public int[] getAxes() {
107+
return axes;
108+
}
109+
}
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.constraints;
16+
17+
import org.tensorflow.Operand;
18+
import org.tensorflow.op.Ops;
19+
import org.tensorflow.types.family.TNumber;
20+
21+
import static org.tensorflow.framework.utils.CastHelper.cast;
22+
23+
/** Constrains the weights to have the norm between a lower bound and an upper bound. */
24+
public class MinMaxNorm extends Constraint {
25+
public static final double MIN_VALUE_DEFAULT = 0.0;
26+
public static final double MAX_VALUE_DEFAULT = 1.0;
27+
public static final double RATE_DEFAULT = 1.0;
28+
public static final int AXIS_DEFAULT = 0;
29+
30+
/** the minimum norm for the incoming weights. */
31+
private final double minValue;
32+
/** the maximum norm for the incoming weights. */
33+
private final double maxValue;
34+
35+
/**
36+
* rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate *
37+
* norm.clip(min_value, max_value). Effectively, this means that rate=1.0 stands for strict
38+
* enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step
39+
* to slowly move towards a value inside the desired interval.
40+
*/
41+
private final double rate;
42+
43+
/** axis along which to calculate weight norms. */
44+
private final int[] axes;
45+
46+
/**
47+
* Create a MinMaxNorm constraint using {@link #MIN_VALUE_DEFAULT} for the min value, {@link
48+
* #MAX_VALUE_DEFAULT} for the max value, {@link #RATE_DEFAULT} for the rate and {@link
49+
* #AXIS_DEFAULT} for the axis
50+
*
51+
* @param tf the TensorFlow Ops
52+
*/
53+
public MinMaxNorm(Ops tf) {
54+
this(tf, MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT);
55+
}
56+
57+
/**
58+
* Create a MinMaxNorm constraint using {@link #RATE_DEFAULT} for the rate and {@link
59+
* #AXIS_DEFAULT} for the axis
60+
*
61+
* @param tf the TensorFlow Ops
62+
* @param minValue the minimum norm for the incoming weights.
63+
* @param maxValue the maximum norm for the incoming weights.
64+
*/
65+
public MinMaxNorm(Ops tf, double minValue, double maxValue) {
66+
this(tf, minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT);
67+
}
68+
69+
/**
70+
* Create a MinMaxNorm constraint
71+
*
72+
* @param tf the TensorFlow Ops
73+
* @param minValue the minimum norm for the incoming weights.
74+
* @param maxValue the maximum norm for the incoming weights.
75+
* @param rate the rate for enforcing the constraint.
76+
* @param axis integer, axis along which to calculate weight norms.
77+
*/
78+
public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int axis) {
79+
this(tf, minValue, maxValue, rate, new int[] {axis});
80+
}
81+
/**
82+
* Create a MinMaxNorm constraint
83+
*
84+
* @param tf the TensorFlow Ops
85+
* @param minValue the minimum norm for the incoming weights.
86+
* @param maxValue the maximum norm for the incoming weights.
87+
* @param rate the rate for enforcing the constraint.
88+
* @param axes integer, axis along which to calculate weight norms.
89+
*/
90+
public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] axes) {
91+
super(tf);
92+
this.minValue = minValue;
93+
this.maxValue = maxValue;
94+
this.rate = rate;
95+
this.axes = axes;
96+
}
97+
98+
/** {@inheritDoc} */
99+
@Override
100+
public <T extends TNumber> Operand<T> call(Operand<T> weights) {
101+
Class<T> type = weights.type();
102+
Ops tf = getTF();
103+
Operand<T> norms = norm(weights, getAxes());
104+
Operand<T> desired =
105+
tf.math.add(
106+
tf.math.mul(
107+
tf.dtypes.cast(tf.constant(this.getRate()), type),
108+
clip(norms, this.getMinValue(), this.getMaxValue())),
109+
tf.math.mul(
110+
tf.math.sub(
111+
tf.dtypes.cast(tf.constant(1), type),
112+
tf.dtypes.cast(tf.constant(this.getRate()), type)),
113+
norms));
114+
115+
return tf.math.mul(
116+
weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms)));
117+
}
118+
119+
/**
120+
* Gets the minValue
121+
*
122+
* @return the minValue
123+
*/
124+
public double getMinValue() {
125+
return minValue;
126+
}
127+
128+
/**
129+
* Gets the maxValue
130+
*
131+
* @return the maxValue
132+
*/
133+
public double getMaxValue() {
134+
return maxValue;
135+
}
136+
137+
/**
138+
* Gets the rate
139+
*
140+
* @return the rate
141+
*/
142+
public double getRate() {
143+
return rate;
144+
}
145+
146+
/**
147+
* Gets the axes
148+
*
149+
* @return the axes
150+
*/
151+
public int[] getAxes() {
152+
return axes;
153+
}
154+
}

0 commit comments

Comments
 (0)