Skip to content

Resync with origin/master #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
needs: prepare
strategy:
matrix:
ext: ["", -mkl] # , -gpu, -mkl-gpu]
ext: ["", -mkl, -gpu, -mkl-gpu]
steps:
- name: Install environment
run: |
Expand Down
4 changes: 4 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ in [IntelliJ](https://github.com/google/styleguide/blob/gh-pages/intellij-java-g
[Eclipse](https://github.com/google/styleguide/blob/gh-pages/eclipse-java-google-style.xml).
[Google's C++ style guide](https://google.github.io/styleguide/cppguide.html) should also be used for C++ code.

### Dependencies

For dependencies, we can use anything compliant with [this list](https://opensource.google/docs/thirdparty/licenses/#notice), but we want to keep the core libraries as dependency free as possible.

### Code generation

Code generation for `Ops` and related classes is done during `tensorflow-core-api`'s `compile` phase, using the annotation processor in
Expand Down
4 changes: 3 additions & 1 deletion tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.0</version>
<version>3.0.0-M5</version>
<executions>
<execution>
<!--
Expand All @@ -389,6 +389,8 @@
</execution>
</executions>
<configuration>
<!-- Activate the use of TCP to transmit events to the plugin -->
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
<additionalClasspathElements>
<additionalClasspathElement>${project.build.directory}/${project.artifactId}-${project.version}-${native.classifier}.jar</additionalClasspathElement>
<!-- Note: the following path is not accessible in deploying profile, so other libraries like
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ class EagerOperation extends AbstractOperation {
this.name = name;
this.opHandle = opNativeHandle;
this.outputHandles = outputNativeHandles;
session.attach(opNativeHandle);
session.attach(outputNativeHandles);
this.outputTensors = new AtomicReferenceArray<>(outputNativeHandles.length);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,7 @@ final class EagerOperationBuilder implements OperationBuilder {
@Override
public EagerOperation build() {
TFE_TensorHandle[] tensorHandles = execute(opHandle, session);
EagerOperation operation =
new EagerOperation(session, opHandle, tensorHandles, type, name);
// Release our reference to the native op handle now that we transferred its
// ownership to the EagerOperation
session.detach(opHandle);
return operation;
return new EagerOperation(session, opHandle, tensorHandles, type, name);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.internal.WeakPointerScope;
import org.tensorflow.internal.c_api.TFE_Context;
import org.tensorflow.internal.c_api.TFE_ContextOptions;
import org.tensorflow.internal.c_api.TF_Status;
Expand Down Expand Up @@ -310,13 +311,45 @@ TFE_Context nativeHandle() {
return nativeHandle;
}

/**
* Attach the list of native resources to this eager session scope.
*
* <p>When the eager session is closed (i.e. by calling {@link #close()} explicitly or
* implicitly via try-with-resources), all native resources attached to the session will be
* released as well, unless so other references are {@link Pointer#retainReference() retaining}
* them.</p>
*
* <p>Attached resources can still be garbage collected though if their associated {@link Pointer}
* is no longer reachable in Java, independently of their reference count. Therefore, it is
* assumed that these resources are not required by the native library once the Java client no
* longer needs them.</p>
*
* <p>Attaching a resource already attached to this session will have no effect.</p>
*
* @param resources resources to attach to the session
*/
void attach(Pointer... resources) {
checkSession();
for (Pointer r : resources) {
nativeResources.attach(r);
}
}

/**
* Detach a list of resources from this eager session scope.
*
* <p>Detached native resources will prevent them to be automatically released when the session is
* closed.</p>
*
* <p>Note though that this method will decrement the reference count of each resources being
* detached, which may automatically released them if that count reaches 0. Therefore,
* invoking {@link Pointer#retainReference()} prior to this call on any resource that must remain
* valid after being detached might be required.</p>
*
* <p>Detaching a resource that is not attached to this session will have no effect.</p>
*
* @param resources resources to detach from the session
*/
void detach(Pointer... resources) {
checkSession();
for (Pointer r : resources) {
Expand All @@ -326,14 +359,12 @@ void detach(Pointer... resources) {

private static volatile EagerSession defaultSession = null;

private final PointerScope nativeResources;
private final WeakPointerScope nativeResources;
private TFE_Context nativeHandle;

private EagerSession(Options options) {
try (PointerScope scope = new PointerScope()) {
this.nativeResources = scope.extend();
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
}
this.nativeResources = new WeakPointerScope();
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
}

private void checkSession() {
Expand Down Expand Up @@ -363,7 +394,7 @@ private static TFE_Context allocate(boolean async, int devicePlacementPolicy, Co
TFE_ContextOptionsSetDevicePlacementPolicy(opts, devicePlacementPolicy);
TFE_Context context = TFE_NewContext(opts, status);
status.throwExceptionIfNotOK();
return context;
return context.retainReference();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import com.google.protobuf.InvalidProtocolBufferException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
Expand All @@ -47,6 +48,7 @@
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Identity;
import org.tensorflow.op.core.NoOp;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.train.Restore;
Expand Down Expand Up @@ -439,15 +441,32 @@ public Output<?>[] whileLoop(
* Return the {@link SaverDef} instance used to save the state of all variables present in
* this graph.
*
* <p/>On the first call of this method, all nodes necessary to save and restore the state of the
* variables are added to the graph. Consequently, any variables that are added to the graph after
* this call could not be saved nor restored using this {@link SaverDef}.
* <p/> The first time this method is called it builds the {@link SaverDef}. If this graph already
* contains a "save/restore_all" operation then it is assumed to contain all necessary saving and
* restoring operations. If that operation does not exist then the graph is mutated to add all
* the nodes necessary to save and restore the state of the graph. Consequently, any variables
* that are added to the graph after this call will not be saved nor restored using this
* {@link SaverDef}.
*
* @return a {@link SaverDef} instance
*/
synchronized SaverDef saverDef() {
if (saverDef == null) {
saverDef = addVariableSaver(this);
// Check to see if this graph has a restore operation
if (operation("save/restore_all") == null) {
// No saver, create one by mutating the graph
saverDef = addVariableSaver(this);
} else {
// This graph already has saving/restoring operations,
// regenerate SaverDef without mutating. The names mirror
// the python implementation for compatibility.
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
saverDef = SaverDef.newBuilder()
.setFilenameTensorName("save/filename")
.setSaveTensorName("save/control_dependency")
.setRestoreOpName("save/restore_all")
.build();
}
}
return saverDef;
}
Expand Down Expand Up @@ -798,13 +817,15 @@ private static SaverDef addVariableSaver(Graph graph) {
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);

Placeholder<TString> saveFilename = tf.placeholder(TString.class);
Placeholder<TString> saveFilename = tf.withName("filename").placeholder(TString.class);
Save saveVariables = tf.train.save(
saveFilename,
varNamesTensor,
varSlices,
varOutputs
);
Identity<TString> id = tf.withControlDependencies(Arrays.asList(saveFilename,saveVariables))
.withName("control_dependency").identity(saveFilename);
Restore restoreVariables = tf.train.restore(
saveFilename,
varNamesTensor,
Expand All @@ -815,11 +836,11 @@ private static SaverDef addVariableSaver(Graph graph) {
for (int i = 0; i < varOutputs.size(); ++i) {
restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
}
NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp();
NoOp restoreAll = tf.withControlDependencies(restoreOps).withName("restore_all").noOp();

return SaverDef.newBuilder()
.setFilenameTensorName(saveFilename.op().name())
.setSaveTensorName(saveVariables.op().name())
.setSaveTensorName(id.op().name())
.setRestoreOpName(restoreAll.op().name())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,16 +512,35 @@ public void runInit(){
* <i>mymodel/myvariables</i> and named <i>variables.data-*-of-*</i>
*
* <p>Note that this method might alter the underlying graph if it is the first time that one
* of its session is saved, see {@link Graph#saverDef()} for more details.
* of its sessions is saved, see {@link Graph#saverDef()} for more details.
*
* @param prefix prefix to the variable files to save
*/
public void save(String prefix) {
SaverDef saverDef = graph.saverDef();
runner()
.addTarget(saverDef.getSaveTensorName())
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
.run();
runner().addTarget(saverDef.getSaveTensorName())
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
.run();
}

/**
* Restore the actual state of the variables of this session's graph.
*
* <p>{@code prefix} is the path where the files containing the variables state live,
* followed by the filename prefix. For example, if {@code prefix} is set to
* <i>mymodel/myvariables/variables</i>, then the files are loaded from
* <i>mymodel/myvariables</i> and named <i>variables.data-*-of-*</i>
*
* <p>Note that this method might alter the underlying graph if it is the first time that one
* of its sessions is saved, see {@link Graph#saverDef()} for more details.
*
* @param prefix prefix to restore from
*/
public void restore(String prefix) {
SaverDef saverDef = graph.saverDef();
runner().addTarget(saverDef.getRestoreOpName())
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
.run();
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.tensorflow.internal;

import java.util.Collections;
import java.util.Set;
import java.util.WeakHashMap;
import org.bytedeco.javacpp.Pointer;

/**
* A minimalist pointer scope only keeping weak references to its elements.
*
* <p>As opposed to {@link org.bytedeco.javacpp.PointerScope}, instances of this class will not
* prevent the garbage collector to free the memory of a pointer that is no longer reachable, even
* if it has been attached to the scope.</p>
*
* <p>When the scope is closed, all pointers that are still valid will be automatically deallocated
* while those already garbage-collected will be ignored.</p>
*/
public class WeakPointerScope implements AutoCloseable {

/**
* Attach a pointer to this scope.
*
* <p>Pointers attached to the scope will be automatically freed once the scope is closed, unless
* they have been already released by the garbage collector</p>
*
* <p>It this {@code pointer} was already attached to this scope, this method has no effect.</p>
*
* @param pointer pointer to attach
* @throws IllegalStateException if that scope has already been closed
*/
public void attach(Pointer pointer) {
checkScope();
if (pointers.add(pointer)) {
pointer.retainReference();
}
}

/**
* Detach a pointer from this scope.
*
* <p>Detaching a pointer from the scope will prevent its memory to be freed when closing the
* scope.</p>
*
* <p>If this {@code pointer} is not attached to this scope, this method has no effect.</p>
*
* @param pointer pointer to detach
* @throws IllegalStateException if that scope has already been closed
*/
public void detach(Pointer pointer) {
checkScope();
if (pointers.remove(pointer)) {
pointer.releaseReference();
}
}

@Override
public synchronized void close() {
checkScope();
pointers.forEach(Pointer::releaseReference);
pointers = null;
}

private Set<Pointer> pointers = Collections.newSetFromMap(new WeakHashMap<>());

private void checkScope() {
if (pointers == null) {
throw new IllegalStateException("Pointer scope has been closed");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ public class EagerOperationTest {
public void failToCreateIfSessionIsClosed() {
EagerSession session = EagerSession.create();
session.close();
try {
new EagerOperation(session, null, null, "Add", "add");
try (TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) {
EagerOperation op =
opBuilder(session, "Const", "OutputAttrs")
.setAttr("dtype", t.dataType())
.setAttr("value", t)
.build();
fail();
} catch (IllegalStateException e) {
// expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ public void cleanupResourceInBackground() {
sleep(50); // allow some time to the background thread for cleaning up resources

long before = Pointer.totalBytes();
s.detach(ref.retainReference());
ref = null;
System.gc();
sleep(50); // allow some time to the background thread for cleaning up resources
Expand Down
Loading