Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.12</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-generator-annprocess</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
Expand Down
9,506 changes: 4,889 additions & 4,617 deletions tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public Shape shape(int outputIndex) {
for (int i = 0; i < shape.length; ++i) {
shape[i] = dim(outputNativeHandle, i);
}
return Shape.make(shape);
return Shape.of(shape);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,16 @@ public void close() {
doClose();
}

// Cleanup default session context for unit tests
static void closeDefaultForTest() {
synchronized (EagerSession.class) {
if (defaultSession != null) {
defaultSession.doClose();
defaultSession = null;
}
}
}

@Override
public OperationBuilder opBuilder(String type, String name) {
if (resourceCleanupStrategy == ResourceCleanupStrategy.ON_SAFE_POINTS) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ Shape shape(int outputIdx) {
Graph.Reference r = graph.ref();
try {
long[] shape = shape(r.nativeHandle(), getUnsafeNativeHandle(), outputIdx);
return shape == null ? Shape.unknown() : Shape.make(shape);
return shape == null ? Shape.unknown() : Shape.of(shape);
} finally {
r.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@
* <p>Example usage:
*
* <pre>{@code
* Ops tf = Ops.create();
*
* // The "decodeJpeg" operation can be used as an operand to the "cast" operation
* Operand<TUint8> decodeJpeg = ops.image.decodeJpeg(...);
* ops.dtypes.cast(decodeJpeg, TFloat32.DTYPE);
* Operand<TUint8> decodeJpeg = tf.image.decodeJpeg(...);
* tf.dtypes.cast(decodeJpeg, TFloat32.DTYPE);
*
* // The output "y" of the "unique" operation can be used as an operand to the "cast" operation
* Output<TInt32> y = ops.unique(...).y();
* ops.dtypes.cast(y, TFloat32.DTYPE);
* Output<TInt32> y = tf.unique(...).y();
* tf.dtypes.cast(y, TFloat32.DTYPE);
*
* // The "split" operation can be used as operand list to the "concat" operation
* Iterable<? extends Operand<TFloat32>> split = ops.split(...);
* ops.concat(split, ops.constant(0));
* Iterable<? extends Operand<TFloat32>> split = tf.split(...);
* tf.concat(split, tf.val(0));
* }</pre>
*/
public interface Operand<T extends TType> {
Expand All @@ -49,10 +51,25 @@ public interface Operand<T extends TType> {
Output<T> asOutput();

/**
* Returns the data of the tensor.
* Returns this operand as a tensor.
*
* <i>Only works when running in an eager execution</i>
* <p>This helper method is equivalent to {@code asOutput().tensor()}
*
* @return the tensor
* @throws IllegalStateException if this is an operand of a graph
*/
default Tensor<T> asTensor() {
return asOutput().tensor();
}

/**
* Returns the data of this operand.
*
* <i>This only works when running in an eager execution</i>
* <i>Only works when running in an eager execution</i>
* <p>This helper method is equivalent to {@code asTensor().data()}
*
* @return the tensor data
* @throws IllegalStateException if this is an operand of a graph
*/
default T data() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TString;
Expand Down Expand Up @@ -158,7 +158,7 @@ public static <T extends TType> Tensor<T> create(Object obj, DataType<T> dtype)
}
long[] dimSizes = new long[numDimensions(obj, dtype)];
fillShape(obj, 0, dimSizes);
Tensor<T> t = new Tensor(dtype, Shape.make(dimSizes));
Tensor<T> t = new Tensor(dtype, Shape.of(dimSizes));
TF_Tensor nativeHandle;
if (t.dtype != TString.DTYPE) {
long byteSize = elemByteSize(t.dtype) * t.shape.size();
Expand Down Expand Up @@ -290,25 +290,25 @@ public static <T extends TType> Tensor<T> create(DataType<T> dtype, long[] shape
return t;
}

public static <T extends TType> Tensor<T> allocate(DataType<T> dtype, Shape shape) {
return allocate(dtype, shape, shape.size() * dtype.byteSize());
public static <T extends TType> Tensor<T> of(DataType<T> dtype, Shape shape) {
return of(dtype, shape, shape.size() * dtype.byteSize());
}

public static <T extends TType> Tensor<T> allocate(DataType<T> dtype, Shape shape, long size) {
public static <T extends TType> Tensor<T> of(DataType<T> dtype, Shape shape, long size) {
Tensor<T> t = new Tensor<>(dtype, shape);
TF_Tensor nativeHandle = allocate(t.dtype.nativeCode(), shape.asArray(), size);
t.nativeRef = new NativeReference(nativeHandle);
return t;
}

public static <T extends TType> Tensor<T> allocate(DataType<T> dtype, Shape shape,
public static <T extends TType> Tensor<T> of(DataType<T> dtype, Shape shape,
Consumer<T> dataInitializer) {
return allocate(dtype, shape, shape.size() * dtype.byteSize(), dataInitializer);
return of(dtype, shape, shape.size() * dtype.byteSize(), dataInitializer);
}

public static <T extends TType> Tensor<T> allocate(DataType<T> dtype, Shape shape, long size,
public static <T extends TType> Tensor<T> of(DataType<T> dtype, Shape shape, long size,
Consumer<T> dataInitializer) {
Tensor<T> tensor = allocate(dtype, shape, size);
Tensor<T> tensor = of(dtype, shape, size);
try {
dataInitializer.accept(tensor.data());
return tensor;
Expand Down Expand Up @@ -350,7 +350,7 @@ private static <T extends TType> Tensor<T> allocateForBuffer(DataType<T> dataTyp
// DT_STRING tensor encoded in a ByteBuffer.
nbytes = nBuffered;
}
Tensor<T> t = new Tensor<>(dataType, Shape.make(dimSizes));
Tensor<T> t = new Tensor<>(dataType, Shape.of(dimSizes));
TF_Tensor nativeHandle = allocate(t.dtype.nativeCode(), dimSizes, nbytes);
t.nativeRef = new NativeReference(nativeHandle);
return t;
Expand Down Expand Up @@ -582,7 +582,7 @@ public String toString() {
* <p>Takes ownership of the handle.
*/
static Tensor<?> fromHandle(TF_Tensor handle) {
Tensor<?> t = new Tensor<>(DataTypes.fromNativeCode(dtype(handle)), Shape.make(shape(handle)));
Tensor<?> t = new Tensor<>(DataTypes.fromNativeCode(dtype(handle)), Shape.of(shape(handle)));
t.nativeRef = new NativeReference(handle);
return t;
}
Expand Down
Loading