From 9e9268791cddd0360cc646039994f20cb0bac808 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 5 Nov 2019 11:02:24 -0500 Subject: [PATCH 01/22] Initial commit of gradient descent optimizers. --- pom.xml | 1 + .../src/main/java/org/tensorflow/Graph.java | 31 ++ tensorflow-sandbox/pom.xml | 41 +++ .../org/tensorflow/sandbox/MNISTTest.java | 345 ++++++++++++++++++ .../sandbox/optimizers/AdaDelta.java | 83 +++++ .../sandbox/optimizers/AdaGrad.java | 71 ++++ .../sandbox/optimizers/AdaGradDA.java | 117 ++++++ .../tensorflow/sandbox/optimizers/Adam.java | 131 +++++++ .../sandbox/optimizers/GradientDescent.java | 45 +++ .../sandbox/optimizers/Momentum.java | 72 ++++ .../sandbox/optimizers/Optimizer.java | 207 +++++++++++ .../sandbox/optimizers/RMSProp.java | 105 ++++++ .../org/tensorflow/sandbox/util/Pair.java | 56 +++ 13 files changed, 1305 insertions(+) create mode 100644 tensorflow-sandbox/pom.xml create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java diff --git a/pom.xml b/pom.xml index c9093490563..a06ccc25a0c 100644 --- a/pom.xml +++ b/pom.xml @@ -31,6 +31,7 @@ tensorflow-tools tensorflow-core + tensorflow-sandbox diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 81bf0d52dbb..0d72d793ae1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -26,7 +26,9 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewGraph; import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewWhile; +import java.util.ArrayList; import java.util.Iterator; +import java.util.List; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerScope; @@ -38,6 +40,9 @@ import org.tensorflow.internal.c_api.TF_Output; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_WhileParams; +import org.tensorflow.op.Scope; +import org.tensorflow.op.core.NoOp; + /** * A data flow graph representing a TensorFlow computation. @@ -49,6 +54,8 @@ */ public final class Graph implements ExecutionEnvironment, AutoCloseable { + public static final String DEFAULT_INIT_NAME = "init"; + /** Create an empty Graph. */ public Graph() { nativeHandle = allocate(); @@ -166,6 +173,28 @@ public byte[] toGraphDef() { } } + /** + * Adds an initializer to the graph initializer list. + * @param initializer An initializer to add to the list. + */ + public synchronized void addInitializer(Operand initializer) { + initializers.add(initializer); + } + + /** + * Returns an op which initializers all the variables. + * @return The initializer operation. + */ + public NoOp variablesInitializer() { + return variablesInitializer(DEFAULT_INIT_NAME); + } + + public NoOp variablesInitializer(String name) { + Scope scope = new Scope(this); + scope = scope.withName(name).withControlDependencies(initializers); + return NoOp.create(scope); + } + /** * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} @@ -378,6 +407,8 @@ public Output[] whileLoop( private TF_Graph nativeHandle; private int refcount = 0; + private final List> initializers = new ArrayList<>(); + // Related native objects (such as the TF_Operation object backing an Operation instance) // have a validity tied to that of the Graph. The handles to those native objects are not // valid after Graph.close() has been invoked. diff --git a/tensorflow-sandbox/pom.xml b/tensorflow-sandbox/pom.xml new file mode 100644 index 00000000000..fd829f54097 --- /dev/null +++ b/tensorflow-sandbox/pom.xml @@ -0,0 +1,41 @@ + + + 4.0.0 + + + org.tensorflow + tensorflow-java + 0.1.0-SNAPSHOT + + tensorflow-sandbox + 0.1.0-SNAPSHOT + + + + org.tensorflow + tensorflow-core-api + 0.1.0-SNAPSHOT + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.0 + + 1.8 + 1.8 + 1.8 + 1.8 + + -Xlint:all + + + + + + + \ No newline at end of file diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java new file mode 100644 index 00000000000..76ffd102371 --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java @@ -0,0 +1,345 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.nio.nd.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Mean; +import org.tensorflow.op.nn.Conv2d; +import org.tensorflow.op.nn.MaxPool; +import org.tensorflow.op.nn.Relu; +import org.tensorflow.op.nn.Softmax; +import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; +import org.tensorflow.op.random.TruncatedNormal; +import org.tensorflow.sandbox.optimizers.AdaDelta; +import org.tensorflow.sandbox.optimizers.AdaGrad; +import org.tensorflow.sandbox.optimizers.AdaGradDA; +import org.tensorflow.sandbox.optimizers.Adam; +import org.tensorflow.sandbox.optimizers.GradientDescent; +import org.tensorflow.sandbox.optimizers.Momentum; +import org.tensorflow.sandbox.optimizers.Optimizer; +import org.tensorflow.sandbox.optimizers.RMSProp; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.TInt32; + +import java.io.BufferedInputStream; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.util.Arrays; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Builds a LeNet-5 style CNN for MNIST. + */ +public class MNISTTest { + + private static final Logger logger = Logger.getLogger(MNISTTest.class.getName()); + + private static final int PIXEL_DEPTH = 255; + private static final int NUM_CHANNELS = 1; + private static final int IMAGE_SIZE = 28; + private static final int NUM_LABELS = 10; + private static final long SEED = 123456789L; + + private static final String PADDING_TYPE = "SAME"; + + public static final String INPUT_NAME = "input"; + public static final String OUTPUT_NAME = "output"; + public static final String TARGET = "target"; + public static final String TRAIN = "train"; + public static final String TRAINING_LOSS = "training_loss"; + public static final String EPOCH = "epoch"; + public static final String INIT = "init"; + + public static Graph build(String optimizerName) { + Graph graph = new Graph(); + + Ops tf = Ops.create(graph); + + // Inputs + Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat.DTYPE, Placeholder.shape(Shape.make(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); + Placeholder labels = tf.withName(TARGET).placeholder(TInt32.DTYPE); + + // Scaling the features + Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); + Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); + Operand scaledInput = tf.math.div(tf.math.add(input, centeringFactor), scalingFactor); + + // First conv layer + Variable conv1Weights = tf.variable(Shape.make(5, 5, NUM_CHANNELS, 32), TFloat.DTYPE); + Assign weights1Init = tf.assign(conv1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv1Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + graph.addInitializer(weights1Init); + Conv2d conv1 = tf.nn.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable conv1Biases = tf.variable(Shape.make(32), TFloat.DTYPE); + Assign biases1Init = tf.assign(conv1Biases, tf.fill(tf.shape(conv1Biases), tf.constant(0.0f))); + graph.addInitializer(biases1Init); + Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); + + // First pooling layer + MaxPool pool1 = tf.nn.maxPool(relu1, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); + + // Second conv layer + Variable conv2Weights = tf.variable(Shape.make(5, 5, 32, 64), TFloat.DTYPE); + Assign weights2Init = tf.assign(conv2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv2Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + graph.addInitializer(weights2Init); + Conv2d conv2 = tf.nn.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable conv2Biases = tf.variable(Shape.make(64), TFloat.DTYPE); + Assign biases2Init = tf.assign(conv2Biases, tf.fill(tf.shape(conv2Biases), tf.constant(0.1f))); + graph.addInitializer(biases2Init); + Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); + + // Second pooling layer + MaxPool pool2 = tf.nn.maxPool(relu2, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); + + // Flatten inputs + Reshape flatten = tf.reshape(pool2, tf.concat(Arrays.asList(tf.slice(tf.shape(pool2), tf.constant(new int[]{0}), tf.constant(new int[]{1})), tf.constant(new int[]{-1})), tf.constant(0))); + + // Fully connected layer + Variable fc1Weights = tf.variable(Shape.make(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat.DTYPE); + Assign weights3Init = tf.assign(fc1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc1Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + graph.addInitializer(weights3Init); + Variable fc1Biases = tf.variable(Shape.make(512), TFloat.DTYPE); + Assign biases3Init = tf.assign(fc1Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc1Biases))); + graph.addInitializer(biases3Init); + Relu relu3 = tf.nn.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); + + // Softmax layer + Variable fc2Weights = tf.variable(Shape.make(512, NUM_LABELS), TFloat.DTYPE); + Assign weights4Init = tf.assign(fc2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc2Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + graph.addInitializer(weights4Init); + Variable fc2Biases = tf.variable(Shape.make(NUM_LABELS), TFloat.DTYPE); + Assign biases4Init = tf.assign(fc2Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc2Biases))); + graph.addInitializer(biases4Init); + + Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); + + // Predicted outputs + Softmax prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); + + // Loss function & regularization + OneHot oneHot = tf.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); + Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); + Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math.add(tf.nn.l2Loss(fc1Biases), tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); + Add loss = tf.withName(TRAINING_LOSS).math.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); + + // Optimizer + Optimizer optimizer; + switch (optimizerName) { + case "AdaDelta": + case "Adadelta": + case "adadelta": + optimizer = new AdaDelta(graph, 1f, 0.95f, 1e-8f); + break; + case "AdaGradDA": + case "AdagradDA": + case "adagradda": + optimizer = new AdaGradDA(graph, 0.01f); + break; + case "AdaGrad": + case "Adagrad": + case "adagrad": + optimizer = new AdaGrad(graph, 0.01f); + break; + case "Adam": + case "adam": + optimizer = new Adam(graph,0.001f,0.9f,0.999f,1e-8f); + break; + case "SGD": + case "sgd": + optimizer = new GradientDescent(graph,0.01f); + break; + case "Momentum": + case "momentum": + optimizer = new Momentum(graph, 0.01f, 0.9f, false); + break; + case "RMSProp": + case "rmsprop": + optimizer = new RMSProp(graph,0.01f, 0.9f, 0.0f, 1e-10f, false); + break; + default: + throw new IllegalArgumentException("Unknown optimizer " + optimizerName); + } + logger.info("Optimizer = " + optimizer.toString()); + Op minimize = optimizer.minimize(loss, TRAIN); + + Op init = graph.variablesInitializer(); + + return graph; + } + + public static void train(Session session, int epochs, int minibatchSize, float[][][][] data, int[] labels) { + // Initialises the parameters. + session.runner().addTarget(INIT).run(); + logger.info("Initialised the model parameters"); + + float[][][][] featureBatch = new float[minibatchSize][][][]; + int[] labelBatch = new int[minibatchSize]; + + int interval = 0; + for (int i = 0; i < epochs; i++) { + logger.log(Level.INFO, "Starting epoch " + i); + //Tensor epoch = Tensor.create(i); + for (int j = 0; j < data.length; j += minibatchSize) { + for (int k = j, m = 0; k < (j + minibatchSize) && k < data.length; k++, m++) { + featureBatch[m] = data[k]; + labelBatch[m] = labels[k]; + } + //logger.info("Batch = " + batch.size()); + Tensor input = Tensor.create(featureBatch); + Tensor target = Tensor.create(labelBatch); + Tensor loss = session.runner() + .feed(INPUT_NAME, input) + .feed(TARGET, target) + .addTarget(TRAIN) + .fetch(TRAINING_LOSS) + .run().get(0); + if (interval % 100 == 0) { + logger.log(Level.INFO, "Iteration = " + interval + ", training loss = " + loss.floatValue()); + } + input.close(); + target.close(); + loss.close(); + interval++; + } + //epoch.close(); + } + } + + /** + * Find the maximum probability and return it's index. + * + * @param probabilities The probabilites. + * @return The index of the max. + */ + public static int pred(float[] probabilities) { + float maxVal = Float.NEGATIVE_INFINITY; + int idx = 0; + for (int i = 0; i < probabilities.length; i++) { + if (probabilities[i] > maxVal) { + maxVal = probabilities[i]; + idx = i; + } + } + return idx; + } + + public static DataTuple loadData(String path) throws IOException, ClassNotFoundException { + try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) { + float[][][][] data = (float[][][][]) ois.readObject(); + int[] labels = (int[]) ois.readObject(); + return new DataTuple(data, labels); + } + } + + private static class DataTuple { + public final float[][][][] features; + public final int[] labels; + + public DataTuple(float[][][][] features, int[] labels) { + this.features = features; + this.labels = labels; + } + } + + public static void main(String[] args) throws IOException, ClassNotFoundException { + logger.info("Usage: MNISTTest "); + + logger.info("Loading training data"); + DataTuple train = loadData(args[3]); + logger.info("Loading testing data"); + DataTuple test = loadData(args[4]); + + logger.info("Loaded data."); + + float[][][][] trainData = train.features; + int[] trainLabels = train.labels; + + float[][][][] testData = test.features; + int[] testLabels = test.labels; + + logger.info("Loaded " + trainLabels.length + " training labels"); + logger.info("Loaded " + testLabels.length + " testing labels"); + + int epochs = Integer.parseInt(args[0]); + int minibatchSize = Integer.parseInt(args[1]); + + Graph graph = build(args[2]); + + int correctCount = 0; + int[][] confusionMatrix = new int[10][10]; + + try (Session session = new Session(graph)) { + train(session, epochs, minibatchSize, trainData, trainLabels); + + logger.info("Trained model"); + + float[][][][] featureBatch = new float[minibatchSize][][][]; + int[] labelBatch = new int[minibatchSize]; + float[][] prediction; + + for (int j = 0; j < testData.length; j += minibatchSize) { + for (int k = j, m = 0; k < (j + minibatchSize) && k < testData.length; k++, m++) { + featureBatch[m] = testData[k]; + labelBatch[m] = testLabels[k]; + } + try (Tensor transformedInput = Tensor.create(featureBatch); + Tensor outputTensor = session.runner() + .feed(INPUT_NAME, transformedInput) + .fetch(OUTPUT_NAME).run().get(0)) { + prediction = outputTensor.copyTo(new float[minibatchSize][NUM_LABELS]); + } + + for (int k = 0; k < labelBatch.length; k++) { + int predLabel; + + predLabel = pred(prediction[k]); + if (predLabel == labelBatch[k]) { + correctCount++; + } + + confusionMatrix[labelBatch[k]][predLabel]++; + } + + if (j % 1000 == 0) { + logger.log(Level.INFO, "Cur accuracy = " + ((float) correctCount) / (j + minibatchSize)); + } + } + + logger.info("Final accuracy = " + ((float) correctCount) / testLabels.length); + + StringBuilder sb = new StringBuilder(); + sb.append("Label"); + for (int i = 0; i < confusionMatrix.length; i++) { + sb.append(String.format("%1$5s", "" + i)); + } + sb.append("\n"); + + for (int i = 0; i < confusionMatrix.length; i++) { + sb.append(String.format("%1$5s", "" + i)); + for (int j = 0; j < confusionMatrix[i].length; j++) { + sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); + } + sb.append("\n"); + } + + System.out.println(sb.toString()); + } + + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java new file mode 100644 index 00000000000..f0bc847bdc5 --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java @@ -0,0 +1,83 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; + +/** + * Optimizer that implements the Adadelta algorithm. + * + * See the paper. + */ +public class AdaDelta extends Optimizer { + + public static final String ACCUMULATOR = "accum"; + public static final String ACCUMULATOR_UPDATE = "accum_update"; + + private final float learningRate; + + private final float rho; + + private final float epsilon; + + public AdaDelta(Graph graph, float learningRate) { + this(graph, learningRate, 0.95f, 1e-8f); + } + + public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { + super(graph); + this.learningRate = learningRate; + this.rho = rho; + this.epsilon = epsilon; + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createAdaDeltaSlot(v); + } + } + + private void createAdaDeltaSlot(Output v) { + Operand accumulatorInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); + Operand updateInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable accumSlot = (Variable) getSlot(variable,ACCUMULATOR).get(); + @SuppressWarnings("unchecked") + Variable accumUpdateSlot = (Variable) getSlot(variable,ACCUMULATOR_UPDATE).get(); + return tf.train.applyAdadelta(variable, accumSlot, accumUpdateSlot, + tf.constant(learningRate, gradient.dataType()), + tf.constant(rho, gradient.dataType()), + tf.constant(epsilon, gradient.dataType()), + gradient); + } + + @Override + public String toString() { + return "AdaDelta{" + + "learningRate=" + learningRate + + ", rho=" + rho + + ", epsilon=" + epsilon + + '}'; + } + + @Override + public String getOptimizerName() { + return "Adadelta"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java new file mode 100644 index 00000000000..6418e7b9d69 --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java @@ -0,0 +1,71 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; + +/** + * Optimizer that implements the Adagrad algorithm. + * + * See the paper + * or this intro. + */ +public class AdaGrad extends Optimizer { + + public static final String ACCUMULATOR = "accumulator"; + + private final float learningRate; + + private final float initialAccumulatorValue; + + public AdaGrad(Graph graph, float learningRate) { + this(graph, learningRate, 0.01f); + } + + public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { + super(graph); + this.learningRate = learningRate; + this.initialAccumulatorValue = initialAccumulatorValue; + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createAdaGradSlot(v); + } + } + + private void createAdaGradSlot(Output v) { + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), ACCUMULATOR, initializer); + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable slot = (Variable) getSlot(variable,ACCUMULATOR).get(); + return tf.train.applyAdagrad(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient); + } + + @Override + public String toString() { + return "AdaGrad{" + + "learningRate=" + learningRate + + ", initialAccumulatorValue=" + initialAccumulatorValue + + '}'; + } + + @Override + public String getOptimizerName() { + return "Adagrad"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java new file mode 100644 index 00000000000..1a4ff11d623 --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java @@ -0,0 +1,117 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.nio.nd.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +import java.util.List; +import java.util.Optional; + +/** + * Optimizer that implements the Adagrad Dual-Averaging algorithm. + * + * See the paper. + */ +public class AdaGradDA extends Optimizer { + + public static final String ACCUMULATOR = "gradient_accumulator"; + public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator"; + + private Variable globalStep; + + private final float learningRate; + + private final float initialAccumulatorValue; + + private final float l1Strength; + + private final float l2Strength; + + public AdaGradDA(Graph graph, float learningRate) { + this(graph, learningRate, 0.1f, 0.0f, 0.0f); + } + + public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, float l1Strength, float l2Strength) { + super(graph); + this.learningRate = learningRate; + this.initialAccumulatorValue = initialAccumulatorValue; + this.l1Strength = l1Strength; + this.l2Strength = l2Strength; + } + + @Override + protected Optional prepare(String name) { + return Optional.of(tf.assignAdd(globalStep,tf.constant(1L))); + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createAdaGradDASlot(v); + } + globalStep = tf.withName("adagrad-da-global-step").variable(Shape.make(),TInt64.DTYPE); + Assign globalStepInitializer = tf.assign(globalStep, tf.constant(0L)); + graph.addInitializer(globalStepInitializer); + } + + private void createAdaGradDASlot(Output v) { + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), ACCUMULATOR, initializer); + Operand sqInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable gradSlot = (Variable) getSlot(variable,ACCUMULATOR).get(); + @SuppressWarnings("unchecked") + Variable gradSquaredSlot = (Variable) getSlot(variable,SQUARED_ACCUMULATOR).get(); + return tf.train.applyAdagradDa(variable, gradSlot, gradSquaredSlot, gradient, + tf.constant(learningRate, gradient.dataType()), + tf.constant(l1Strength, gradient.dataType()), + tf.constant(l2Strength, gradient.dataType()), + globalStep); + } + + /** + * Gathers up the update operations into a single op that can be used as a run target. + *

+ * Adds the global step update to the end of the updates list. + * @param updateOperations The update operations. + * @param name The name of the run target. + * @return A NoOp with a control dependency on each update operation. + */ + @Override + protected Op finish(List> updateOperations, String name) { + updateOperations.add(tf.assignAdd(globalStep,tf.constant(1L))); + return super.finish(updateOperations,name); + } + + @Override + public String toString() { + return "AdaGradDA{" + + "globalStep=" + globalStep + + ", learningRate=" + learningRate + + ", initialAccumulatorValue=" + initialAccumulatorValue + + ", l1Strength=" + l1Strength + + ", l2Strength=" + l2Strength + + '}'; + } + + @Override + public String getOptimizerName() { + return "adagrad-da"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java new file mode 100644 index 00000000000..bc0342e7c04 --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java @@ -0,0 +1,131 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.nio.nd.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; +import java.util.Optional; + +/** + * Optimizer that implements the Adam algorithm. + * + * See the paper. + */ +public class Adam extends Optimizer { + + public static final String FIRST_MOMENT = "m"; + public static final String SECOND_MOMENT = "v"; + + private final float learningRate; + + private final float betaOne; + + private final float betaTwo; + + private final float epsilon; + + private Constant learningRateConst; + private Constant epsilonConst; + private Constant betaOneConst; + private Constant betaTwoConst; + private Variable betaOnePower; + private Variable betaTwoPower; + + public Adam(Graph graph, float learningRate) { + this(graph, learningRate, 0.9f, 0.999f, 1e-8f); + } + + public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { + super(graph); + this.learningRate = learningRate; + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createAdamSlot(v); + } + betaOnePower = tf.withName("beta1_power").variable(Shape.make(),TFloat.DTYPE); + Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne, TFloat.DTYPE)); + graph.addInitializer(betaOnePowerInit); + betaTwoPower = tf.withName("beta2_power").variable(Shape.make(),TFloat.DTYPE); + Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo, TFloat.DTYPE)); + graph.addInitializer(betaTwoPowerInit); + } + + @Override + protected Optional prepare(String scopeName) { + betaOneConst = tf.constant(betaOne); + betaTwoConst = tf.constant(betaTwo); + learningRateConst = tf.constant(learningRate); + epsilonConst = tf.constant(epsilon); + return Optional.empty(); + } + + private void createAdamSlot(Output v) { + Operand firstMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); + Operand secondMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable firstMomentSlot = (Variable) getSlot(variable,FIRST_MOMENT).get(); + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable secondMomentSlot = (Variable) getSlot(variable,SECOND_MOMENT).get(); + return tf.train.applyAdam(variable, firstMomentSlot, secondMomentSlot, + tf.dtypes.cast(betaOnePower,gradient.dataType()), + tf.dtypes.cast(betaTwoPower,gradient.dataType()), + tf.dtypes.cast(learningRateConst,gradient.dataType()), + tf.dtypes.cast(betaOneConst,gradient.dataType()), + tf.dtypes.cast(betaTwoConst,gradient.dataType()), + tf.dtypes.cast(epsilonConst,gradient.dataType()), + gradient); + } + + /** + * Gathers up the update operations into a single op that can be used as a run target. + *

+ * Adds the betaOne and betaTwo updates to the end of the updates list. + * @param updateOperations The update operations. + * @param name The name of the run target. + * @return A NoOp with a control dependency on each update operation. + */ + @Override + protected Op finish(List> updateOperations, String name) { + updateOperations.add(tf.assign(betaOnePower,tf.math.mul(betaOnePower,betaOneConst))); + updateOperations.add(tf.assign(betaTwoPower,tf.math.mul(betaTwoPower,betaTwoConst))); + return super.finish(updateOperations,name); + } + + @Override + public String toString() { + return "Adam{" + + "learningRate=" + learningRate + + ", betaOne=" + betaOne + + ", betaTwo=" + betaTwo + + ", epsilon=" + epsilon + + '}'; + } + + @Override + public String getOptimizerName() { + return "Adam"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java new file mode 100644 index 00000000000..e7aa095367a --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java @@ -0,0 +1,45 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.train.ApplyMomentum; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; + +/** + * Basic SGD. + */ +public class GradientDescent extends Optimizer { + + private final float learningRate; + + public GradientDescent(Graph graph, float learningRate) { + super(graph); + this.learningRate = learningRate; + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + return tf.train.applyGradientDescent(variable, tf.constant(learningRate, gradient.dataType()), gradient); + } + + @Override + public String toString() { + return "GradientDescent{" + + "learningRate=" + learningRate + + '}'; + } + + @Override + public String getOptimizerName() { + return "GradientDescent"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java new file mode 100644 index 00000000000..5feeb9faa1e --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java @@ -0,0 +1,72 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.train.ApplyMomentum; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; + +/** + * SGD plus momentum, either nesterov or traditional. + * + * See the paper for + * details of nesterov momentum. + */ +public class Momentum extends Optimizer { + + public static final String MOMENTUM = "momentum"; + + private final float learningRate; + + private final float momentum; + + private final boolean useNesterov; + + public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { + super(graph); + this.learningRate = learningRate; + this.momentum = momentum; + this.useNesterov = useNesterov; + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createMomentumSlot(v); + } + } + + private void createMomentumSlot(Output v) { + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), MOMENTUM, initializer); + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable slot = (Variable) getSlot(variable,MOMENTUM).get(); + return tf.train.applyMomentum(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient, tf.constant(momentum, gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); + } + + @Override + public String toString() { + return "Momentum{" + + "learningRate=" + learningRate + + ", momentum=" + momentum + + ", useNesterov=" + useNesterov + + '}'; + } + + @Override + public String getOptimizerName() { + return "Momentum"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java new file mode 100644 index 00000000000..494a2d650fb --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java @@ -0,0 +1,207 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.NoOp; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; +import org.tensorflow.sandbox.util.Pair; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * + */ +public abstract class Optimizer { + public static final String VARIABLE_V2 = "VariableV2"; + + /** + * Top level map is variable name, bottom map is slot name. + */ + private final Map>> slots; + + /** + * Global state variables + */ + //TODO make this be used. + protected final List globals; + + /** + * The Graph this optimizer is operating on. + */ + protected final Graph graph; + + /** + * The ops builder for the graph. + */ + protected final Ops tf; + + protected Optimizer(Graph graph) { + this.graph = graph; + this.tf = Ops.create(graph).withName(getOptimizerName()); + this.slots = new HashMap<>(); + this.globals = new ArrayList<>(); + } + + public Op minimize(Operand loss) { + return minimize(loss, getOptimizerName()+"-minimize"); + } + + public Op minimize(Operand loss, String name) { + List, Output>> gradsAndVars = computeGradients(loss); + + return applyGradients(gradsAndVars, name); + } + + public List, Output>> computeGradients(Operand loss) { + List variables = new ArrayList<>(); + Iterator opItr = graph.operations(); + while (opItr.hasNext()) { + Operation op = opItr.next(); + if (op.type().equals(VARIABLE_V2)) { + variables.add(op); + } + } + + Output[] variableOutputArray = new Output[variables.size()]; + for (int i = 0; i < variables.size(); i++) { + // First output of a variable is it's output. + variableOutputArray[i] = variables.get(i).output(0); + } + + Output[] gradients = graph.addGradients(loss.asOutput(), variableOutputArray); + List, Output>> gradVarPairs = new ArrayList<>(); + + for (int i = 0; i < variableOutputArray.length; i++) { + gradVarPairs.add(new Pair<>(gradients[i], (Output)variableOutputArray[i])); + } + + return gradVarPairs; + } + + public Op applyGradients(List, Output>> gradsAndVars, String name) { + List> variables = gradsAndVars.stream().map(Pair::getB).collect(Collectors.toList()); + + createSlots(variables); + + Optional prepOp = prepare(name+"/prepare"); + + List> updateOps = new ArrayList<>(); + prepOp.ifPresent(updateOps::add); + for (Pair pair : gradsAndVars) { + updateOps.add(applyDense((Output)pair.getA(),(Output)pair.getB())); + } + + return finish(updateOps,name); + } + + /** + * Gets the slot associated with the specified variable and slot name. + * @param var The variable to lookup. + * @param slotName The slot name. + * @return The slot or {@link Optional#empty}. + */ + public Optional> getSlot(Output var, String slotName) { + return getSlot(var.op().name(),slotName); + } + + /** + * Gets the slot associated with the specified variable and slot name. + * @param varName The variable to lookup. + * @param slotName The slot name. + * @return The slot or {@link Optional#empty}. + */ + public Optional> getSlot(String varName, String slotName) { + Map> variables = slots.get(slotName); + if (variables != null) { + Variable slot = variables.get(varName); + if (slot != null) { + return Optional.of(slot); + } else { + return Optional.empty(); + } + } else { + return Optional.empty(); + } + } + + /** + * Creates a slot in the graph for the specified variable with the specified name. Adds the slot's initializer + * to the graph's initializers, and the slot to the Optimizer's slot map. + * @param variable The variable to create the slot for. + * @param slotName The name of the slot. + * @param initializer The initializer for the slot. + * @param The type of the variable. + */ + protected void createSlot(Output variable, String slotName, Operand initializer) { + Variable slot = (Variable) tf.withName(createName(variable, slotName)).variable(variable.shape(), TFloat.DTYPE); + Assign slotInit = tf.assign(slot, initializer); + graph.addInitializer(slotInit); + String varName = variable.op().name(); + Map> variables = slots.computeIfAbsent(slotName,(k) -> new HashMap<>()); + variables.put(varName,slot); + } + + /** + * No-op prepare method. + * + * @param scopeName The scope name to use for any variable creations. + */ + protected Optional prepare(String scopeName) { + return Optional.empty(); + } + + /** + * No-op slot creation method. + * @param variables The variables to create slots for. + */ + protected void createSlots(List> variables) { } + + /** + * Generates + * @param gradient + * @param variable + * @param + * @return + */ + protected abstract Operand applyDense(Output gradient, Output variable); + + /** + * Gathers up the update operations into a single op that can be used as a run target. + * @param updateOperations The update operations. + * @param name The name of the run target. + * @return A NoOp with a control dependency on each update operation. + */ + protected Op finish(List> updateOperations, String name) { + Scope scope = new Scope(graph); + scope = scope.withName(name); + scope = scope.withControlDependencies(updateOperations); + return NoOp.create(scope); + } + + /** + * Name of the optimizer. + * @return The optimizer name. + */ + public abstract String getOptimizerName(); + + public static String createName(Output variable, String slotName) { + return variable.op().name() + "-" + slotName; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java new file mode 100644 index 00000000000..0f1a1f4c85b --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java @@ -0,0 +1,105 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; + +/** + * Optimizer that implements the RMSProp algorithm. + * + * See the lecture notes + * that is inexplicably the canonical reference. + */ +public class RMSProp extends Optimizer { + + public static final String RMS = "rms"; + public static final String MG = "mg"; // mean gradient? + public static final String MOMENTUM = "momentum"; + + private final float learningRate; + private final float decay; + private final float momentum; + private final float epsilon; + private final boolean centered; + + public RMSProp(Graph graph, float learningRate) { + this(graph, learningRate, 0.9f, 0.0f, 1e-10f, false); + } + + public RMSProp(Graph graph, float learningRate, float decay, float momentum, float epsilon, boolean centered) { + super(graph); + this.learningRate = learningRate; + this.decay = decay; + this.momentum = momentum; + this.epsilon = epsilon; + this.centered = centered; + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createRMSPropSlot(v); + } + } + + private void createRMSPropSlot(Output v) { + Operand rmsInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(1.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), RMS, rmsInitializer); + Operand momentumInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), MOMENTUM, momentumInitializer); + if (centered) { + Operand mgInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), MG, mgInitializer); + } + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable rmsSlot = (Variable) getSlot(variable,RMS).get(); + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable momentumSlot = (Variable) getSlot(variable,MOMENTUM).get(); + if (centered) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable mgSlot = (Variable) getSlot(variable, MG).get(); + return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, + tf.constant(learningRate, gradient.dataType()), + tf.constant(decay, gradient.dataType()), + tf.constant(momentum, gradient.dataType()), + tf.constant(epsilon, gradient.dataType()), + gradient); + } else { + return tf.train.applyRmsProp(variable, rmsSlot, momentumSlot, + tf.constant(learningRate, gradient.dataType()), + tf.constant(decay, gradient.dataType()), + tf.constant(momentum, gradient.dataType()), + tf.constant(epsilon, gradient.dataType()), + gradient); + } + } + + @Override + public String toString() { + return "RMSProp{" + + "learningRate=" + learningRate + + ", decay=" + decay + + ", momentum=" + momentum + + ", epsilon=" + epsilon + + ", centered=" + centered + + '}'; + } + + @Override + public String getOptimizerName() { + return "RMSProp"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java new file mode 100644 index 00000000000..8e6e1a0b3ea --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java @@ -0,0 +1,56 @@ +package org.tensorflow.sandbox.util; + +import java.io.Serializable; +import java.util.Objects; + +/** + * An immutable pair of things. + * + * @param The type of the first object. + * @param The type of the second object. + */ +public class Pair implements Serializable { + private static final long serialVersionUID = 1L; + + private final T1 a; + + private final T2 b; + + public Pair(T1 a, T2 b) { + this.a = a; + this.b = b; + } + + public T1 getA() { + return a; + } + + public T2 getB() { + return b; + } + + @Override + public int hashCode() { + return a.hashCode() ^ b.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (!(obj instanceof Pair)) { + return false; + } + final Pair other = (Pair) obj; + if (!Objects.equals(this.a, other.a)) { + return false; + } + return Objects.equals(this.b, other.b); + } + + @Override + public String toString() { + return "Pair{" + "a=" + a + ", b=" + b + '}'; + } +} From 68a353c1fb555ac8a4e39f0ed65ac75de77864f4 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 5 Nov 2019 11:14:39 -0500 Subject: [PATCH 02/22] Adding Apache 2.0 license header to all optimizer files. --- .../java/org/tensorflow/sandbox/MNISTTest.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/AdaDelta.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/AdaGrad.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/AdaGradDA.java | 12 ++++++++++++ .../org/tensorflow/sandbox/optimizers/Adam.java | 12 ++++++++++++ .../sandbox/optimizers/GradientDescent.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/Momentum.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/Optimizer.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/RMSProp.java | 12 ++++++++++++ .../java/org/tensorflow/sandbox/util/Pair.java | 15 +++++++++++++++ 10 files changed, 123 insertions(+) diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java index 76ffd102371..73932c109ec 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java index f0bc847bdc5..687687a0661 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java index 6418e7b9d69..6d3240a2ffe 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java index 1a4ff11d623..c71735ad35d 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java index bc0342e7c04..1337163abf6 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java index e7aa095367a..efb067f68e4 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java index 5feeb9faa1e..34b94ed060b 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java index 494a2d650fb..e7e7e87f968 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java index 0f1a1f4c85b..b34b22c9a5c 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java index 8e6e1a0b3ea..07560d1e56d 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java @@ -1,3 +1,18 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.sandbox.util; import java.io.Serializable; From b3f4be8ee32caeeb503ff625fcd8c95e4d0f33cd Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 6 Dec 2019 16:40:19 -0500 Subject: [PATCH 03/22] Bug fix for the MNISTTest. --- .../src/main/java/org/tensorflow/sandbox/MNISTTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java index 73932c109ec..c25147f7134 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java @@ -90,7 +90,7 @@ public static Graph build(String optimizerName) { // Scaling the features Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); - Operand scaledInput = tf.math.div(tf.math.add(input, centeringFactor), scalingFactor); + Operand scaledInput = tf.math.div(tf.math.sub(input, centeringFactor), scalingFactor); // First conv layer Variable conv1Weights = tf.variable(Shape.make(5, 5, NUM_CHANNELS, 32), TFloat.DTYPE); From d1868ea31b7d0ca9b01cc0ad01bafb6f22c5536f Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 31 Jan 2020 12:32:22 -0500 Subject: [PATCH 04/22] Refactor to uptake latest tensorflow-core changes. --- .../org/tensorflow/sandbox/MNISTTest.java | 76 +++++++++---------- .../sandbox/optimizers/AdaDelta.java | 8 +- .../sandbox/optimizers/AdaGrad.java | 6 +- .../sandbox/optimizers/AdaGradDA.java | 12 +-- .../tensorflow/sandbox/optimizers/Adam.java | 30 ++++---- .../sandbox/optimizers/GradientDescent.java | 7 +- .../sandbox/optimizers/Momentum.java | 6 +- .../sandbox/optimizers/Optimizer.java | 20 ++--- .../sandbox/optimizers/RMSProp.java | 10 +-- .../org/tensorflow/sandbox/util/Pair.java | 2 +- 10 files changed, 86 insertions(+), 91 deletions(-) diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java index c25147f7134..a508046f5d7 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; -import org.tensorflow.nio.nd.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Assign; @@ -44,7 +43,8 @@ import org.tensorflow.sandbox.optimizers.Momentum; import org.tensorflow.sandbox.optimizers.Optimizer; import org.tensorflow.sandbox.optimizers.RMSProp; -import org.tensorflow.types.TFloat; +import org.tensorflow.tools.Shape; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import java.io.BufferedInputStream; @@ -84,71 +84,71 @@ public static Graph build(String optimizerName) { Ops tf = Ops.create(graph); // Inputs - Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat.DTYPE, Placeholder.shape(Shape.make(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); + Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.make(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); Placeholder labels = tf.withName(TARGET).placeholder(TInt32.DTYPE); // Scaling the features - Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); - Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); - Operand scaledInput = tf.math.div(tf.math.sub(input, centeringFactor), scalingFactor); + Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); + Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); + Operand scaledInput = tf.math.div(tf.math.sub(input, centeringFactor), scalingFactor); // First conv layer - Variable conv1Weights = tf.variable(Shape.make(5, 5, NUM_CHANNELS, 32), TFloat.DTYPE); - Assign weights1Init = tf.assign(conv1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv1Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable conv1Weights = tf.variable(Shape.make(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE); + Assign weights1Init = tf.assign(conv1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv1Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); graph.addInitializer(weights1Init); - Conv2d conv1 = tf.nn.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); - Variable conv1Biases = tf.variable(Shape.make(32), TFloat.DTYPE); - Assign biases1Init = tf.assign(conv1Biases, tf.fill(tf.shape(conv1Biases), tf.constant(0.0f))); + Conv2d conv1 = tf.nn.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable conv1Biases = tf.variable(Shape.make(32), TFloat32.DTYPE); + Assign biases1Init = tf.assign(conv1Biases, tf.fill(tf.shape(conv1Biases), tf.constant(0.0f))); graph.addInitializer(biases1Init); - Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); + Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); // First pooling layer - MaxPool pool1 = tf.nn.maxPool(relu1, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); + MaxPool pool1 = tf.nn.maxPool(relu1, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); // Second conv layer - Variable conv2Weights = tf.variable(Shape.make(5, 5, 32, 64), TFloat.DTYPE); - Assign weights2Init = tf.assign(conv2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv2Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable conv2Weights = tf.variable(Shape.make(5, 5, 32, 64), TFloat32.DTYPE); + Assign weights2Init = tf.assign(conv2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv2Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); graph.addInitializer(weights2Init); - Conv2d conv2 = tf.nn.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); - Variable conv2Biases = tf.variable(Shape.make(64), TFloat.DTYPE); - Assign biases2Init = tf.assign(conv2Biases, tf.fill(tf.shape(conv2Biases), tf.constant(0.1f))); + Conv2d conv2 = tf.nn.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable conv2Biases = tf.variable(Shape.make(64), TFloat32.DTYPE); + Assign biases2Init = tf.assign(conv2Biases, tf.fill(tf.shape(conv2Biases), tf.constant(0.1f))); graph.addInitializer(biases2Init); - Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); + Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); // Second pooling layer - MaxPool pool2 = tf.nn.maxPool(relu2, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); + MaxPool pool2 = tf.nn.maxPool(relu2, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); // Flatten inputs - Reshape flatten = tf.reshape(pool2, tf.concat(Arrays.asList(tf.slice(tf.shape(pool2), tf.constant(new int[]{0}), tf.constant(new int[]{1})), tf.constant(new int[]{-1})), tf.constant(0))); + Reshape flatten = tf.reshape(pool2, tf.concat(Arrays.asList(tf.slice(tf.shape(pool2), tf.constant(new int[]{0}), tf.constant(new int[]{1})), tf.constant(new int[]{-1})), tf.constant(0))); // Fully connected layer - Variable fc1Weights = tf.variable(Shape.make(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat.DTYPE); - Assign weights3Init = tf.assign(fc1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc1Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc1Weights = tf.variable(Shape.make(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE); + Assign weights3Init = tf.assign(fc1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc1Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); graph.addInitializer(weights3Init); - Variable fc1Biases = tf.variable(Shape.make(512), TFloat.DTYPE); - Assign biases3Init = tf.assign(fc1Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc1Biases))); + Variable fc1Biases = tf.variable(Shape.make(512), TFloat32.DTYPE); + Assign biases3Init = tf.assign(fc1Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc1Biases))); graph.addInitializer(biases3Init); - Relu relu3 = tf.nn.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); + Relu relu3 = tf.nn.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); // Softmax layer - Variable fc2Weights = tf.variable(Shape.make(512, NUM_LABELS), TFloat.DTYPE); - Assign weights4Init = tf.assign(fc2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc2Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc2Weights = tf.variable(Shape.make(512, NUM_LABELS), TFloat32.DTYPE); + Assign weights4Init = tf.assign(fc2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc2Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); graph.addInitializer(weights4Init); - Variable fc2Biases = tf.variable(Shape.make(NUM_LABELS), TFloat.DTYPE); - Assign biases4Init = tf.assign(fc2Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc2Biases))); + Variable fc2Biases = tf.variable(Shape.make(NUM_LABELS), TFloat32.DTYPE); + Assign biases4Init = tf.assign(fc2Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc2Biases))); graph.addInitializer(biases4Init); - Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); + Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); // Predicted outputs - Softmax prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); + Softmax prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); // Loss function & regularization - OneHot oneHot = tf.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); - SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); - Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); - Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math.add(tf.nn.l2Loss(fc1Biases), tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); - Add loss = tf.withName(TRAINING_LOSS).math.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); + OneHot oneHot = tf.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); + Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); + Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math.add(tf.nn.l2Loss(fc1Biases), tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); + Add loss = tf.withName(TRAINING_LOSS).math.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); // Optimizer Optimizer optimizer; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java index 687687a0661..415e62c176d 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -60,9 +60,9 @@ protected void createSlots(List> variables) { } private void createAdaDeltaSlot(Output v) { - Operand accumulatorInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand accumulatorInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); - Operand updateInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand updateInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java index 6d3240a2ffe..11cf7ba867a 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -57,7 +57,7 @@ protected void createSlots(List> variables) { } private void createAdaGradSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java index c71735ad35d..e2c591cd9ea 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,12 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.nio.nd.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.tools.Shape; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -72,15 +72,15 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdaGradDASlot(v); } - globalStep = tf.withName("adagrad-da-global-step").variable(Shape.make(),TInt64.DTYPE); + globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(),TInt64.DTYPE); Assign globalStepInitializer = tf.assign(globalStep, tf.constant(0L)); graph.addInitializer(globalStepInitializer); } private void createAdaGradDASlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); - Operand sqInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat.DTYPE));//v.dataType())); + Operand sqInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java index 1337163abf6..c06ac96a467 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,12 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.nio.nd.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.tools.Shape; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -47,12 +47,12 @@ public class Adam extends Optimizer { private final float epsilon; - private Constant learningRateConst; - private Constant epsilonConst; - private Constant betaOneConst; - private Constant betaTwoConst; - private Variable betaOnePower; - private Variable betaTwoPower; + private Constant learningRateConst; + private Constant epsilonConst; + private Constant betaOneConst; + private Constant betaTwoConst; + private Variable betaOnePower; + private Variable betaTwoPower; public Adam(Graph graph, float learningRate) { this(graph, learningRate, 0.9f, 0.999f, 1e-8f); @@ -71,11 +71,11 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdamSlot(v); } - betaOnePower = tf.withName("beta1_power").variable(Shape.make(),TFloat.DTYPE); - Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne, TFloat.DTYPE)); + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(),TFloat32.DTYPE); + Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne, TFloat32.DTYPE)); graph.addInitializer(betaOnePowerInit); - betaTwoPower = tf.withName("beta2_power").variable(Shape.make(),TFloat.DTYPE); - Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo, TFloat.DTYPE)); + betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(),TFloat32.DTYPE); + Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo, TFloat32.DTYPE)); graph.addInitializer(betaTwoPowerInit); } @@ -89,9 +89,9 @@ protected Optional prepare(String scopeName) { } private void createAdamSlot(Output v) { - Operand firstMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand firstMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); - Operand secondMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand secondMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java index efb067f68e4..fd0f264d664 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +18,8 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.Variable; -import org.tensorflow.op.train.ApplyMomentum; -import org.tensorflow.types.TFloat; import org.tensorflow.types.family.TType; -import java.util.List; /** * Basic SGD. diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java index 34b94ed060b..e0b4fbbac49 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; -import org.tensorflow.types.TFloat; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -57,7 +57,7 @@ protected void createSlots(List> variables) { } private void createMomentumSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), MOMENTUM, initializer); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java index e7e7e87f968..1ac3ef1ceed 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.NoOp; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import org.tensorflow.sandbox.util.Pair; @@ -44,7 +44,7 @@ public abstract class Optimizer { public static final String VARIABLE_V2 = "VariableV2"; /** - * Top level map is variable name, bottom map is slot name. + * Top level map key is the variable name, lower level map key is the slot name. */ private final Map>> slots; @@ -101,7 +101,7 @@ public List, Output>> computeGradients(Operand, Output>> gradVarPairs = new ArrayList<>(); for (int i = 0; i < variableOutputArray.length; i++) { - gradVarPairs.add(new Pair<>(gradients[i], (Output)variableOutputArray[i])); + gradVarPairs.add(new Pair<>(gradients[i], variableOutputArray[i])); } return gradVarPairs; @@ -162,7 +162,7 @@ public Optional> getSlot(String varName, String slotName) { * @param The type of the variable. */ protected void createSlot(Output variable, String slotName, Operand initializer) { - Variable slot = (Variable) tf.withName(createName(variable, slotName)).variable(variable.shape(), TFloat.DTYPE); + Variable slot = (Variable) tf.withName(createName(variable, slotName)).variable(variable.shape(), TFloat32.DTYPE); Assign slotInit = tf.assign(slot, initializer); graph.addInitializer(slotInit); String varName = variable.op().name(); @@ -186,11 +186,11 @@ protected Optional prepare(String scopeName) { protected void createSlots(List> variables) { } /** - * Generates - * @param gradient - * @param variable - * @param - * @return + * Generates the gradient update operations for the specific variable and gradient. + * @param gradient The gradient to use. + * @param variable The variable to update. + * @param The type of the variable. + * @return An operand which applies the desired optimizer update to the variable. */ protected abstract Operand applyDense(Output gradient, Output variable); diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java index b34b22c9a5c..9ba293b46f6 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -64,12 +64,12 @@ protected void createSlots(List> variables) { } private void createRMSPropSlot(Output v) { - Operand rmsInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(1.0f, TFloat.DTYPE));//v.dataType())); + Operand rmsInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(1.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), RMS, rmsInitializer); - Operand momentumInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand momentumInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand mgInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand mgInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), MG, mgInitializer); } } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java index 07560d1e56d..3160b0f2b8d 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 6d189cc79d58d968bd004c7c0c425c4dccfe802b Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 31 Jan 2020 15:30:22 -0500 Subject: [PATCH 05/22] Added type safety and updates for new api. --- .../sandbox/optimizers/AdaDelta.java | 11 ++- .../sandbox/optimizers/AdaGrad.java | 6 +- .../sandbox/optimizers/AdaGradDA.java | 13 ++-- .../tensorflow/sandbox/optimizers/Adam.java | 14 ++-- .../sandbox/optimizers/GradientDescent.java | 3 +- .../sandbox/optimizers/Momentum.java | 6 +- .../sandbox/optimizers/Optimizer.java | 70 ++++++++++++------ .../sandbox/optimizers/RMSProp.java | 16 ++--- .../org/tensorflow/sandbox/util/Pair.java | 71 ------------------- 9 files changed, 75 insertions(+), 135 deletions(-) delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java index 415e62c176d..c4cd4079df2 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java @@ -18,7 +18,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -60,18 +59,16 @@ protected void createSlots(List> variables) { } private void createAdaDeltaSlot(Output v) { - Operand accumulatorInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand accumulatorInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); - Operand updateInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand updateInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable accumSlot = (Variable) getSlot(variable,ACCUMULATOR).get(); - @SuppressWarnings("unchecked") - Variable accumUpdateSlot = (Variable) getSlot(variable,ACCUMULATOR_UPDATE).get(); + Variable accumSlot = getSlot(variable,ACCUMULATOR).get(); + Variable accumUpdateSlot = getSlot(variable,ACCUMULATOR_UPDATE).get(); return tf.train.applyAdadelta(variable, accumSlot, accumUpdateSlot, tf.constant(learningRate, gradient.dataType()), tf.constant(rho, gradient.dataType()), diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java index 11cf7ba867a..00a56f5853a 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java @@ -18,7 +18,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -57,14 +56,13 @@ protected void createSlots(List> variables) { } private void createAdaGradSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat32.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable slot = (Variable) getSlot(variable,ACCUMULATOR).get(); + Variable slot = getSlot(variable,ACCUMULATOR).get(); return tf.train.applyAdagrad(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java index e2c591cd9ea..753e104c8a1 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java @@ -20,7 +20,6 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.tools.Shape; import org.tensorflow.types.TFloat32; @@ -63,7 +62,7 @@ public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, } @Override - protected Optional prepare(String name) { + protected Optional> prepare(String name) { return Optional.of(tf.assignAdd(globalStep,tf.constant(1L))); } @@ -78,18 +77,16 @@ protected void createSlots(List> variables) { } private void createAdaGradDASlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); - Operand sqInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat32.DTYPE));//v.dataType())); + Operand sqInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable gradSlot = (Variable) getSlot(variable,ACCUMULATOR).get(); - @SuppressWarnings("unchecked") - Variable gradSquaredSlot = (Variable) getSlot(variable,SQUARED_ACCUMULATOR).get(); + Variable gradSlot = getSlot(variable,ACCUMULATOR).get(); + Variable gradSquaredSlot = getSlot(variable,SQUARED_ACCUMULATOR).get(); return tf.train.applyAdagradDa(variable, gradSlot, gradSquaredSlot, gradient, tf.constant(learningRate, gradient.dataType()), tf.constant(l1Strength, gradient.dataType()), diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java index c06ac96a467..7b18fab0354 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java @@ -69,7 +69,7 @@ public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float @Override protected void createSlots(List> variables) { for (Output v : variables) { - createAdamSlot(v); + createAdamSlot(v.asOutput()); } betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(),TFloat32.DTYPE); Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne, TFloat32.DTYPE)); @@ -80,7 +80,7 @@ protected void createSlots(List> variables) { } @Override - protected Optional prepare(String scopeName) { + protected Optional> prepare(String scopeName) { betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); learningRateConst = tf.constant(learningRate); @@ -89,18 +89,16 @@ protected Optional prepare(String scopeName) { } private void createAdamSlot(Output v) { - Operand firstMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand firstMomentInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); - Operand secondMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand secondMomentInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable firstMomentSlot = (Variable) getSlot(variable,FIRST_MOMENT).get(); - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable secondMomentSlot = (Variable) getSlot(variable,SECOND_MOMENT).get(); + Variable firstMomentSlot = getSlot(variable,FIRST_MOMENT).get(); + Variable secondMomentSlot = getSlot(variable,SECOND_MOMENT).get(); return tf.train.applyAdam(variable, firstMomentSlot, secondMomentSlot, tf.dtypes.cast(betaOnePower,gradient.dataType()), tf.dtypes.cast(betaTwoPower,gradient.dataType()), diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java index fd0f264d664..c95398abe6f 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java @@ -18,6 +18,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -35,7 +36,7 @@ public GradientDescent(Graph graph, float learningRate) { @Override protected Operand applyDense(Output gradient, Output variable) { - return tf.train.applyGradientDescent(variable, tf.constant(learningRate, gradient.dataType()), gradient); + return tf.train.applyGradientDescent(variable, tf.dtypes.cast(tf.constant(learningRate, TFloat32.DTYPE), gradient.dataType()), gradient); } @Override diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java index e0b4fbbac49..60f3497d570 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java @@ -18,7 +18,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; import org.tensorflow.types.TFloat32; @@ -57,14 +56,13 @@ protected void createSlots(List> variables) { } private void createMomentumSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), MOMENTUM, initializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable slot = (Variable) getSlot(variable,MOMENTUM).get(); + Variable slot = getSlot(variable,MOMENTUM).get(); return tf.train.applyMomentum(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient, tf.constant(momentum, gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java index 1ac3ef1ceed..37467a50f34 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java @@ -25,9 +25,7 @@ import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.NoOp; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import org.tensorflow.sandbox.util.Pair; import java.util.ArrayList; import java.util.HashMap; @@ -52,7 +50,7 @@ public abstract class Optimizer { * Global state variables */ //TODO make this be used. - protected final List globals; + protected final List> globals; /** * The Graph this optimizer is operating on. @@ -76,12 +74,12 @@ public Op minimize(Operand loss) { } public Op minimize(Operand loss, String name) { - List, Output>> gradsAndVars = computeGradients(loss); + List> gradsAndVars = computeGradients(loss); return applyGradients(gradsAndVars, name); } - public List, Output>> computeGradients(Operand loss) { + public List> computeGradients(Operand loss) { List variables = new ArrayList<>(); Iterator opItr = graph.operations(); while (opItr.hasNext()) { @@ -98,26 +96,30 @@ public List, Output>> computeGradients(Operand[] gradients = graph.addGradients(loss.asOutput(), variableOutputArray); - List, Output>> gradVarPairs = new ArrayList<>(); + List> gradVarPairs = new ArrayList<>(); for (int i = 0; i < variableOutputArray.length; i++) { - gradVarPairs.add(new Pair<>(gradients[i], variableOutputArray[i])); + @SuppressWarnings("unchecked") + Output typedGrad = (Output) gradients[i]; + @SuppressWarnings("unchecked") + Output typedVar = (Output) variableOutputArray[i]; + gradVarPairs.add(new GradAndVar<>(typedGrad, typedVar)); } return gradVarPairs; } - public Op applyGradients(List, Output>> gradsAndVars, String name) { - List> variables = gradsAndVars.stream().map(Pair::getB).collect(Collectors.toList()); + public Op applyGradients(List> gradsAndVars, String name) { + List> variables = gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList()); createSlots(variables); - Optional prepOp = prepare(name+"/prepare"); + Optional> prepOp = prepare(name+"/prepare"); - List> updateOps = new ArrayList<>(); + List> updateOps = new ArrayList<>(); prepOp.ifPresent(updateOps::add); - for (Pair pair : gradsAndVars) { - updateOps.add(applyDense((Output)pair.getA(),(Output)pair.getB())); + for (GradAndVar pair : gradsAndVars) { + updateOps.add(applyDense(pair)); } return finish(updateOps,name); @@ -129,7 +131,7 @@ public Op applyGradients(List, Output>> gradsAnd * @param slotName The slot name. * @return The slot or {@link Optional#empty}. */ - public Optional> getSlot(Output var, String slotName) { + public Optional> getSlot(Output var, String slotName) { return getSlot(var.op().name(),slotName); } @@ -139,12 +141,14 @@ public Optional> getSlot(Output var, String slotName) { * @param slotName The slot name. * @return The slot or {@link Optional#empty}. */ - public Optional> getSlot(String varName, String slotName) { - Map> variables = slots.get(slotName); + private Optional> getSlot(String varName, String slotName) { + Map> variables = slots.get(slotName); if (variables != null) { - Variable slot = variables.get(varName); + Variable slot = variables.get(varName); if (slot != null) { - return Optional.of(slot); + @SuppressWarnings("unchecked") // This method should only be called when the type is known. + Optional> opt = Optional.of((Variable)slot); + return opt; } else { return Optional.empty(); } @@ -162,11 +166,11 @@ public Optional> getSlot(String varName, String slotName) { * @param The type of the variable. */ protected void createSlot(Output variable, String slotName, Operand initializer) { - Variable slot = (Variable) tf.withName(createName(variable, slotName)).variable(variable.shape(), TFloat32.DTYPE); + Variable slot = tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.dataType()); Assign slotInit = tf.assign(slot, initializer); graph.addInitializer(slotInit); String varName = variable.op().name(); - Map> variables = slots.computeIfAbsent(slotName,(k) -> new HashMap<>()); + Map> variables = slots.computeIfAbsent(slotName,(k) -> new HashMap<>()); variables.put(varName,slot); } @@ -175,7 +179,7 @@ protected void createSlot(Output variable, String slotName, * * @param scopeName The scope name to use for any variable creations. */ - protected Optional prepare(String scopeName) { + protected Optional> prepare(String scopeName) { return Optional.empty(); } @@ -185,6 +189,10 @@ protected Optional prepare(String scopeName) { */ protected void createSlots(List> variables) { } + private Operand applyDense(GradAndVar gradVarPair) { + return applyDense(gradVarPair.getGradient(),gradVarPair.getVariable()); + } + /** * Generates the gradient update operations for the specific variable and gradient. * @param gradient The gradient to use. @@ -213,7 +221,25 @@ protected Op finish(List> updateOperations, String name) { */ public abstract String getOptimizerName(); - public static String createName(Output variable, String slotName) { + public static String createName(Output variable, String slotName) { return variable.op().name() + "-" + slotName; } + + public static class GradAndVar { + private final Output gradient; + private final Output variable; + + public GradAndVar(Output gradient, Output variable) { + this.gradient = gradient; + this.variable = variable; + } + + public Output getGradient() { + return gradient; + } + + public Output getVariable() { + return variable; + } + } } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java index 9ba293b46f6..a20996f0018 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java @@ -18,7 +18,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -64,25 +63,22 @@ protected void createSlots(List> variables) { } private void createRMSPropSlot(Output v) { - Operand rmsInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(1.0f, TFloat32.DTYPE));//v.dataType())); + Operand rmsInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), RMS, rmsInitializer); - Operand momentumInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand momentumInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand mgInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand mgInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), MG, mgInitializer); } } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable rmsSlot = (Variable) getSlot(variable,RMS).get(); - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable momentumSlot = (Variable) getSlot(variable,MOMENTUM).get(); + Variable rmsSlot = getSlot(variable,RMS).get(); + Variable momentumSlot = getSlot(variable,MOMENTUM).get(); if (centered) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable mgSlot = (Variable) getSlot(variable, MG).get(); + Variable mgSlot = getSlot(variable, MG).get(); return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, tf.constant(learningRate, gradient.dataType()), tf.constant(decay, gradient.dataType()), diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java deleted file mode 100644 index 3160b0f2b8d..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox.util; - -import java.io.Serializable; -import java.util.Objects; - -/** - * An immutable pair of things. - * - * @param The type of the first object. - * @param The type of the second object. - */ -public class Pair implements Serializable { - private static final long serialVersionUID = 1L; - - private final T1 a; - - private final T2 b; - - public Pair(T1 a, T2 b) { - this.a = a; - this.b = b; - } - - public T1 getA() { - return a; - } - - public T2 getB() { - return b; - } - - @Override - public int hashCode() { - return a.hashCode() ^ b.hashCode(); - } - - @Override - public boolean equals(Object obj) { - if (obj == null) { - return false; - } - if (!(obj instanceof Pair)) { - return false; - } - final Pair other = (Pair) obj; - if (!Objects.equals(this.a, other.a)) { - return false; - } - return Objects.equals(this.b, other.b); - } - - @Override - public String toString() { - return "Pair{" + "a=" + a + ", b=" + b + '}'; - } -} From 53e438a76ea768341ae8bc4bd5aa8c1db4cca77a Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 31 Jan 2020 15:32:57 -0500 Subject: [PATCH 06/22] Small changes, plus a fix for DataTypes to include references to the type. --- pom.xml | 10 ++++++++++ tensorflow-core/tensorflow-core-api/pom.xml | 9 +++++++++ .../src/main/java/org/tensorflow/DataTypes.java | 1 + .../src/main/java/org/tensorflow/Tensor.java | 3 +-- .../java/org/tensorflow/tools/ndarray/NdArrays.java | 4 ++-- 5 files changed, 23 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index a06ccc25a0c..057811ef0e5 100644 --- a/pom.xml +++ b/pom.xml @@ -40,6 +40,7 @@ 1.8 4.12 1.21 + true @@ -121,6 +122,15 @@ + + + + org.apache.maven.plugins + maven-jar-plugin + 3.2.0 + + + diff --git a/tensorflow-core/tensorflow-core-api/pom.xml b/tensorflow-core/tensorflow-core-api/pom.xml index 3ecd648e60b..15ed9c9bd56 100644 --- a/tensorflow-core/tensorflow-core-api/pom.xml +++ b/tensorflow-core/tensorflow-core-api/pom.xml @@ -332,6 +332,15 @@ + + maven-assembly-plugin + 3.2.0 + + + jar-with-dependencies + + + diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java index 1e2ad6ec427..468d5111c36 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/DataTypes.java @@ -70,5 +70,6 @@ static DataType fromNativeCode(int nativeCode) { // to allow user to register custom data types? private static void register(DataType dataType) { DATA_TYPE_REGISTRY.put(dataType.nativeCode(), dataType); + DATA_TYPE_REGISTRY.put(dataType.nativeCode() + 100, dataType); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 585442ea559..5939db9ead9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -146,7 +146,6 @@ public final class Tensor implements AutoCloseable { * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type * system. */ - @SuppressWarnings("unchecked") public static Tensor create(Object obj, DataType dtype) { if (!objectCompatWithType(obj, dtype)) { throw new IllegalArgumentException( @@ -158,7 +157,7 @@ public static Tensor create(Object obj, DataType dtype) } long[] dimSizes = new long[numDimensions(obj, dtype)]; fillShape(obj, 0, dimSizes); - Tensor t = new Tensor(dtype, Shape.of(dimSizes)); + Tensor t = new Tensor<>(dtype, Shape.make(dimSizes)); TF_Tensor nativeHandle; if (t.dtype != TString.DTYPE) { long byteSize = elemByteSize(t.dtype) * t.shape.size(); diff --git a/tensorflow-tools/src/main/java/org/tensorflow/tools/ndarray/NdArrays.java b/tensorflow-tools/src/main/java/org/tensorflow/tools/ndarray/NdArrays.java index fef35f923ff..91fae5892c1 100644 --- a/tensorflow-tools/src/main/java/org/tensorflow/tools/ndarray/NdArrays.java +++ b/tensorflow-tools/src/main/java/org/tensorflow/tools/ndarray/NdArrays.java @@ -456,8 +456,8 @@ public static NdArray scalarOfObject(T value) { */ @SafeVarargs public static NdArray vectorOfObjects(T... values) { - if (values == null) { - throw new IllegalArgumentException(); + if (values == null || values.length == 0) { + throw new IllegalArgumentException("Null or zero length input supplied to vectorOfObjects."); } return wrap(Shape.of(values.length), DataBuffers.from(values, false, false)); } From 83140b46bc1d4a712a81c4ac7c7b13776770eb0c Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 7 Feb 2020 16:55:37 -0500 Subject: [PATCH 07/22] Repackaging the optimizers into tensorflow-training, org.tensorflow.training. --- pom.xml | 2 +- tensorflow-training/pom.xml | 77 +++++++++++++++++++ .../training/examples}/MNISTTest.java | 18 ++--- .../training}/optimizers/AdaDelta.java | 2 +- .../training}/optimizers/AdaGrad.java | 2 +- .../training}/optimizers/AdaGradDA.java | 2 +- .../tensorflow/training}/optimizers/Adam.java | 2 +- .../training}/optimizers/GradientDescent.java | 2 +- .../training}/optimizers/Momentum.java | 2 +- .../training}/optimizers/Optimizer.java | 2 +- .../training}/optimizers/RMSProp.java | 2 +- 11 files changed, 95 insertions(+), 18 deletions(-) create mode 100644 tensorflow-training/pom.xml rename {tensorflow-sandbox/src/main/java/org/tensorflow/sandbox => tensorflow-training/src/main/java/org/tensorflow/training/examples}/MNISTTest.java (96%) rename {tensorflow-sandbox/src/main/java/org/tensorflow/sandbox => tensorflow-training/src/main/java/org/tensorflow/training}/optimizers/AdaDelta.java (98%) rename {tensorflow-sandbox/src/main/java/org/tensorflow/sandbox => tensorflow-training/src/main/java/org/tensorflow/training}/optimizers/AdaGrad.java (98%) rename {tensorflow-sandbox/src/main/java/org/tensorflow/sandbox => tensorflow-training/src/main/java/org/tensorflow/training}/optimizers/AdaGradDA.java (99%) rename {tensorflow-sandbox/src/main/java/org/tensorflow/sandbox => tensorflow-training/src/main/java/org/tensorflow/training}/optimizers/Adam.java (99%) rename {tensorflow-sandbox/src/main/java/org/tensorflow/sandbox => tensorflow-training/src/main/java/org/tensorflow/training}/optimizers/GradientDescent.java (97%) rename {tensorflow-sandbox/src/main/java/org/tensorflow/sandbox => tensorflow-training/src/main/java/org/tensorflow/training}/optimizers/Momentum.java (98%) rename {tensorflow-sandbox/src/main/java/org/tensorflow/sandbox => tensorflow-training/src/main/java/org/tensorflow/training}/optimizers/Optimizer.java (99%) rename {tensorflow-sandbox/src/main/java/org/tensorflow/sandbox => tensorflow-training/src/main/java/org/tensorflow/training}/optimizers/RMSProp.java (98%) diff --git a/pom.xml b/pom.xml index 057811ef0e5..5112f662760 100644 --- a/pom.xml +++ b/pom.xml @@ -31,7 +31,7 @@ tensorflow-tools tensorflow-core - tensorflow-sandbox + tensorflow-training diff --git a/tensorflow-training/pom.xml b/tensorflow-training/pom.xml new file mode 100644 index 00000000000..2b3448dfd0f --- /dev/null +++ b/tensorflow-training/pom.xml @@ -0,0 +1,77 @@ + + + 4.0.0 + + + org.tensorflow + tensorflow-java + 0.1.0-SNAPSHOT + + tensorflow-training + jar + + TensorFlow Training Library + + Operations for training Tensorflow models. + + + + + org.tensorflow + tensorflow-core-api + ${project.version} + + + junit + junit + test + + + org.openjdk.jmh + jmh-core + test + + + org.openjdk.jmh + jmh-generator-annprocess + test + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.22.2 + + 1 + false + -Xmx2G -XX:MaxPermSize=256m + false + + **/*Test.java + + + + + + + diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java b/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java similarity index 96% rename from tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java rename to tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java index a508046f5d7..38f38370bce 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.tensorflow.sandbox; +package org.tensorflow.training.examples; import org.tensorflow.Graph; import org.tensorflow.Operand; @@ -35,14 +35,14 @@ import org.tensorflow.op.nn.Softmax; import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.op.random.TruncatedNormal; -import org.tensorflow.sandbox.optimizers.AdaDelta; -import org.tensorflow.sandbox.optimizers.AdaGrad; -import org.tensorflow.sandbox.optimizers.AdaGradDA; -import org.tensorflow.sandbox.optimizers.Adam; -import org.tensorflow.sandbox.optimizers.GradientDescent; -import org.tensorflow.sandbox.optimizers.Momentum; -import org.tensorflow.sandbox.optimizers.Optimizer; -import org.tensorflow.sandbox.optimizers.RMSProp; +import org.tensorflow.training.optimizers.AdaDelta; +import org.tensorflow.training.optimizers.AdaGrad; +import org.tensorflow.training.optimizers.AdaGradDA; +import org.tensorflow.training.optimizers.Adam; +import org.tensorflow.training.optimizers.GradientDescent; +import org.tensorflow.training.optimizers.Momentum; +import org.tensorflow.training.optimizers.Optimizer; +import org.tensorflow.training.optimizers.RMSProp; import org.tensorflow.tools.Shape; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java similarity index 98% rename from tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java rename to tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java index c4cd4079df2..edc19b2aa47 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.tensorflow.sandbox.optimizers; +package org.tensorflow.training.optimizers; import org.tensorflow.Graph; import org.tensorflow.Operand; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java similarity index 98% rename from tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java rename to tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java index 00a56f5853a..9146cc1b5c5 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.tensorflow.sandbox.optimizers; +package org.tensorflow.training.optimizers; import org.tensorflow.Graph; import org.tensorflow.Operand; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java similarity index 99% rename from tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java rename to tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java index 753e104c8a1..61b3bef21a4 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.tensorflow.sandbox.optimizers; +package org.tensorflow.training.optimizers; import org.tensorflow.Graph; import org.tensorflow.Operand; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java similarity index 99% rename from tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java rename to tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java index 7b18fab0354..87e5263cd3a 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.tensorflow.sandbox.optimizers; +package org.tensorflow.training.optimizers; import org.tensorflow.Graph; import org.tensorflow.Operand; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java similarity index 97% rename from tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java rename to tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java index c95398abe6f..82de29e737a 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.tensorflow.sandbox.optimizers; +package org.tensorflow.training.optimizers; import org.tensorflow.Graph; import org.tensorflow.Operand; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java similarity index 98% rename from tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java rename to tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java index 60f3497d570..f925150b561 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.tensorflow.sandbox.optimizers; +package org.tensorflow.training.optimizers; import org.tensorflow.Graph; import org.tensorflow.Operand; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java similarity index 99% rename from tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java rename to tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java index 37467a50f34..8c2733abb15 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.tensorflow.sandbox.optimizers; +package org.tensorflow.training.optimizers; import org.tensorflow.Graph; import org.tensorflow.Operand; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java similarity index 98% rename from tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java rename to tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java index a20996f0018..73ff1777923 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.tensorflow.sandbox.optimizers; +package org.tensorflow.training.optimizers; import org.tensorflow.Graph; import org.tensorflow.Operand; From b2ac923b5c61b00f04bac75b0735ebfc3177d4cc Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 5 Nov 2019 11:02:24 -0500 Subject: [PATCH 08/22] Initial commit of gradient descent optimizers. --- .../org/tensorflow/sandbox/MNISTTest.java | 345 ++++++++++++++++++ .../sandbox/optimizers/AdaDelta.java | 83 +++++ .../sandbox/optimizers/AdaGrad.java | 71 ++++ .../sandbox/optimizers/AdaGradDA.java | 117 ++++++ .../tensorflow/sandbox/optimizers/Adam.java | 131 +++++++ .../sandbox/optimizers/GradientDescent.java | 45 +++ .../sandbox/optimizers/Momentum.java | 72 ++++ .../sandbox/optimizers/Optimizer.java | 207 +++++++++++ .../sandbox/optimizers/RMSProp.java | 105 ++++++ .../org/tensorflow/sandbox/util/Pair.java | 56 +++ 10 files changed, 1232 insertions(+) create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java create mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java new file mode 100644 index 00000000000..76ffd102371 --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java @@ -0,0 +1,345 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.nio.nd.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Mean; +import org.tensorflow.op.nn.Conv2d; +import org.tensorflow.op.nn.MaxPool; +import org.tensorflow.op.nn.Relu; +import org.tensorflow.op.nn.Softmax; +import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; +import org.tensorflow.op.random.TruncatedNormal; +import org.tensorflow.sandbox.optimizers.AdaDelta; +import org.tensorflow.sandbox.optimizers.AdaGrad; +import org.tensorflow.sandbox.optimizers.AdaGradDA; +import org.tensorflow.sandbox.optimizers.Adam; +import org.tensorflow.sandbox.optimizers.GradientDescent; +import org.tensorflow.sandbox.optimizers.Momentum; +import org.tensorflow.sandbox.optimizers.Optimizer; +import org.tensorflow.sandbox.optimizers.RMSProp; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.TInt32; + +import java.io.BufferedInputStream; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.util.Arrays; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Builds a LeNet-5 style CNN for MNIST. + */ +public class MNISTTest { + + private static final Logger logger = Logger.getLogger(MNISTTest.class.getName()); + + private static final int PIXEL_DEPTH = 255; + private static final int NUM_CHANNELS = 1; + private static final int IMAGE_SIZE = 28; + private static final int NUM_LABELS = 10; + private static final long SEED = 123456789L; + + private static final String PADDING_TYPE = "SAME"; + + public static final String INPUT_NAME = "input"; + public static final String OUTPUT_NAME = "output"; + public static final String TARGET = "target"; + public static final String TRAIN = "train"; + public static final String TRAINING_LOSS = "training_loss"; + public static final String EPOCH = "epoch"; + public static final String INIT = "init"; + + public static Graph build(String optimizerName) { + Graph graph = new Graph(); + + Ops tf = Ops.create(graph); + + // Inputs + Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat.DTYPE, Placeholder.shape(Shape.make(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); + Placeholder labels = tf.withName(TARGET).placeholder(TInt32.DTYPE); + + // Scaling the features + Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); + Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); + Operand scaledInput = tf.math.div(tf.math.add(input, centeringFactor), scalingFactor); + + // First conv layer + Variable conv1Weights = tf.variable(Shape.make(5, 5, NUM_CHANNELS, 32), TFloat.DTYPE); + Assign weights1Init = tf.assign(conv1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv1Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + graph.addInitializer(weights1Init); + Conv2d conv1 = tf.nn.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable conv1Biases = tf.variable(Shape.make(32), TFloat.DTYPE); + Assign biases1Init = tf.assign(conv1Biases, tf.fill(tf.shape(conv1Biases), tf.constant(0.0f))); + graph.addInitializer(biases1Init); + Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); + + // First pooling layer + MaxPool pool1 = tf.nn.maxPool(relu1, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); + + // Second conv layer + Variable conv2Weights = tf.variable(Shape.make(5, 5, 32, 64), TFloat.DTYPE); + Assign weights2Init = tf.assign(conv2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv2Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + graph.addInitializer(weights2Init); + Conv2d conv2 = tf.nn.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable conv2Biases = tf.variable(Shape.make(64), TFloat.DTYPE); + Assign biases2Init = tf.assign(conv2Biases, tf.fill(tf.shape(conv2Biases), tf.constant(0.1f))); + graph.addInitializer(biases2Init); + Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); + + // Second pooling layer + MaxPool pool2 = tf.nn.maxPool(relu2, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); + + // Flatten inputs + Reshape flatten = tf.reshape(pool2, tf.concat(Arrays.asList(tf.slice(tf.shape(pool2), tf.constant(new int[]{0}), tf.constant(new int[]{1})), tf.constant(new int[]{-1})), tf.constant(0))); + + // Fully connected layer + Variable fc1Weights = tf.variable(Shape.make(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat.DTYPE); + Assign weights3Init = tf.assign(fc1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc1Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + graph.addInitializer(weights3Init); + Variable fc1Biases = tf.variable(Shape.make(512), TFloat.DTYPE); + Assign biases3Init = tf.assign(fc1Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc1Biases))); + graph.addInitializer(biases3Init); + Relu relu3 = tf.nn.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); + + // Softmax layer + Variable fc2Weights = tf.variable(Shape.make(512, NUM_LABELS), TFloat.DTYPE); + Assign weights4Init = tf.assign(fc2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc2Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + graph.addInitializer(weights4Init); + Variable fc2Biases = tf.variable(Shape.make(NUM_LABELS), TFloat.DTYPE); + Assign biases4Init = tf.assign(fc2Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc2Biases))); + graph.addInitializer(biases4Init); + + Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); + + // Predicted outputs + Softmax prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); + + // Loss function & regularization + OneHot oneHot = tf.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); + Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); + Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math.add(tf.nn.l2Loss(fc1Biases), tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); + Add loss = tf.withName(TRAINING_LOSS).math.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); + + // Optimizer + Optimizer optimizer; + switch (optimizerName) { + case "AdaDelta": + case "Adadelta": + case "adadelta": + optimizer = new AdaDelta(graph, 1f, 0.95f, 1e-8f); + break; + case "AdaGradDA": + case "AdagradDA": + case "adagradda": + optimizer = new AdaGradDA(graph, 0.01f); + break; + case "AdaGrad": + case "Adagrad": + case "adagrad": + optimizer = new AdaGrad(graph, 0.01f); + break; + case "Adam": + case "adam": + optimizer = new Adam(graph,0.001f,0.9f,0.999f,1e-8f); + break; + case "SGD": + case "sgd": + optimizer = new GradientDescent(graph,0.01f); + break; + case "Momentum": + case "momentum": + optimizer = new Momentum(graph, 0.01f, 0.9f, false); + break; + case "RMSProp": + case "rmsprop": + optimizer = new RMSProp(graph,0.01f, 0.9f, 0.0f, 1e-10f, false); + break; + default: + throw new IllegalArgumentException("Unknown optimizer " + optimizerName); + } + logger.info("Optimizer = " + optimizer.toString()); + Op minimize = optimizer.minimize(loss, TRAIN); + + Op init = graph.variablesInitializer(); + + return graph; + } + + public static void train(Session session, int epochs, int minibatchSize, float[][][][] data, int[] labels) { + // Initialises the parameters. + session.runner().addTarget(INIT).run(); + logger.info("Initialised the model parameters"); + + float[][][][] featureBatch = new float[minibatchSize][][][]; + int[] labelBatch = new int[minibatchSize]; + + int interval = 0; + for (int i = 0; i < epochs; i++) { + logger.log(Level.INFO, "Starting epoch " + i); + //Tensor epoch = Tensor.create(i); + for (int j = 0; j < data.length; j += minibatchSize) { + for (int k = j, m = 0; k < (j + minibatchSize) && k < data.length; k++, m++) { + featureBatch[m] = data[k]; + labelBatch[m] = labels[k]; + } + //logger.info("Batch = " + batch.size()); + Tensor input = Tensor.create(featureBatch); + Tensor target = Tensor.create(labelBatch); + Tensor loss = session.runner() + .feed(INPUT_NAME, input) + .feed(TARGET, target) + .addTarget(TRAIN) + .fetch(TRAINING_LOSS) + .run().get(0); + if (interval % 100 == 0) { + logger.log(Level.INFO, "Iteration = " + interval + ", training loss = " + loss.floatValue()); + } + input.close(); + target.close(); + loss.close(); + interval++; + } + //epoch.close(); + } + } + + /** + * Find the maximum probability and return it's index. + * + * @param probabilities The probabilites. + * @return The index of the max. + */ + public static int pred(float[] probabilities) { + float maxVal = Float.NEGATIVE_INFINITY; + int idx = 0; + for (int i = 0; i < probabilities.length; i++) { + if (probabilities[i] > maxVal) { + maxVal = probabilities[i]; + idx = i; + } + } + return idx; + } + + public static DataTuple loadData(String path) throws IOException, ClassNotFoundException { + try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) { + float[][][][] data = (float[][][][]) ois.readObject(); + int[] labels = (int[]) ois.readObject(); + return new DataTuple(data, labels); + } + } + + private static class DataTuple { + public final float[][][][] features; + public final int[] labels; + + public DataTuple(float[][][][] features, int[] labels) { + this.features = features; + this.labels = labels; + } + } + + public static void main(String[] args) throws IOException, ClassNotFoundException { + logger.info("Usage: MNISTTest "); + + logger.info("Loading training data"); + DataTuple train = loadData(args[3]); + logger.info("Loading testing data"); + DataTuple test = loadData(args[4]); + + logger.info("Loaded data."); + + float[][][][] trainData = train.features; + int[] trainLabels = train.labels; + + float[][][][] testData = test.features; + int[] testLabels = test.labels; + + logger.info("Loaded " + trainLabels.length + " training labels"); + logger.info("Loaded " + testLabels.length + " testing labels"); + + int epochs = Integer.parseInt(args[0]); + int minibatchSize = Integer.parseInt(args[1]); + + Graph graph = build(args[2]); + + int correctCount = 0; + int[][] confusionMatrix = new int[10][10]; + + try (Session session = new Session(graph)) { + train(session, epochs, minibatchSize, trainData, trainLabels); + + logger.info("Trained model"); + + float[][][][] featureBatch = new float[minibatchSize][][][]; + int[] labelBatch = new int[minibatchSize]; + float[][] prediction; + + for (int j = 0; j < testData.length; j += minibatchSize) { + for (int k = j, m = 0; k < (j + minibatchSize) && k < testData.length; k++, m++) { + featureBatch[m] = testData[k]; + labelBatch[m] = testLabels[k]; + } + try (Tensor transformedInput = Tensor.create(featureBatch); + Tensor outputTensor = session.runner() + .feed(INPUT_NAME, transformedInput) + .fetch(OUTPUT_NAME).run().get(0)) { + prediction = outputTensor.copyTo(new float[minibatchSize][NUM_LABELS]); + } + + for (int k = 0; k < labelBatch.length; k++) { + int predLabel; + + predLabel = pred(prediction[k]); + if (predLabel == labelBatch[k]) { + correctCount++; + } + + confusionMatrix[labelBatch[k]][predLabel]++; + } + + if (j % 1000 == 0) { + logger.log(Level.INFO, "Cur accuracy = " + ((float) correctCount) / (j + minibatchSize)); + } + } + + logger.info("Final accuracy = " + ((float) correctCount) / testLabels.length); + + StringBuilder sb = new StringBuilder(); + sb.append("Label"); + for (int i = 0; i < confusionMatrix.length; i++) { + sb.append(String.format("%1$5s", "" + i)); + } + sb.append("\n"); + + for (int i = 0; i < confusionMatrix.length; i++) { + sb.append(String.format("%1$5s", "" + i)); + for (int j = 0; j < confusionMatrix[i].length; j++) { + sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); + } + sb.append("\n"); + } + + System.out.println(sb.toString()); + } + + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java new file mode 100644 index 00000000000..f0bc847bdc5 --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java @@ -0,0 +1,83 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; + +/** + * Optimizer that implements the Adadelta algorithm. + * + * See the paper. + */ +public class AdaDelta extends Optimizer { + + public static final String ACCUMULATOR = "accum"; + public static final String ACCUMULATOR_UPDATE = "accum_update"; + + private final float learningRate; + + private final float rho; + + private final float epsilon; + + public AdaDelta(Graph graph, float learningRate) { + this(graph, learningRate, 0.95f, 1e-8f); + } + + public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { + super(graph); + this.learningRate = learningRate; + this.rho = rho; + this.epsilon = epsilon; + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createAdaDeltaSlot(v); + } + } + + private void createAdaDeltaSlot(Output v) { + Operand accumulatorInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); + Operand updateInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable accumSlot = (Variable) getSlot(variable,ACCUMULATOR).get(); + @SuppressWarnings("unchecked") + Variable accumUpdateSlot = (Variable) getSlot(variable,ACCUMULATOR_UPDATE).get(); + return tf.train.applyAdadelta(variable, accumSlot, accumUpdateSlot, + tf.constant(learningRate, gradient.dataType()), + tf.constant(rho, gradient.dataType()), + tf.constant(epsilon, gradient.dataType()), + gradient); + } + + @Override + public String toString() { + return "AdaDelta{" + + "learningRate=" + learningRate + + ", rho=" + rho + + ", epsilon=" + epsilon + + '}'; + } + + @Override + public String getOptimizerName() { + return "Adadelta"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java new file mode 100644 index 00000000000..6418e7b9d69 --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java @@ -0,0 +1,71 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; + +/** + * Optimizer that implements the Adagrad algorithm. + * + * See the paper + * or this intro. + */ +public class AdaGrad extends Optimizer { + + public static final String ACCUMULATOR = "accumulator"; + + private final float learningRate; + + private final float initialAccumulatorValue; + + public AdaGrad(Graph graph, float learningRate) { + this(graph, learningRate, 0.01f); + } + + public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { + super(graph); + this.learningRate = learningRate; + this.initialAccumulatorValue = initialAccumulatorValue; + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createAdaGradSlot(v); + } + } + + private void createAdaGradSlot(Output v) { + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), ACCUMULATOR, initializer); + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable slot = (Variable) getSlot(variable,ACCUMULATOR).get(); + return tf.train.applyAdagrad(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient); + } + + @Override + public String toString() { + return "AdaGrad{" + + "learningRate=" + learningRate + + ", initialAccumulatorValue=" + initialAccumulatorValue + + '}'; + } + + @Override + public String getOptimizerName() { + return "Adagrad"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java new file mode 100644 index 00000000000..1a4ff11d623 --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java @@ -0,0 +1,117 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.nio.nd.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +import java.util.List; +import java.util.Optional; + +/** + * Optimizer that implements the Adagrad Dual-Averaging algorithm. + * + * See the paper. + */ +public class AdaGradDA extends Optimizer { + + public static final String ACCUMULATOR = "gradient_accumulator"; + public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator"; + + private Variable globalStep; + + private final float learningRate; + + private final float initialAccumulatorValue; + + private final float l1Strength; + + private final float l2Strength; + + public AdaGradDA(Graph graph, float learningRate) { + this(graph, learningRate, 0.1f, 0.0f, 0.0f); + } + + public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, float l1Strength, float l2Strength) { + super(graph); + this.learningRate = learningRate; + this.initialAccumulatorValue = initialAccumulatorValue; + this.l1Strength = l1Strength; + this.l2Strength = l2Strength; + } + + @Override + protected Optional prepare(String name) { + return Optional.of(tf.assignAdd(globalStep,tf.constant(1L))); + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createAdaGradDASlot(v); + } + globalStep = tf.withName("adagrad-da-global-step").variable(Shape.make(),TInt64.DTYPE); + Assign globalStepInitializer = tf.assign(globalStep, tf.constant(0L)); + graph.addInitializer(globalStepInitializer); + } + + private void createAdaGradDASlot(Output v) { + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), ACCUMULATOR, initializer); + Operand sqInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable gradSlot = (Variable) getSlot(variable,ACCUMULATOR).get(); + @SuppressWarnings("unchecked") + Variable gradSquaredSlot = (Variable) getSlot(variable,SQUARED_ACCUMULATOR).get(); + return tf.train.applyAdagradDa(variable, gradSlot, gradSquaredSlot, gradient, + tf.constant(learningRate, gradient.dataType()), + tf.constant(l1Strength, gradient.dataType()), + tf.constant(l2Strength, gradient.dataType()), + globalStep); + } + + /** + * Gathers up the update operations into a single op that can be used as a run target. + *

+ * Adds the global step update to the end of the updates list. + * @param updateOperations The update operations. + * @param name The name of the run target. + * @return A NoOp with a control dependency on each update operation. + */ + @Override + protected Op finish(List> updateOperations, String name) { + updateOperations.add(tf.assignAdd(globalStep,tf.constant(1L))); + return super.finish(updateOperations,name); + } + + @Override + public String toString() { + return "AdaGradDA{" + + "globalStep=" + globalStep + + ", learningRate=" + learningRate + + ", initialAccumulatorValue=" + initialAccumulatorValue + + ", l1Strength=" + l1Strength + + ", l2Strength=" + l2Strength + + '}'; + } + + @Override + public String getOptimizerName() { + return "adagrad-da"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java new file mode 100644 index 00000000000..bc0342e7c04 --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java @@ -0,0 +1,131 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.nio.nd.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; +import java.util.Optional; + +/** + * Optimizer that implements the Adam algorithm. + * + * See the paper. + */ +public class Adam extends Optimizer { + + public static final String FIRST_MOMENT = "m"; + public static final String SECOND_MOMENT = "v"; + + private final float learningRate; + + private final float betaOne; + + private final float betaTwo; + + private final float epsilon; + + private Constant learningRateConst; + private Constant epsilonConst; + private Constant betaOneConst; + private Constant betaTwoConst; + private Variable betaOnePower; + private Variable betaTwoPower; + + public Adam(Graph graph, float learningRate) { + this(graph, learningRate, 0.9f, 0.999f, 1e-8f); + } + + public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { + super(graph); + this.learningRate = learningRate; + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createAdamSlot(v); + } + betaOnePower = tf.withName("beta1_power").variable(Shape.make(),TFloat.DTYPE); + Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne, TFloat.DTYPE)); + graph.addInitializer(betaOnePowerInit); + betaTwoPower = tf.withName("beta2_power").variable(Shape.make(),TFloat.DTYPE); + Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo, TFloat.DTYPE)); + graph.addInitializer(betaTwoPowerInit); + } + + @Override + protected Optional prepare(String scopeName) { + betaOneConst = tf.constant(betaOne); + betaTwoConst = tf.constant(betaTwo); + learningRateConst = tf.constant(learningRate); + epsilonConst = tf.constant(epsilon); + return Optional.empty(); + } + + private void createAdamSlot(Output v) { + Operand firstMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); + Operand secondMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable firstMomentSlot = (Variable) getSlot(variable,FIRST_MOMENT).get(); + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable secondMomentSlot = (Variable) getSlot(variable,SECOND_MOMENT).get(); + return tf.train.applyAdam(variable, firstMomentSlot, secondMomentSlot, + tf.dtypes.cast(betaOnePower,gradient.dataType()), + tf.dtypes.cast(betaTwoPower,gradient.dataType()), + tf.dtypes.cast(learningRateConst,gradient.dataType()), + tf.dtypes.cast(betaOneConst,gradient.dataType()), + tf.dtypes.cast(betaTwoConst,gradient.dataType()), + tf.dtypes.cast(epsilonConst,gradient.dataType()), + gradient); + } + + /** + * Gathers up the update operations into a single op that can be used as a run target. + *

+ * Adds the betaOne and betaTwo updates to the end of the updates list. + * @param updateOperations The update operations. + * @param name The name of the run target. + * @return A NoOp with a control dependency on each update operation. + */ + @Override + protected Op finish(List> updateOperations, String name) { + updateOperations.add(tf.assign(betaOnePower,tf.math.mul(betaOnePower,betaOneConst))); + updateOperations.add(tf.assign(betaTwoPower,tf.math.mul(betaTwoPower,betaTwoConst))); + return super.finish(updateOperations,name); + } + + @Override + public String toString() { + return "Adam{" + + "learningRate=" + learningRate + + ", betaOne=" + betaOne + + ", betaTwo=" + betaTwo + + ", epsilon=" + epsilon + + '}'; + } + + @Override + public String getOptimizerName() { + return "Adam"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java new file mode 100644 index 00000000000..e7aa095367a --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java @@ -0,0 +1,45 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.train.ApplyMomentum; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; + +/** + * Basic SGD. + */ +public class GradientDescent extends Optimizer { + + private final float learningRate; + + public GradientDescent(Graph graph, float learningRate) { + super(graph); + this.learningRate = learningRate; + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + return tf.train.applyGradientDescent(variable, tf.constant(learningRate, gradient.dataType()), gradient); + } + + @Override + public String toString() { + return "GradientDescent{" + + "learningRate=" + learningRate + + '}'; + } + + @Override + public String getOptimizerName() { + return "GradientDescent"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java new file mode 100644 index 00000000000..5feeb9faa1e --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java @@ -0,0 +1,72 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.op.train.ApplyMomentum; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; + +/** + * SGD plus momentum, either nesterov or traditional. + * + * See the paper for + * details of nesterov momentum. + */ +public class Momentum extends Optimizer { + + public static final String MOMENTUM = "momentum"; + + private final float learningRate; + + private final float momentum; + + private final boolean useNesterov; + + public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { + super(graph); + this.learningRate = learningRate; + this.momentum = momentum; + this.useNesterov = useNesterov; + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createMomentumSlot(v); + } + } + + private void createMomentumSlot(Output v) { + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), MOMENTUM, initializer); + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable slot = (Variable) getSlot(variable,MOMENTUM).get(); + return tf.train.applyMomentum(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient, tf.constant(momentum, gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); + } + + @Override + public String toString() { + return "Momentum{" + + "learningRate=" + learningRate + + ", momentum=" + momentum + + ", useNesterov=" + useNesterov + + '}'; + } + + @Override + public String getOptimizerName() { + return "Momentum"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java new file mode 100644 index 00000000000..494a2d650fb --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java @@ -0,0 +1,207 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.NoOp; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; +import org.tensorflow.sandbox.util.Pair; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * + */ +public abstract class Optimizer { + public static final String VARIABLE_V2 = "VariableV2"; + + /** + * Top level map is variable name, bottom map is slot name. + */ + private final Map>> slots; + + /** + * Global state variables + */ + //TODO make this be used. + protected final List globals; + + /** + * The Graph this optimizer is operating on. + */ + protected final Graph graph; + + /** + * The ops builder for the graph. + */ + protected final Ops tf; + + protected Optimizer(Graph graph) { + this.graph = graph; + this.tf = Ops.create(graph).withName(getOptimizerName()); + this.slots = new HashMap<>(); + this.globals = new ArrayList<>(); + } + + public Op minimize(Operand loss) { + return minimize(loss, getOptimizerName()+"-minimize"); + } + + public Op minimize(Operand loss, String name) { + List, Output>> gradsAndVars = computeGradients(loss); + + return applyGradients(gradsAndVars, name); + } + + public List, Output>> computeGradients(Operand loss) { + List variables = new ArrayList<>(); + Iterator opItr = graph.operations(); + while (opItr.hasNext()) { + Operation op = opItr.next(); + if (op.type().equals(VARIABLE_V2)) { + variables.add(op); + } + } + + Output[] variableOutputArray = new Output[variables.size()]; + for (int i = 0; i < variables.size(); i++) { + // First output of a variable is it's output. + variableOutputArray[i] = variables.get(i).output(0); + } + + Output[] gradients = graph.addGradients(loss.asOutput(), variableOutputArray); + List, Output>> gradVarPairs = new ArrayList<>(); + + for (int i = 0; i < variableOutputArray.length; i++) { + gradVarPairs.add(new Pair<>(gradients[i], (Output)variableOutputArray[i])); + } + + return gradVarPairs; + } + + public Op applyGradients(List, Output>> gradsAndVars, String name) { + List> variables = gradsAndVars.stream().map(Pair::getB).collect(Collectors.toList()); + + createSlots(variables); + + Optional prepOp = prepare(name+"/prepare"); + + List> updateOps = new ArrayList<>(); + prepOp.ifPresent(updateOps::add); + for (Pair pair : gradsAndVars) { + updateOps.add(applyDense((Output)pair.getA(),(Output)pair.getB())); + } + + return finish(updateOps,name); + } + + /** + * Gets the slot associated with the specified variable and slot name. + * @param var The variable to lookup. + * @param slotName The slot name. + * @return The slot or {@link Optional#empty}. + */ + public Optional> getSlot(Output var, String slotName) { + return getSlot(var.op().name(),slotName); + } + + /** + * Gets the slot associated with the specified variable and slot name. + * @param varName The variable to lookup. + * @param slotName The slot name. + * @return The slot or {@link Optional#empty}. + */ + public Optional> getSlot(String varName, String slotName) { + Map> variables = slots.get(slotName); + if (variables != null) { + Variable slot = variables.get(varName); + if (slot != null) { + return Optional.of(slot); + } else { + return Optional.empty(); + } + } else { + return Optional.empty(); + } + } + + /** + * Creates a slot in the graph for the specified variable with the specified name. Adds the slot's initializer + * to the graph's initializers, and the slot to the Optimizer's slot map. + * @param variable The variable to create the slot for. + * @param slotName The name of the slot. + * @param initializer The initializer for the slot. + * @param The type of the variable. + */ + protected void createSlot(Output variable, String slotName, Operand initializer) { + Variable slot = (Variable) tf.withName(createName(variable, slotName)).variable(variable.shape(), TFloat.DTYPE); + Assign slotInit = tf.assign(slot, initializer); + graph.addInitializer(slotInit); + String varName = variable.op().name(); + Map> variables = slots.computeIfAbsent(slotName,(k) -> new HashMap<>()); + variables.put(varName,slot); + } + + /** + * No-op prepare method. + * + * @param scopeName The scope name to use for any variable creations. + */ + protected Optional prepare(String scopeName) { + return Optional.empty(); + } + + /** + * No-op slot creation method. + * @param variables The variables to create slots for. + */ + protected void createSlots(List> variables) { } + + /** + * Generates + * @param gradient + * @param variable + * @param + * @return + */ + protected abstract Operand applyDense(Output gradient, Output variable); + + /** + * Gathers up the update operations into a single op that can be used as a run target. + * @param updateOperations The update operations. + * @param name The name of the run target. + * @return A NoOp with a control dependency on each update operation. + */ + protected Op finish(List> updateOperations, String name) { + Scope scope = new Scope(graph); + scope = scope.withName(name); + scope = scope.withControlDependencies(updateOperations); + return NoOp.create(scope); + } + + /** + * Name of the optimizer. + * @return The optimizer name. + */ + public abstract String getOptimizerName(); + + public static String createName(Output variable, String slotName) { + return variable.op().name() + "-" + slotName; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java new file mode 100644 index 00000000000..0f1a1f4c85b --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java @@ -0,0 +1,105 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + */ +package org.tensorflow.sandbox.optimizers; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat; +import org.tensorflow.types.family.TType; + +import java.util.List; + +/** + * Optimizer that implements the RMSProp algorithm. + * + * See the lecture notes + * that is inexplicably the canonical reference. + */ +public class RMSProp extends Optimizer { + + public static final String RMS = "rms"; + public static final String MG = "mg"; // mean gradient? + public static final String MOMENTUM = "momentum"; + + private final float learningRate; + private final float decay; + private final float momentum; + private final float epsilon; + private final boolean centered; + + public RMSProp(Graph graph, float learningRate) { + this(graph, learningRate, 0.9f, 0.0f, 1e-10f, false); + } + + public RMSProp(Graph graph, float learningRate, float decay, float momentum, float epsilon, boolean centered) { + super(graph); + this.learningRate = learningRate; + this.decay = decay; + this.momentum = momentum; + this.epsilon = epsilon; + this.centered = centered; + } + + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createRMSPropSlot(v); + } + } + + private void createRMSPropSlot(Output v) { + Operand rmsInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(1.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), RMS, rmsInitializer); + Operand momentumInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), MOMENTUM, momentumInitializer); + if (centered) { + Operand mgInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + createSlot(v.asOutput(), MG, mgInitializer); + } + } + + @Override + protected Operand applyDense(Output gradient, Output variable) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable rmsSlot = (Variable) getSlot(variable,RMS).get(); + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable momentumSlot = (Variable) getSlot(variable,MOMENTUM).get(); + if (centered) { + @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. + Variable mgSlot = (Variable) getSlot(variable, MG).get(); + return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, + tf.constant(learningRate, gradient.dataType()), + tf.constant(decay, gradient.dataType()), + tf.constant(momentum, gradient.dataType()), + tf.constant(epsilon, gradient.dataType()), + gradient); + } else { + return tf.train.applyRmsProp(variable, rmsSlot, momentumSlot, + tf.constant(learningRate, gradient.dataType()), + tf.constant(decay, gradient.dataType()), + tf.constant(momentum, gradient.dataType()), + tf.constant(epsilon, gradient.dataType()), + gradient); + } + } + + @Override + public String toString() { + return "RMSProp{" + + "learningRate=" + learningRate + + ", decay=" + decay + + ", momentum=" + momentum + + ", epsilon=" + epsilon + + ", centered=" + centered + + '}'; + } + + @Override + public String getOptimizerName() { + return "RMSProp"; + } +} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java new file mode 100644 index 00000000000..8e6e1a0b3ea --- /dev/null +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java @@ -0,0 +1,56 @@ +package org.tensorflow.sandbox.util; + +import java.io.Serializable; +import java.util.Objects; + +/** + * An immutable pair of things. + * + * @param The type of the first object. + * @param The type of the second object. + */ +public class Pair implements Serializable { + private static final long serialVersionUID = 1L; + + private final T1 a; + + private final T2 b; + + public Pair(T1 a, T2 b) { + this.a = a; + this.b = b; + } + + public T1 getA() { + return a; + } + + public T2 getB() { + return b; + } + + @Override + public int hashCode() { + return a.hashCode() ^ b.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (!(obj instanceof Pair)) { + return false; + } + final Pair other = (Pair) obj; + if (!Objects.equals(this.a, other.a)) { + return false; + } + return Objects.equals(this.b, other.b); + } + + @Override + public String toString() { + return "Pair{" + "a=" + a + ", b=" + b + '}'; + } +} From e7eb2e8c4e3597f4e05184af9ae70181f346e744 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 5 Nov 2019 11:14:39 -0500 Subject: [PATCH 09/22] Adding Apache 2.0 license header to all optimizer files. --- .../java/org/tensorflow/sandbox/MNISTTest.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/AdaDelta.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/AdaGrad.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/AdaGradDA.java | 12 ++++++++++++ .../org/tensorflow/sandbox/optimizers/Adam.java | 12 ++++++++++++ .../sandbox/optimizers/GradientDescent.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/Momentum.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/Optimizer.java | 12 ++++++++++++ .../tensorflow/sandbox/optimizers/RMSProp.java | 12 ++++++++++++ .../java/org/tensorflow/sandbox/util/Pair.java | 15 +++++++++++++++ 10 files changed, 123 insertions(+) diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java index 76ffd102371..73932c109ec 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java index f0bc847bdc5..687687a0661 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java index 6418e7b9d69..6d3240a2ffe 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java index 1a4ff11d623..c71735ad35d 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java index bc0342e7c04..1337163abf6 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java index e7aa095367a..efb067f68e4 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java index 5feeb9faa1e..34b94ed060b 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java index 494a2d650fb..e7e7e87f968 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java index 0f1a1f4c85b..b34b22c9a5c 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java @@ -1,5 +1,17 @@ /* * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package org.tensorflow.sandbox.optimizers; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java index 8e6e1a0b3ea..07560d1e56d 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java @@ -1,3 +1,18 @@ +/* + * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.sandbox.util; import java.io.Serializable; From 3d63564458081e5c5bca6405aed0952ef346d383 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 6 Dec 2019 16:40:19 -0500 Subject: [PATCH 10/22] Bug fix for the MNISTTest. --- .../src/main/java/org/tensorflow/sandbox/MNISTTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java index 73932c109ec..c25147f7134 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java @@ -90,7 +90,7 @@ public static Graph build(String optimizerName) { // Scaling the features Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); - Operand scaledInput = tf.math.div(tf.math.add(input, centeringFactor), scalingFactor); + Operand scaledInput = tf.math.div(tf.math.sub(input, centeringFactor), scalingFactor); // First conv layer Variable conv1Weights = tf.variable(Shape.make(5, 5, NUM_CHANNELS, 32), TFloat.DTYPE); From ed71dc55ca5cad1d365f548166e0e853b62ba8e2 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 31 Jan 2020 12:32:22 -0500 Subject: [PATCH 11/22] Refactor to uptake latest tensorflow-core changes. --- .../org/tensorflow/sandbox/MNISTTest.java | 76 +++++++++---------- .../sandbox/optimizers/AdaDelta.java | 8 +- .../sandbox/optimizers/AdaGrad.java | 6 +- .../sandbox/optimizers/AdaGradDA.java | 12 +-- .../tensorflow/sandbox/optimizers/Adam.java | 30 ++++---- .../sandbox/optimizers/GradientDescent.java | 7 +- .../sandbox/optimizers/Momentum.java | 6 +- .../sandbox/optimizers/Optimizer.java | 20 ++--- .../sandbox/optimizers/RMSProp.java | 10 +-- .../org/tensorflow/sandbox/util/Pair.java | 2 +- 10 files changed, 86 insertions(+), 91 deletions(-) diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java index c25147f7134..a508046f5d7 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; -import org.tensorflow.nio.nd.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Assign; @@ -44,7 +43,8 @@ import org.tensorflow.sandbox.optimizers.Momentum; import org.tensorflow.sandbox.optimizers.Optimizer; import org.tensorflow.sandbox.optimizers.RMSProp; -import org.tensorflow.types.TFloat; +import org.tensorflow.tools.Shape; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import java.io.BufferedInputStream; @@ -84,71 +84,71 @@ public static Graph build(String optimizerName) { Ops tf = Ops.create(graph); // Inputs - Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat.DTYPE, Placeholder.shape(Shape.make(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); + Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.make(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); Placeholder labels = tf.withName(TARGET).placeholder(TInt32.DTYPE); // Scaling the features - Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); - Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); - Operand scaledInput = tf.math.div(tf.math.sub(input, centeringFactor), scalingFactor); + Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); + Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); + Operand scaledInput = tf.math.div(tf.math.sub(input, centeringFactor), scalingFactor); // First conv layer - Variable conv1Weights = tf.variable(Shape.make(5, 5, NUM_CHANNELS, 32), TFloat.DTYPE); - Assign weights1Init = tf.assign(conv1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv1Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable conv1Weights = tf.variable(Shape.make(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE); + Assign weights1Init = tf.assign(conv1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv1Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); graph.addInitializer(weights1Init); - Conv2d conv1 = tf.nn.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); - Variable conv1Biases = tf.variable(Shape.make(32), TFloat.DTYPE); - Assign biases1Init = tf.assign(conv1Biases, tf.fill(tf.shape(conv1Biases), tf.constant(0.0f))); + Conv2d conv1 = tf.nn.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable conv1Biases = tf.variable(Shape.make(32), TFloat32.DTYPE); + Assign biases1Init = tf.assign(conv1Biases, tf.fill(tf.shape(conv1Biases), tf.constant(0.0f))); graph.addInitializer(biases1Init); - Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); + Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); // First pooling layer - MaxPool pool1 = tf.nn.maxPool(relu1, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); + MaxPool pool1 = tf.nn.maxPool(relu1, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); // Second conv layer - Variable conv2Weights = tf.variable(Shape.make(5, 5, 32, 64), TFloat.DTYPE); - Assign weights2Init = tf.assign(conv2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv2Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable conv2Weights = tf.variable(Shape.make(5, 5, 32, 64), TFloat32.DTYPE); + Assign weights2Init = tf.assign(conv2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv2Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); graph.addInitializer(weights2Init); - Conv2d conv2 = tf.nn.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); - Variable conv2Biases = tf.variable(Shape.make(64), TFloat.DTYPE); - Assign biases2Init = tf.assign(conv2Biases, tf.fill(tf.shape(conv2Biases), tf.constant(0.1f))); + Conv2d conv2 = tf.nn.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable conv2Biases = tf.variable(Shape.make(64), TFloat32.DTYPE); + Assign biases2Init = tf.assign(conv2Biases, tf.fill(tf.shape(conv2Biases), tf.constant(0.1f))); graph.addInitializer(biases2Init); - Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); + Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); // Second pooling layer - MaxPool pool2 = tf.nn.maxPool(relu2, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); + MaxPool pool2 = tf.nn.maxPool(relu2, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); // Flatten inputs - Reshape flatten = tf.reshape(pool2, tf.concat(Arrays.asList(tf.slice(tf.shape(pool2), tf.constant(new int[]{0}), tf.constant(new int[]{1})), tf.constant(new int[]{-1})), tf.constant(0))); + Reshape flatten = tf.reshape(pool2, tf.concat(Arrays.asList(tf.slice(tf.shape(pool2), tf.constant(new int[]{0}), tf.constant(new int[]{1})), tf.constant(new int[]{-1})), tf.constant(0))); // Fully connected layer - Variable fc1Weights = tf.variable(Shape.make(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat.DTYPE); - Assign weights3Init = tf.assign(fc1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc1Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc1Weights = tf.variable(Shape.make(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE); + Assign weights3Init = tf.assign(fc1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc1Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); graph.addInitializer(weights3Init); - Variable fc1Biases = tf.variable(Shape.make(512), TFloat.DTYPE); - Assign biases3Init = tf.assign(fc1Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc1Biases))); + Variable fc1Biases = tf.variable(Shape.make(512), TFloat32.DTYPE); + Assign biases3Init = tf.assign(fc1Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc1Biases))); graph.addInitializer(biases3Init); - Relu relu3 = tf.nn.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); + Relu relu3 = tf.nn.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); // Softmax layer - Variable fc2Weights = tf.variable(Shape.make(512, NUM_LABELS), TFloat.DTYPE); - Assign weights4Init = tf.assign(fc2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc2Weights), TFloat.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc2Weights = tf.variable(Shape.make(512, NUM_LABELS), TFloat32.DTYPE); + Assign weights4Init = tf.assign(fc2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc2Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); graph.addInitializer(weights4Init); - Variable fc2Biases = tf.variable(Shape.make(NUM_LABELS), TFloat.DTYPE); - Assign biases4Init = tf.assign(fc2Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc2Biases))); + Variable fc2Biases = tf.variable(Shape.make(NUM_LABELS), TFloat32.DTYPE); + Assign biases4Init = tf.assign(fc2Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc2Biases))); graph.addInitializer(biases4Init); - Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); + Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); // Predicted outputs - Softmax prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); + Softmax prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); // Loss function & regularization - OneHot oneHot = tf.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); - SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); - Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); - Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math.add(tf.nn.l2Loss(fc1Biases), tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); - Add loss = tf.withName(TRAINING_LOSS).math.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); + OneHot oneHot = tf.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); + Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); + Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math.add(tf.nn.l2Loss(fc1Biases), tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); + Add loss = tf.withName(TRAINING_LOSS).math.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); // Optimizer Optimizer optimizer; diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java index 687687a0661..415e62c176d 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -60,9 +60,9 @@ protected void createSlots(List> variables) { } private void createAdaDeltaSlot(Output v) { - Operand accumulatorInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand accumulatorInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); - Operand updateInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand updateInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java index 6d3240a2ffe..11cf7ba867a 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -57,7 +57,7 @@ protected void createSlots(List> variables) { } private void createAdaGradSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java index c71735ad35d..e2c591cd9ea 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,12 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.nio.nd.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.tools.Shape; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -72,15 +72,15 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdaGradDASlot(v); } - globalStep = tf.withName("adagrad-da-global-step").variable(Shape.make(),TInt64.DTYPE); + globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(),TInt64.DTYPE); Assign globalStepInitializer = tf.assign(globalStep, tf.constant(0L)); graph.addInitializer(globalStepInitializer); } private void createAdaGradDASlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); - Operand sqInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat.DTYPE));//v.dataType())); + Operand sqInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java index 1337163abf6..c06ac96a467 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,12 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.nio.nd.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.tools.Shape; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -47,12 +47,12 @@ public class Adam extends Optimizer { private final float epsilon; - private Constant learningRateConst; - private Constant epsilonConst; - private Constant betaOneConst; - private Constant betaTwoConst; - private Variable betaOnePower; - private Variable betaTwoPower; + private Constant learningRateConst; + private Constant epsilonConst; + private Constant betaOneConst; + private Constant betaTwoConst; + private Variable betaOnePower; + private Variable betaTwoPower; public Adam(Graph graph, float learningRate) { this(graph, learningRate, 0.9f, 0.999f, 1e-8f); @@ -71,11 +71,11 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdamSlot(v); } - betaOnePower = tf.withName("beta1_power").variable(Shape.make(),TFloat.DTYPE); - Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne, TFloat.DTYPE)); + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(),TFloat32.DTYPE); + Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne, TFloat32.DTYPE)); graph.addInitializer(betaOnePowerInit); - betaTwoPower = tf.withName("beta2_power").variable(Shape.make(),TFloat.DTYPE); - Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo, TFloat.DTYPE)); + betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(),TFloat32.DTYPE); + Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo, TFloat32.DTYPE)); graph.addInitializer(betaTwoPowerInit); } @@ -89,9 +89,9 @@ protected Optional prepare(String scopeName) { } private void createAdamSlot(Output v) { - Operand firstMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand firstMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); - Operand secondMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand secondMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java index efb067f68e4..fd0f264d664 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +18,8 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.Variable; -import org.tensorflow.op.train.ApplyMomentum; -import org.tensorflow.types.TFloat; import org.tensorflow.types.family.TType; -import java.util.List; /** * Basic SGD. diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java index 34b94ed060b..e0b4fbbac49 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; -import org.tensorflow.types.TFloat; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -57,7 +57,7 @@ protected void createSlots(List> variables) { } private void createMomentumSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), MOMENTUM, initializer); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java index e7e7e87f968..1ac3ef1ceed 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.NoOp; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import org.tensorflow.sandbox.util.Pair; @@ -44,7 +44,7 @@ public abstract class Optimizer { public static final String VARIABLE_V2 = "VariableV2"; /** - * Top level map is variable name, bottom map is slot name. + * Top level map key is the variable name, lower level map key is the slot name. */ private final Map>> slots; @@ -101,7 +101,7 @@ public List, Output>> computeGradients(Operand, Output>> gradVarPairs = new ArrayList<>(); for (int i = 0; i < variableOutputArray.length; i++) { - gradVarPairs.add(new Pair<>(gradients[i], (Output)variableOutputArray[i])); + gradVarPairs.add(new Pair<>(gradients[i], variableOutputArray[i])); } return gradVarPairs; @@ -162,7 +162,7 @@ public Optional> getSlot(String varName, String slotName) { * @param The type of the variable. */ protected void createSlot(Output variable, String slotName, Operand initializer) { - Variable slot = (Variable) tf.withName(createName(variable, slotName)).variable(variable.shape(), TFloat.DTYPE); + Variable slot = (Variable) tf.withName(createName(variable, slotName)).variable(variable.shape(), TFloat32.DTYPE); Assign slotInit = tf.assign(slot, initializer); graph.addInitializer(slotInit); String varName = variable.op().name(); @@ -186,11 +186,11 @@ protected Optional prepare(String scopeName) { protected void createSlots(List> variables) { } /** - * Generates - * @param gradient - * @param variable - * @param - * @return + * Generates the gradient update operations for the specific variable and gradient. + * @param gradient The gradient to use. + * @param variable The variable to update. + * @param The type of the variable. + * @return An operand which applies the desired optimizer update to the variable. */ protected abstract Operand applyDense(Output gradient, Output variable); diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java index b34b22c9a5c..9ba293b46f6 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -64,12 +64,12 @@ protected void createSlots(List> variables) { } private void createRMSPropSlot(Output v) { - Operand rmsInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(1.0f, TFloat.DTYPE));//v.dataType())); + Operand rmsInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(1.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), RMS, rmsInitializer); - Operand momentumInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand momentumInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand mgInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat.DTYPE));//v.dataType())); + Operand mgInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); createSlot(v.asOutput(), MG, mgInitializer); } } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java index 07560d1e56d..3160b0f2b8d 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java @@ -1,5 +1,5 @@ /* - * Copyright © 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From b0544494775ede7f856f454cfc2f3f8c20303f24 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 31 Jan 2020 15:30:22 -0500 Subject: [PATCH 12/22] Added type safety and updates for new api. --- .../sandbox/optimizers/AdaDelta.java | 11 ++- .../sandbox/optimizers/AdaGrad.java | 6 +- .../sandbox/optimizers/AdaGradDA.java | 13 ++-- .../tensorflow/sandbox/optimizers/Adam.java | 14 ++-- .../sandbox/optimizers/GradientDescent.java | 3 +- .../sandbox/optimizers/Momentum.java | 6 +- .../sandbox/optimizers/Optimizer.java | 70 ++++++++++++------ .../sandbox/optimizers/RMSProp.java | 16 ++--- .../org/tensorflow/sandbox/util/Pair.java | 71 ------------------- 9 files changed, 75 insertions(+), 135 deletions(-) delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java index 415e62c176d..c4cd4079df2 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java @@ -18,7 +18,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -60,18 +59,16 @@ protected void createSlots(List> variables) { } private void createAdaDeltaSlot(Output v) { - Operand accumulatorInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand accumulatorInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); - Operand updateInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand updateInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable accumSlot = (Variable) getSlot(variable,ACCUMULATOR).get(); - @SuppressWarnings("unchecked") - Variable accumUpdateSlot = (Variable) getSlot(variable,ACCUMULATOR_UPDATE).get(); + Variable accumSlot = getSlot(variable,ACCUMULATOR).get(); + Variable accumUpdateSlot = getSlot(variable,ACCUMULATOR_UPDATE).get(); return tf.train.applyAdadelta(variable, accumSlot, accumUpdateSlot, tf.constant(learningRate, gradient.dataType()), tf.constant(rho, gradient.dataType()), diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java index 11cf7ba867a..00a56f5853a 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java @@ -18,7 +18,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -57,14 +56,13 @@ protected void createSlots(List> variables) { } private void createAdaGradSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat32.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable slot = (Variable) getSlot(variable,ACCUMULATOR).get(); + Variable slot = getSlot(variable,ACCUMULATOR).get(); return tf.train.applyAdagrad(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java index e2c591cd9ea..753e104c8a1 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java @@ -20,7 +20,6 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.tools.Shape; import org.tensorflow.types.TFloat32; @@ -63,7 +62,7 @@ public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, } @Override - protected Optional prepare(String name) { + protected Optional> prepare(String name) { return Optional.of(tf.assignAdd(globalStep,tf.constant(1L))); } @@ -78,18 +77,16 @@ protected void createSlots(List> variables) { } private void createAdaGradDASlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); - Operand sqInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(initialAccumulatorValue, TFloat32.DTYPE));//v.dataType())); + Operand sqInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable gradSlot = (Variable) getSlot(variable,ACCUMULATOR).get(); - @SuppressWarnings("unchecked") - Variable gradSquaredSlot = (Variable) getSlot(variable,SQUARED_ACCUMULATOR).get(); + Variable gradSlot = getSlot(variable,ACCUMULATOR).get(); + Variable gradSquaredSlot = getSlot(variable,SQUARED_ACCUMULATOR).get(); return tf.train.applyAdagradDa(variable, gradSlot, gradSquaredSlot, gradient, tf.constant(learningRate, gradient.dataType()), tf.constant(l1Strength, gradient.dataType()), diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java index c06ac96a467..7b18fab0354 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java @@ -69,7 +69,7 @@ public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float @Override protected void createSlots(List> variables) { for (Output v : variables) { - createAdamSlot(v); + createAdamSlot(v.asOutput()); } betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(),TFloat32.DTYPE); Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne, TFloat32.DTYPE)); @@ -80,7 +80,7 @@ protected void createSlots(List> variables) { } @Override - protected Optional prepare(String scopeName) { + protected Optional> prepare(String scopeName) { betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); learningRateConst = tf.constant(learningRate); @@ -89,18 +89,16 @@ protected Optional prepare(String scopeName) { } private void createAdamSlot(Output v) { - Operand firstMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand firstMomentInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); - Operand secondMomentInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand secondMomentInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable firstMomentSlot = (Variable) getSlot(variable,FIRST_MOMENT).get(); - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable secondMomentSlot = (Variable) getSlot(variable,SECOND_MOMENT).get(); + Variable firstMomentSlot = getSlot(variable,FIRST_MOMENT).get(); + Variable secondMomentSlot = getSlot(variable,SECOND_MOMENT).get(); return tf.train.applyAdam(variable, firstMomentSlot, secondMomentSlot, tf.dtypes.cast(betaOnePower,gradient.dataType()), tf.dtypes.cast(betaTwoPower,gradient.dataType()), diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java index fd0f264d664..c95398abe6f 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java @@ -18,6 +18,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -35,7 +36,7 @@ public GradientDescent(Graph graph, float learningRate) { @Override protected Operand applyDense(Output gradient, Output variable) { - return tf.train.applyGradientDescent(variable, tf.constant(learningRate, gradient.dataType()), gradient); + return tf.train.applyGradientDescent(variable, tf.dtypes.cast(tf.constant(learningRate, TFloat32.DTYPE), gradient.dataType()), gradient); } @Override diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java index e0b4fbbac49..60f3497d570 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java @@ -18,7 +18,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; import org.tensorflow.types.TFloat32; @@ -57,14 +56,13 @@ protected void createSlots(List> variables) { } private void createMomentumSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), MOMENTUM, initializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable slot = (Variable) getSlot(variable,MOMENTUM).get(); + Variable slot = getSlot(variable,MOMENTUM).get(); return tf.train.applyMomentum(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient, tf.constant(momentum, gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java index 1ac3ef1ceed..37467a50f34 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java @@ -25,9 +25,7 @@ import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.NoOp; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import org.tensorflow.sandbox.util.Pair; import java.util.ArrayList; import java.util.HashMap; @@ -52,7 +50,7 @@ public abstract class Optimizer { * Global state variables */ //TODO make this be used. - protected final List globals; + protected final List> globals; /** * The Graph this optimizer is operating on. @@ -76,12 +74,12 @@ public Op minimize(Operand loss) { } public Op minimize(Operand loss, String name) { - List, Output>> gradsAndVars = computeGradients(loss); + List> gradsAndVars = computeGradients(loss); return applyGradients(gradsAndVars, name); } - public List, Output>> computeGradients(Operand loss) { + public List> computeGradients(Operand loss) { List variables = new ArrayList<>(); Iterator opItr = graph.operations(); while (opItr.hasNext()) { @@ -98,26 +96,30 @@ public List, Output>> computeGradients(Operand[] gradients = graph.addGradients(loss.asOutput(), variableOutputArray); - List, Output>> gradVarPairs = new ArrayList<>(); + List> gradVarPairs = new ArrayList<>(); for (int i = 0; i < variableOutputArray.length; i++) { - gradVarPairs.add(new Pair<>(gradients[i], variableOutputArray[i])); + @SuppressWarnings("unchecked") + Output typedGrad = (Output) gradients[i]; + @SuppressWarnings("unchecked") + Output typedVar = (Output) variableOutputArray[i]; + gradVarPairs.add(new GradAndVar<>(typedGrad, typedVar)); } return gradVarPairs; } - public Op applyGradients(List, Output>> gradsAndVars, String name) { - List> variables = gradsAndVars.stream().map(Pair::getB).collect(Collectors.toList()); + public Op applyGradients(List> gradsAndVars, String name) { + List> variables = gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList()); createSlots(variables); - Optional prepOp = prepare(name+"/prepare"); + Optional> prepOp = prepare(name+"/prepare"); - List> updateOps = new ArrayList<>(); + List> updateOps = new ArrayList<>(); prepOp.ifPresent(updateOps::add); - for (Pair pair : gradsAndVars) { - updateOps.add(applyDense((Output)pair.getA(),(Output)pair.getB())); + for (GradAndVar pair : gradsAndVars) { + updateOps.add(applyDense(pair)); } return finish(updateOps,name); @@ -129,7 +131,7 @@ public Op applyGradients(List, Output>> gradsAnd * @param slotName The slot name. * @return The slot or {@link Optional#empty}. */ - public Optional> getSlot(Output var, String slotName) { + public Optional> getSlot(Output var, String slotName) { return getSlot(var.op().name(),slotName); } @@ -139,12 +141,14 @@ public Optional> getSlot(Output var, String slotName) { * @param slotName The slot name. * @return The slot or {@link Optional#empty}. */ - public Optional> getSlot(String varName, String slotName) { - Map> variables = slots.get(slotName); + private Optional> getSlot(String varName, String slotName) { + Map> variables = slots.get(slotName); if (variables != null) { - Variable slot = variables.get(varName); + Variable slot = variables.get(varName); if (slot != null) { - return Optional.of(slot); + @SuppressWarnings("unchecked") // This method should only be called when the type is known. + Optional> opt = Optional.of((Variable)slot); + return opt; } else { return Optional.empty(); } @@ -162,11 +166,11 @@ public Optional> getSlot(String varName, String slotName) { * @param The type of the variable. */ protected void createSlot(Output variable, String slotName, Operand initializer) { - Variable slot = (Variable) tf.withName(createName(variable, slotName)).variable(variable.shape(), TFloat32.DTYPE); + Variable slot = tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.dataType()); Assign slotInit = tf.assign(slot, initializer); graph.addInitializer(slotInit); String varName = variable.op().name(); - Map> variables = slots.computeIfAbsent(slotName,(k) -> new HashMap<>()); + Map> variables = slots.computeIfAbsent(slotName,(k) -> new HashMap<>()); variables.put(varName,slot); } @@ -175,7 +179,7 @@ protected void createSlot(Output variable, String slotName, * * @param scopeName The scope name to use for any variable creations. */ - protected Optional prepare(String scopeName) { + protected Optional> prepare(String scopeName) { return Optional.empty(); } @@ -185,6 +189,10 @@ protected Optional prepare(String scopeName) { */ protected void createSlots(List> variables) { } + private Operand applyDense(GradAndVar gradVarPair) { + return applyDense(gradVarPair.getGradient(),gradVarPair.getVariable()); + } + /** * Generates the gradient update operations for the specific variable and gradient. * @param gradient The gradient to use. @@ -213,7 +221,25 @@ protected Op finish(List> updateOperations, String name) { */ public abstract String getOptimizerName(); - public static String createName(Output variable, String slotName) { + public static String createName(Output variable, String slotName) { return variable.op().name() + "-" + slotName; } + + public static class GradAndVar { + private final Output gradient; + private final Output variable; + + public GradAndVar(Output gradient, Output variable) { + this.gradient = gradient; + this.variable = variable; + } + + public Output getGradient() { + return gradient; + } + + public Output getVariable() { + return variable; + } + } } diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java index 9ba293b46f6..a20996f0018 100644 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java +++ b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java @@ -18,7 +18,6 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -64,25 +63,22 @@ protected void createSlots(List> variables) { } private void createRMSPropSlot(Output v) { - Operand rmsInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(1.0f, TFloat32.DTYPE));//v.dataType())); + Operand rmsInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), RMS, rmsInitializer); - Operand momentumInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand momentumInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand mgInitializer = tf.fill(tf.shape(v), (Constant) tf.constant(0.0f, TFloat32.DTYPE));//v.dataType())); + Operand mgInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); createSlot(v.asOutput(), MG, mgInitializer); } } @Override protected Operand applyDense(Output gradient, Output variable) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable rmsSlot = (Variable) getSlot(variable,RMS).get(); - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable momentumSlot = (Variable) getSlot(variable,MOMENTUM).get(); + Variable rmsSlot = getSlot(variable,RMS).get(); + Variable momentumSlot = getSlot(variable,MOMENTUM).get(); if (centered) { - @SuppressWarnings("unchecked") // suppressed as the slots are created to have the dtype of the variable. - Variable mgSlot = (Variable) getSlot(variable, MG).get(); + Variable mgSlot = getSlot(variable, MG).get(); return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, tf.constant(learningRate, gradient.dataType()), tf.constant(decay, gradient.dataType()), diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java deleted file mode 100644 index 3160b0f2b8d..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/util/Pair.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox.util; - -import java.io.Serializable; -import java.util.Objects; - -/** - * An immutable pair of things. - * - * @param The type of the first object. - * @param The type of the second object. - */ -public class Pair implements Serializable { - private static final long serialVersionUID = 1L; - - private final T1 a; - - private final T2 b; - - public Pair(T1 a, T2 b) { - this.a = a; - this.b = b; - } - - public T1 getA() { - return a; - } - - public T2 getB() { - return b; - } - - @Override - public int hashCode() { - return a.hashCode() ^ b.hashCode(); - } - - @Override - public boolean equals(Object obj) { - if (obj == null) { - return false; - } - if (!(obj instanceof Pair)) { - return false; - } - final Pair other = (Pair) obj; - if (!Objects.equals(this.a, other.a)) { - return false; - } - return Objects.equals(this.b, other.b); - } - - @Override - public String toString() { - return "Pair{" + "a=" + a + ", b=" + b + '}'; - } -} From b29be50c4a61bcd43afea1a163fd669f6246562a Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 7 Feb 2020 16:55:37 -0500 Subject: [PATCH 13/22] Repackaging the optimizers into tensorflow-training, org.tensorflow.training. --- .../org/tensorflow/sandbox/MNISTTest.java | 357 ------------------ .../sandbox/optimizers/AdaDelta.java | 92 ----- .../sandbox/optimizers/AdaGrad.java | 81 ---- .../sandbox/optimizers/AdaGradDA.java | 126 ------- .../tensorflow/sandbox/optimizers/Adam.java | 141 ------- .../sandbox/optimizers/GradientDescent.java | 53 --- .../sandbox/optimizers/Momentum.java | 82 ---- .../sandbox/optimizers/Optimizer.java | 245 ------------ .../sandbox/optimizers/RMSProp.java | 113 ------ 9 files changed, 1290 deletions(-) delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java delete mode 100644 tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java deleted file mode 100644 index a508046f5d7..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/MNISTTest.java +++ /dev/null @@ -1,357 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Session; -import org.tensorflow.Tensor; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.OneHot; -import org.tensorflow.op.core.Placeholder; -import org.tensorflow.op.core.Reshape; -import org.tensorflow.op.core.Variable; -import org.tensorflow.op.math.Add; -import org.tensorflow.op.math.Mean; -import org.tensorflow.op.nn.Conv2d; -import org.tensorflow.op.nn.MaxPool; -import org.tensorflow.op.nn.Relu; -import org.tensorflow.op.nn.Softmax; -import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; -import org.tensorflow.op.random.TruncatedNormal; -import org.tensorflow.sandbox.optimizers.AdaDelta; -import org.tensorflow.sandbox.optimizers.AdaGrad; -import org.tensorflow.sandbox.optimizers.AdaGradDA; -import org.tensorflow.sandbox.optimizers.Adam; -import org.tensorflow.sandbox.optimizers.GradientDescent; -import org.tensorflow.sandbox.optimizers.Momentum; -import org.tensorflow.sandbox.optimizers.Optimizer; -import org.tensorflow.sandbox.optimizers.RMSProp; -import org.tensorflow.tools.Shape; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt32; - -import java.io.BufferedInputStream; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.util.Arrays; -import java.util.logging.Level; -import java.util.logging.Logger; - -/** - * Builds a LeNet-5 style CNN for MNIST. - */ -public class MNISTTest { - - private static final Logger logger = Logger.getLogger(MNISTTest.class.getName()); - - private static final int PIXEL_DEPTH = 255; - private static final int NUM_CHANNELS = 1; - private static final int IMAGE_SIZE = 28; - private static final int NUM_LABELS = 10; - private static final long SEED = 123456789L; - - private static final String PADDING_TYPE = "SAME"; - - public static final String INPUT_NAME = "input"; - public static final String OUTPUT_NAME = "output"; - public static final String TARGET = "target"; - public static final String TRAIN = "train"; - public static final String TRAINING_LOSS = "training_loss"; - public static final String EPOCH = "epoch"; - public static final String INIT = "init"; - - public static Graph build(String optimizerName) { - Graph graph = new Graph(); - - Ops tf = Ops.create(graph); - - // Inputs - Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.make(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); - Placeholder labels = tf.withName(TARGET).placeholder(TInt32.DTYPE); - - // Scaling the features - Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); - Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); - Operand scaledInput = tf.math.div(tf.math.sub(input, centeringFactor), scalingFactor); - - // First conv layer - Variable conv1Weights = tf.variable(Shape.make(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE); - Assign weights1Init = tf.assign(conv1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv1Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); - graph.addInitializer(weights1Init); - Conv2d conv1 = tf.nn.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); - Variable conv1Biases = tf.variable(Shape.make(32), TFloat32.DTYPE); - Assign biases1Init = tf.assign(conv1Biases, tf.fill(tf.shape(conv1Biases), tf.constant(0.0f))); - graph.addInitializer(biases1Init); - Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); - - // First pooling layer - MaxPool pool1 = tf.nn.maxPool(relu1, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); - - // Second conv layer - Variable conv2Weights = tf.variable(Shape.make(5, 5, 32, 64), TFloat32.DTYPE); - Assign weights2Init = tf.assign(conv2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv2Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); - graph.addInitializer(weights2Init); - Conv2d conv2 = tf.nn.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); - Variable conv2Biases = tf.variable(Shape.make(64), TFloat32.DTYPE); - Assign biases2Init = tf.assign(conv2Biases, tf.fill(tf.shape(conv2Biases), tf.constant(0.1f))); - graph.addInitializer(biases2Init); - Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); - - // Second pooling layer - MaxPool pool2 = tf.nn.maxPool(relu2, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); - - // Flatten inputs - Reshape flatten = tf.reshape(pool2, tf.concat(Arrays.asList(tf.slice(tf.shape(pool2), tf.constant(new int[]{0}), tf.constant(new int[]{1})), tf.constant(new int[]{-1})), tf.constant(0))); - - // Fully connected layer - Variable fc1Weights = tf.variable(Shape.make(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE); - Assign weights3Init = tf.assign(fc1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc1Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); - graph.addInitializer(weights3Init); - Variable fc1Biases = tf.variable(Shape.make(512), TFloat32.DTYPE); - Assign biases3Init = tf.assign(fc1Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc1Biases))); - graph.addInitializer(biases3Init); - Relu relu3 = tf.nn.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); - - // Softmax layer - Variable fc2Weights = tf.variable(Shape.make(512, NUM_LABELS), TFloat32.DTYPE); - Assign weights4Init = tf.assign(fc2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc2Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); - graph.addInitializer(weights4Init); - Variable fc2Biases = tf.variable(Shape.make(NUM_LABELS), TFloat32.DTYPE); - Assign biases4Init = tf.assign(fc2Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc2Biases))); - graph.addInitializer(biases4Init); - - Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); - - // Predicted outputs - Softmax prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); - - // Loss function & regularization - OneHot oneHot = tf.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); - SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); - Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); - Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math.add(tf.nn.l2Loss(fc1Biases), tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); - Add loss = tf.withName(TRAINING_LOSS).math.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); - - // Optimizer - Optimizer optimizer; - switch (optimizerName) { - case "AdaDelta": - case "Adadelta": - case "adadelta": - optimizer = new AdaDelta(graph, 1f, 0.95f, 1e-8f); - break; - case "AdaGradDA": - case "AdagradDA": - case "adagradda": - optimizer = new AdaGradDA(graph, 0.01f); - break; - case "AdaGrad": - case "Adagrad": - case "adagrad": - optimizer = new AdaGrad(graph, 0.01f); - break; - case "Adam": - case "adam": - optimizer = new Adam(graph,0.001f,0.9f,0.999f,1e-8f); - break; - case "SGD": - case "sgd": - optimizer = new GradientDescent(graph,0.01f); - break; - case "Momentum": - case "momentum": - optimizer = new Momentum(graph, 0.01f, 0.9f, false); - break; - case "RMSProp": - case "rmsprop": - optimizer = new RMSProp(graph,0.01f, 0.9f, 0.0f, 1e-10f, false); - break; - default: - throw new IllegalArgumentException("Unknown optimizer " + optimizerName); - } - logger.info("Optimizer = " + optimizer.toString()); - Op minimize = optimizer.minimize(loss, TRAIN); - - Op init = graph.variablesInitializer(); - - return graph; - } - - public static void train(Session session, int epochs, int minibatchSize, float[][][][] data, int[] labels) { - // Initialises the parameters. - session.runner().addTarget(INIT).run(); - logger.info("Initialised the model parameters"); - - float[][][][] featureBatch = new float[minibatchSize][][][]; - int[] labelBatch = new int[minibatchSize]; - - int interval = 0; - for (int i = 0; i < epochs; i++) { - logger.log(Level.INFO, "Starting epoch " + i); - //Tensor epoch = Tensor.create(i); - for (int j = 0; j < data.length; j += minibatchSize) { - for (int k = j, m = 0; k < (j + minibatchSize) && k < data.length; k++, m++) { - featureBatch[m] = data[k]; - labelBatch[m] = labels[k]; - } - //logger.info("Batch = " + batch.size()); - Tensor input = Tensor.create(featureBatch); - Tensor target = Tensor.create(labelBatch); - Tensor loss = session.runner() - .feed(INPUT_NAME, input) - .feed(TARGET, target) - .addTarget(TRAIN) - .fetch(TRAINING_LOSS) - .run().get(0); - if (interval % 100 == 0) { - logger.log(Level.INFO, "Iteration = " + interval + ", training loss = " + loss.floatValue()); - } - input.close(); - target.close(); - loss.close(); - interval++; - } - //epoch.close(); - } - } - - /** - * Find the maximum probability and return it's index. - * - * @param probabilities The probabilites. - * @return The index of the max. - */ - public static int pred(float[] probabilities) { - float maxVal = Float.NEGATIVE_INFINITY; - int idx = 0; - for (int i = 0; i < probabilities.length; i++) { - if (probabilities[i] > maxVal) { - maxVal = probabilities[i]; - idx = i; - } - } - return idx; - } - - public static DataTuple loadData(String path) throws IOException, ClassNotFoundException { - try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) { - float[][][][] data = (float[][][][]) ois.readObject(); - int[] labels = (int[]) ois.readObject(); - return new DataTuple(data, labels); - } - } - - private static class DataTuple { - public final float[][][][] features; - public final int[] labels; - - public DataTuple(float[][][][] features, int[] labels) { - this.features = features; - this.labels = labels; - } - } - - public static void main(String[] args) throws IOException, ClassNotFoundException { - logger.info("Usage: MNISTTest "); - - logger.info("Loading training data"); - DataTuple train = loadData(args[3]); - logger.info("Loading testing data"); - DataTuple test = loadData(args[4]); - - logger.info("Loaded data."); - - float[][][][] trainData = train.features; - int[] trainLabels = train.labels; - - float[][][][] testData = test.features; - int[] testLabels = test.labels; - - logger.info("Loaded " + trainLabels.length + " training labels"); - logger.info("Loaded " + testLabels.length + " testing labels"); - - int epochs = Integer.parseInt(args[0]); - int minibatchSize = Integer.parseInt(args[1]); - - Graph graph = build(args[2]); - - int correctCount = 0; - int[][] confusionMatrix = new int[10][10]; - - try (Session session = new Session(graph)) { - train(session, epochs, minibatchSize, trainData, trainLabels); - - logger.info("Trained model"); - - float[][][][] featureBatch = new float[minibatchSize][][][]; - int[] labelBatch = new int[minibatchSize]; - float[][] prediction; - - for (int j = 0; j < testData.length; j += minibatchSize) { - for (int k = j, m = 0; k < (j + minibatchSize) && k < testData.length; k++, m++) { - featureBatch[m] = testData[k]; - labelBatch[m] = testLabels[k]; - } - try (Tensor transformedInput = Tensor.create(featureBatch); - Tensor outputTensor = session.runner() - .feed(INPUT_NAME, transformedInput) - .fetch(OUTPUT_NAME).run().get(0)) { - prediction = outputTensor.copyTo(new float[minibatchSize][NUM_LABELS]); - } - - for (int k = 0; k < labelBatch.length; k++) { - int predLabel; - - predLabel = pred(prediction[k]); - if (predLabel == labelBatch[k]) { - correctCount++; - } - - confusionMatrix[labelBatch[k]][predLabel]++; - } - - if (j % 1000 == 0) { - logger.log(Level.INFO, "Cur accuracy = " + ((float) correctCount) / (j + minibatchSize)); - } - } - - logger.info("Final accuracy = " + ((float) correctCount) / testLabels.length); - - StringBuilder sb = new StringBuilder(); - sb.append("Label"); - for (int i = 0; i < confusionMatrix.length; i++) { - sb.append(String.format("%1$5s", "" + i)); - } - sb.append("\n"); - - for (int i = 0; i < confusionMatrix.length; i++) { - sb.append(String.format("%1$5s", "" + i)); - for (int j = 0; j < confusionMatrix[i].length; j++) { - sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); - } - sb.append("\n"); - } - - System.out.println(sb.toString()); - } - - } -} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java deleted file mode 100644 index c4cd4079df2..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaDelta.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox.optimizers; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.family.TType; - -import java.util.List; - -/** - * Optimizer that implements the Adadelta algorithm. - * - * See the paper. - */ -public class AdaDelta extends Optimizer { - - public static final String ACCUMULATOR = "accum"; - public static final String ACCUMULATOR_UPDATE = "accum_update"; - - private final float learningRate; - - private final float rho; - - private final float epsilon; - - public AdaDelta(Graph graph, float learningRate) { - this(graph, learningRate, 0.95f, 1e-8f); - } - - public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { - super(graph); - this.learningRate = learningRate; - this.rho = rho; - this.epsilon = epsilon; - } - - @Override - protected void createSlots(List> variables) { - for (Output v : variables) { - createAdaDeltaSlot(v); - } - } - - private void createAdaDeltaSlot(Output v) { - Operand accumulatorInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); - Operand updateInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); - } - - @Override - protected Operand applyDense(Output gradient, Output variable) { - Variable accumSlot = getSlot(variable,ACCUMULATOR).get(); - Variable accumUpdateSlot = getSlot(variable,ACCUMULATOR_UPDATE).get(); - return tf.train.applyAdadelta(variable, accumSlot, accumUpdateSlot, - tf.constant(learningRate, gradient.dataType()), - tf.constant(rho, gradient.dataType()), - tf.constant(epsilon, gradient.dataType()), - gradient); - } - - @Override - public String toString() { - return "AdaDelta{" + - "learningRate=" + learningRate + - ", rho=" + rho + - ", epsilon=" + epsilon + - '}'; - } - - @Override - public String getOptimizerName() { - return "Adadelta"; - } -} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java deleted file mode 100644 index 00a56f5853a..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGrad.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox.optimizers; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.family.TType; - -import java.util.List; - -/** - * Optimizer that implements the Adagrad algorithm. - * - * See the paper - * or this intro. - */ -public class AdaGrad extends Optimizer { - - public static final String ACCUMULATOR = "accumulator"; - - private final float learningRate; - - private final float initialAccumulatorValue; - - public AdaGrad(Graph graph, float learningRate) { - this(graph, learningRate, 0.01f); - } - - public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { - super(graph); - this.learningRate = learningRate; - this.initialAccumulatorValue = initialAccumulatorValue; - } - - @Override - protected void createSlots(List> variables) { - for (Output v : variables) { - createAdaGradSlot(v); - } - } - - private void createAdaGradSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), ACCUMULATOR, initializer); - } - - @Override - protected Operand applyDense(Output gradient, Output variable) { - Variable slot = getSlot(variable,ACCUMULATOR).get(); - return tf.train.applyAdagrad(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient); - } - - @Override - public String toString() { - return "AdaGrad{" + - "learningRate=" + learningRate + - ", initialAccumulatorValue=" + initialAccumulatorValue + - '}'; - } - - @Override - public String getOptimizerName() { - return "Adagrad"; - } -} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java deleted file mode 100644 index 753e104c8a1..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/AdaGradDA.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox.optimizers; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.op.Op; -import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.Variable; -import org.tensorflow.tools.Shape; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TType; - -import java.util.List; -import java.util.Optional; - -/** - * Optimizer that implements the Adagrad Dual-Averaging algorithm. - * - * See the paper. - */ -public class AdaGradDA extends Optimizer { - - public static final String ACCUMULATOR = "gradient_accumulator"; - public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator"; - - private Variable globalStep; - - private final float learningRate; - - private final float initialAccumulatorValue; - - private final float l1Strength; - - private final float l2Strength; - - public AdaGradDA(Graph graph, float learningRate) { - this(graph, learningRate, 0.1f, 0.0f, 0.0f); - } - - public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, float l1Strength, float l2Strength) { - super(graph); - this.learningRate = learningRate; - this.initialAccumulatorValue = initialAccumulatorValue; - this.l1Strength = l1Strength; - this.l2Strength = l2Strength; - } - - @Override - protected Optional> prepare(String name) { - return Optional.of(tf.assignAdd(globalStep,tf.constant(1L))); - } - - @Override - protected void createSlots(List> variables) { - for (Output v : variables) { - createAdaGradDASlot(v); - } - globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(),TInt64.DTYPE); - Assign globalStepInitializer = tf.assign(globalStep, tf.constant(0L)); - graph.addInitializer(globalStepInitializer); - } - - private void createAdaGradDASlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), ACCUMULATOR, initializer); - Operand sqInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); - } - - @Override - protected Operand applyDense(Output gradient, Output variable) { - Variable gradSlot = getSlot(variable,ACCUMULATOR).get(); - Variable gradSquaredSlot = getSlot(variable,SQUARED_ACCUMULATOR).get(); - return tf.train.applyAdagradDa(variable, gradSlot, gradSquaredSlot, gradient, - tf.constant(learningRate, gradient.dataType()), - tf.constant(l1Strength, gradient.dataType()), - tf.constant(l2Strength, gradient.dataType()), - globalStep); - } - - /** - * Gathers up the update operations into a single op that can be used as a run target. - *

- * Adds the global step update to the end of the updates list. - * @param updateOperations The update operations. - * @param name The name of the run target. - * @return A NoOp with a control dependency on each update operation. - */ - @Override - protected Op finish(List> updateOperations, String name) { - updateOperations.add(tf.assignAdd(globalStep,tf.constant(1L))); - return super.finish(updateOperations,name); - } - - @Override - public String toString() { - return "AdaGradDA{" + - "globalStep=" + globalStep + - ", learningRate=" + learningRate + - ", initialAccumulatorValue=" + initialAccumulatorValue + - ", l1Strength=" + l1Strength + - ", l2Strength=" + l2Strength + - '}'; - } - - @Override - public String getOptimizerName() { - return "adagrad-da"; - } -} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java deleted file mode 100644 index 7b18fab0354..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Adam.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox.optimizers; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.op.Op; -import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.Variable; -import org.tensorflow.tools.Shape; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.family.TType; - -import java.util.List; -import java.util.Optional; - -/** - * Optimizer that implements the Adam algorithm. - * - * See the paper. - */ -public class Adam extends Optimizer { - - public static final String FIRST_MOMENT = "m"; - public static final String SECOND_MOMENT = "v"; - - private final float learningRate; - - private final float betaOne; - - private final float betaTwo; - - private final float epsilon; - - private Constant learningRateConst; - private Constant epsilonConst; - private Constant betaOneConst; - private Constant betaTwoConst; - private Variable betaOnePower; - private Variable betaTwoPower; - - public Adam(Graph graph, float learningRate) { - this(graph, learningRate, 0.9f, 0.999f, 1e-8f); - } - - public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph); - this.learningRate = learningRate; - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; - } - - @Override - protected void createSlots(List> variables) { - for (Output v : variables) { - createAdamSlot(v.asOutput()); - } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(),TFloat32.DTYPE); - Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne, TFloat32.DTYPE)); - graph.addInitializer(betaOnePowerInit); - betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(),TFloat32.DTYPE); - Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo, TFloat32.DTYPE)); - graph.addInitializer(betaTwoPowerInit); - } - - @Override - protected Optional> prepare(String scopeName) { - betaOneConst = tf.constant(betaOne); - betaTwoConst = tf.constant(betaTwo); - learningRateConst = tf.constant(learningRate); - epsilonConst = tf.constant(epsilon); - return Optional.empty(); - } - - private void createAdamSlot(Output v) { - Operand firstMomentInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); - Operand secondMomentInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); - } - - @Override - protected Operand applyDense(Output gradient, Output variable) { - Variable firstMomentSlot = getSlot(variable,FIRST_MOMENT).get(); - Variable secondMomentSlot = getSlot(variable,SECOND_MOMENT).get(); - return tf.train.applyAdam(variable, firstMomentSlot, secondMomentSlot, - tf.dtypes.cast(betaOnePower,gradient.dataType()), - tf.dtypes.cast(betaTwoPower,gradient.dataType()), - tf.dtypes.cast(learningRateConst,gradient.dataType()), - tf.dtypes.cast(betaOneConst,gradient.dataType()), - tf.dtypes.cast(betaTwoConst,gradient.dataType()), - tf.dtypes.cast(epsilonConst,gradient.dataType()), - gradient); - } - - /** - * Gathers up the update operations into a single op that can be used as a run target. - *

- * Adds the betaOne and betaTwo updates to the end of the updates list. - * @param updateOperations The update operations. - * @param name The name of the run target. - * @return A NoOp with a control dependency on each update operation. - */ - @Override - protected Op finish(List> updateOperations, String name) { - updateOperations.add(tf.assign(betaOnePower,tf.math.mul(betaOnePower,betaOneConst))); - updateOperations.add(tf.assign(betaTwoPower,tf.math.mul(betaTwoPower,betaTwoConst))); - return super.finish(updateOperations,name); - } - - @Override - public String toString() { - return "Adam{" + - "learningRate=" + learningRate + - ", betaOne=" + betaOne + - ", betaTwo=" + betaTwo + - ", epsilon=" + epsilon + - '}'; - } - - @Override - public String getOptimizerName() { - return "Adam"; - } -} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java deleted file mode 100644 index c95398abe6f..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/GradientDescent.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox.optimizers; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.family.TType; - - -/** - * Basic SGD. - */ -public class GradientDescent extends Optimizer { - - private final float learningRate; - - public GradientDescent(Graph graph, float learningRate) { - super(graph); - this.learningRate = learningRate; - } - - @Override - protected Operand applyDense(Output gradient, Output variable) { - return tf.train.applyGradientDescent(variable, tf.dtypes.cast(tf.constant(learningRate, TFloat32.DTYPE), gradient.dataType()), gradient); - } - - @Override - public String toString() { - return "GradientDescent{" + - "learningRate=" + learningRate + - '}'; - } - - @Override - public String getOptimizerName() { - return "GradientDescent"; - } -} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java deleted file mode 100644 index 60f3497d570..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Momentum.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox.optimizers; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.op.core.Variable; -import org.tensorflow.op.train.ApplyMomentum; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.family.TType; - -import java.util.List; - -/** - * SGD plus momentum, either nesterov or traditional. - * - * See the paper for - * details of nesterov momentum. - */ -public class Momentum extends Optimizer { - - public static final String MOMENTUM = "momentum"; - - private final float learningRate; - - private final float momentum; - - private final boolean useNesterov; - - public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { - super(graph); - this.learningRate = learningRate; - this.momentum = momentum; - this.useNesterov = useNesterov; - } - - @Override - protected void createSlots(List> variables) { - for (Output v : variables) { - createMomentumSlot(v); - } - } - - private void createMomentumSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), MOMENTUM, initializer); - } - - @Override - protected Operand applyDense(Output gradient, Output variable) { - Variable slot = getSlot(variable,MOMENTUM).get(); - return tf.train.applyMomentum(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient, tf.constant(momentum, gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); - } - - @Override - public String toString() { - return "Momentum{" + - "learningRate=" + learningRate + - ", momentum=" + momentum + - ", useNesterov=" + useNesterov + - '}'; - } - - @Override - public String getOptimizerName() { - return "Momentum"; - } -} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java deleted file mode 100644 index 37467a50f34..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/Optimizer.java +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox.optimizers; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.Output; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.op.Scope; -import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.NoOp; -import org.tensorflow.op.core.Variable; -import org.tensorflow.types.family.TType; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * - */ -public abstract class Optimizer { - public static final String VARIABLE_V2 = "VariableV2"; - - /** - * Top level map key is the variable name, lower level map key is the slot name. - */ - private final Map>> slots; - - /** - * Global state variables - */ - //TODO make this be used. - protected final List> globals; - - /** - * The Graph this optimizer is operating on. - */ - protected final Graph graph; - - /** - * The ops builder for the graph. - */ - protected final Ops tf; - - protected Optimizer(Graph graph) { - this.graph = graph; - this.tf = Ops.create(graph).withName(getOptimizerName()); - this.slots = new HashMap<>(); - this.globals = new ArrayList<>(); - } - - public Op minimize(Operand loss) { - return minimize(loss, getOptimizerName()+"-minimize"); - } - - public Op minimize(Operand loss, String name) { - List> gradsAndVars = computeGradients(loss); - - return applyGradients(gradsAndVars, name); - } - - public List> computeGradients(Operand loss) { - List variables = new ArrayList<>(); - Iterator opItr = graph.operations(); - while (opItr.hasNext()) { - Operation op = opItr.next(); - if (op.type().equals(VARIABLE_V2)) { - variables.add(op); - } - } - - Output[] variableOutputArray = new Output[variables.size()]; - for (int i = 0; i < variables.size(); i++) { - // First output of a variable is it's output. - variableOutputArray[i] = variables.get(i).output(0); - } - - Output[] gradients = graph.addGradients(loss.asOutput(), variableOutputArray); - List> gradVarPairs = new ArrayList<>(); - - for (int i = 0; i < variableOutputArray.length; i++) { - @SuppressWarnings("unchecked") - Output typedGrad = (Output) gradients[i]; - @SuppressWarnings("unchecked") - Output typedVar = (Output) variableOutputArray[i]; - gradVarPairs.add(new GradAndVar<>(typedGrad, typedVar)); - } - - return gradVarPairs; - } - - public Op applyGradients(List> gradsAndVars, String name) { - List> variables = gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList()); - - createSlots(variables); - - Optional> prepOp = prepare(name+"/prepare"); - - List> updateOps = new ArrayList<>(); - prepOp.ifPresent(updateOps::add); - for (GradAndVar pair : gradsAndVars) { - updateOps.add(applyDense(pair)); - } - - return finish(updateOps,name); - } - - /** - * Gets the slot associated with the specified variable and slot name. - * @param var The variable to lookup. - * @param slotName The slot name. - * @return The slot or {@link Optional#empty}. - */ - public Optional> getSlot(Output var, String slotName) { - return getSlot(var.op().name(),slotName); - } - - /** - * Gets the slot associated with the specified variable and slot name. - * @param varName The variable to lookup. - * @param slotName The slot name. - * @return The slot or {@link Optional#empty}. - */ - private Optional> getSlot(String varName, String slotName) { - Map> variables = slots.get(slotName); - if (variables != null) { - Variable slot = variables.get(varName); - if (slot != null) { - @SuppressWarnings("unchecked") // This method should only be called when the type is known. - Optional> opt = Optional.of((Variable)slot); - return opt; - } else { - return Optional.empty(); - } - } else { - return Optional.empty(); - } - } - - /** - * Creates a slot in the graph for the specified variable with the specified name. Adds the slot's initializer - * to the graph's initializers, and the slot to the Optimizer's slot map. - * @param variable The variable to create the slot for. - * @param slotName The name of the slot. - * @param initializer The initializer for the slot. - * @param The type of the variable. - */ - protected void createSlot(Output variable, String slotName, Operand initializer) { - Variable slot = tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.dataType()); - Assign slotInit = tf.assign(slot, initializer); - graph.addInitializer(slotInit); - String varName = variable.op().name(); - Map> variables = slots.computeIfAbsent(slotName,(k) -> new HashMap<>()); - variables.put(varName,slot); - } - - /** - * No-op prepare method. - * - * @param scopeName The scope name to use for any variable creations. - */ - protected Optional> prepare(String scopeName) { - return Optional.empty(); - } - - /** - * No-op slot creation method. - * @param variables The variables to create slots for. - */ - protected void createSlots(List> variables) { } - - private Operand applyDense(GradAndVar gradVarPair) { - return applyDense(gradVarPair.getGradient(),gradVarPair.getVariable()); - } - - /** - * Generates the gradient update operations for the specific variable and gradient. - * @param gradient The gradient to use. - * @param variable The variable to update. - * @param The type of the variable. - * @return An operand which applies the desired optimizer update to the variable. - */ - protected abstract Operand applyDense(Output gradient, Output variable); - - /** - * Gathers up the update operations into a single op that can be used as a run target. - * @param updateOperations The update operations. - * @param name The name of the run target. - * @return A NoOp with a control dependency on each update operation. - */ - protected Op finish(List> updateOperations, String name) { - Scope scope = new Scope(graph); - scope = scope.withName(name); - scope = scope.withControlDependencies(updateOperations); - return NoOp.create(scope); - } - - /** - * Name of the optimizer. - * @return The optimizer name. - */ - public abstract String getOptimizerName(); - - public static String createName(Output variable, String slotName) { - return variable.op().name() + "-" + slotName; - } - - public static class GradAndVar { - private final Output gradient; - private final Output variable; - - public GradAndVar(Output gradient, Output variable) { - this.gradient = gradient; - this.variable = variable; - } - - public Output getGradient() { - return gradient; - } - - public Output getVariable() { - return variable; - } - } -} diff --git a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java b/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java deleted file mode 100644 index a20996f0018..00000000000 --- a/tensorflow-sandbox/src/main/java/org/tensorflow/sandbox/optimizers/RMSProp.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.sandbox.optimizers; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.family.TType; - -import java.util.List; - -/** - * Optimizer that implements the RMSProp algorithm. - * - * See the lecture notes - * that is inexplicably the canonical reference. - */ -public class RMSProp extends Optimizer { - - public static final String RMS = "rms"; - public static final String MG = "mg"; // mean gradient? - public static final String MOMENTUM = "momentum"; - - private final float learningRate; - private final float decay; - private final float momentum; - private final float epsilon; - private final boolean centered; - - public RMSProp(Graph graph, float learningRate) { - this(graph, learningRate, 0.9f, 0.0f, 1e-10f, false); - } - - public RMSProp(Graph graph, float learningRate, float decay, float momentum, float epsilon, boolean centered) { - super(graph); - this.learningRate = learningRate; - this.decay = decay; - this.momentum = momentum; - this.epsilon = epsilon; - this.centered = centered; - } - - @Override - protected void createSlots(List> variables) { - for (Output v : variables) { - createRMSPropSlot(v); - } - } - - private void createRMSPropSlot(Output v) { - Operand rmsInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), RMS, rmsInitializer); - Operand momentumInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), MOMENTUM, momentumInitializer); - if (centered) { - Operand mgInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); - createSlot(v.asOutput(), MG, mgInitializer); - } - } - - @Override - protected Operand applyDense(Output gradient, Output variable) { - Variable rmsSlot = getSlot(variable,RMS).get(); - Variable momentumSlot = getSlot(variable,MOMENTUM).get(); - if (centered) { - Variable mgSlot = getSlot(variable, MG).get(); - return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, - tf.constant(learningRate, gradient.dataType()), - tf.constant(decay, gradient.dataType()), - tf.constant(momentum, gradient.dataType()), - tf.constant(epsilon, gradient.dataType()), - gradient); - } else { - return tf.train.applyRmsProp(variable, rmsSlot, momentumSlot, - tf.constant(learningRate, gradient.dataType()), - tf.constant(decay, gradient.dataType()), - tf.constant(momentum, gradient.dataType()), - tf.constant(epsilon, gradient.dataType()), - gradient); - } - } - - @Override - public String toString() { - return "RMSProp{" + - "learningRate=" + learningRate + - ", decay=" + decay + - ", momentum=" + momentum + - ", epsilon=" + epsilon + - ", centered=" + centered + - '}'; - } - - @Override - public String getOptimizerName() { - return "RMSProp"; - } -} From 6cdb55c16b4437e5396f8959ccb1348212b0ec1f Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 7 Feb 2020 23:35:38 -0500 Subject: [PATCH 14/22] Delete pom.xml --- tensorflow-sandbox/pom.xml | 41 -------------------------------------- 1 file changed, 41 deletions(-) delete mode 100644 tensorflow-sandbox/pom.xml diff --git a/tensorflow-sandbox/pom.xml b/tensorflow-sandbox/pom.xml deleted file mode 100644 index fd829f54097..00000000000 --- a/tensorflow-sandbox/pom.xml +++ /dev/null @@ -1,41 +0,0 @@ - - - 4.0.0 - - - org.tensorflow - tensorflow-java - 0.1.0-SNAPSHOT - - tensorflow-sandbox - 0.1.0-SNAPSHOT - - - - org.tensorflow - tensorflow-core-api - 0.1.0-SNAPSHOT - - - - - - org.apache.maven.plugins - maven-compiler-plugin - 3.8.0 - - 1.8 - 1.8 - 1.8 - 1.8 - - -Xlint:all - - - - - - - \ No newline at end of file From b9d64c5885e169d5202f335b70992e9afb2af124 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 11 Feb 2020 08:40:26 -0500 Subject: [PATCH 15/22] Googlify with IntelliJ's Google Java Style Guide formatter. --- .../training/examples/MNISTTest.java | 106 ++++++++++-------- .../training/optimizers/AdaDelta.java | 12 +- .../training/optimizers/AdaGrad.java | 14 ++- .../training/optimizers/AdaGradDA.java | 26 +++-- .../tensorflow/training/optimizers/Adam.java | 43 +++---- .../training/optimizers/GradientDescent.java | 3 +- .../training/optimizers/Momentum.java | 15 ++- .../training/optimizers/Optimizer.java | 63 +++++++---- .../training/optimizers/RMSProp.java | 22 ++-- 9 files changed, 178 insertions(+), 126 deletions(-) diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java b/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java index 38f38370bce..e1097804227 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java @@ -84,7 +84,8 @@ public static Graph build(String optimizerName) { Ops tf = Ops.create(graph); // Inputs - Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.make(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); + Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat32.DTYPE, + Placeholder.shape(Shape.make(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); Placeholder labels = tf.withName(TARGET).placeholder(TInt32.DTYPE); // Scaling the features @@ -93,50 +94,55 @@ public static Graph build(String optimizerName) { Operand scaledInput = tf.math.div(tf.math.sub(input, centeringFactor), scalingFactor); // First conv layer - Variable conv1Weights = tf.variable(Shape.make(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE); - Assign weights1Init = tf.assign(conv1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv1Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); - graph.addInitializer(weights1Init); - Conv2d conv1 = tf.nn.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); - Variable conv1Biases = tf.variable(Shape.make(32), TFloat32.DTYPE); - Assign biases1Init = tf.assign(conv1Biases, tf.fill(tf.shape(conv1Biases), tf.constant(0.0f))); - graph.addInitializer(biases1Init); + Variable conv1Weights = tf.variableWithInit(tf.math.mul(tf.random + .truncatedNormal(tf.constant(new int[]{5, 5, NUM_CHANNELS, 32}), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Conv2d conv1 = tf.nn + .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable conv1Biases = tf + .variableWithInit(tf.fill(tf.constant(new int[]{32}), tf.constant(0.0f))); Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); // First pooling layer - MaxPool pool1 = tf.nn.maxPool(relu1, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); + MaxPool pool1 = tf.nn + .maxPool(relu1, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), + PADDING_TYPE); // Second conv layer - Variable conv2Weights = tf.variable(Shape.make(5, 5, 32, 64), TFloat32.DTYPE); - Assign weights2Init = tf.assign(conv2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(conv2Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); - graph.addInitializer(weights2Init); - Conv2d conv2 = tf.nn.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); - Variable conv2Biases = tf.variable(Shape.make(64), TFloat32.DTYPE); - Assign biases2Init = tf.assign(conv2Biases, tf.fill(tf.shape(conv2Biases), tf.constant(0.1f))); - graph.addInitializer(biases2Init); + Variable conv2Weights = tf.variableWithInit(tf.math.mul(tf.random + .truncatedNormal(tf.constant(new int[]{5, 5, 32, 64}), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Conv2d conv2 = tf.nn + .conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); + Variable conv2Biases = tf + .variableWithInit(tf.fill(tf.constant(new int[]{64}), tf.constant(0.1f))); Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); // Second pooling layer - MaxPool pool2 = tf.nn.maxPool(relu2, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), PADDING_TYPE); + MaxPool pool2 = tf.nn + .maxPool(relu2, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), + PADDING_TYPE); // Flatten inputs - Reshape flatten = tf.reshape(pool2, tf.concat(Arrays.asList(tf.slice(tf.shape(pool2), tf.constant(new int[]{0}), tf.constant(new int[]{1})), tf.constant(new int[]{-1})), tf.constant(0))); + Reshape flatten = tf.reshape(pool2, tf.concat(Arrays + .asList(tf.slice(tf.shape(pool2), tf.constant(new int[]{0}), tf.constant(new int[]{1})), + tf.constant(new int[]{-1})), tf.constant(0))); // Fully connected layer - Variable fc1Weights = tf.variable(Shape.make(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE); - Assign weights3Init = tf.assign(fc1Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc1Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); - graph.addInitializer(weights3Init); - Variable fc1Biases = tf.variable(Shape.make(512), TFloat32.DTYPE); - Assign biases3Init = tf.assign(fc1Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc1Biases))); - graph.addInitializer(biases3Init); - Relu relu3 = tf.nn.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); + Variable fc1Weights = tf.variableWithInit(tf.math.mul(tf.random + .truncatedNormal(tf.constant(new int[]{IMAGE_SIZE * IMAGE_SIZE * 4, 512}), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc1Biases = tf + .variableWithInit(tf.fill(tf.constant(new int[]{512}), tf.constant(0.1f))); + Relu relu3 = tf.nn + .relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); // Softmax layer - Variable fc2Weights = tf.variable(Shape.make(512, NUM_LABELS), TFloat32.DTYPE); - Assign weights4Init = tf.assign(fc2Weights, tf.math.mul(tf.random.truncatedNormal(tf.shape(fc2Weights), TFloat32.DTYPE, TruncatedNormal.seed(SEED)), tf.constant(0.1f))); - graph.addInitializer(weights4Init); - Variable fc2Biases = tf.variable(Shape.make(NUM_LABELS), TFloat32.DTYPE); - Assign biases4Init = tf.assign(fc2Biases, tf.broadcastTo(tf.constant(0.1f), tf.shape(fc2Biases))); - graph.addInitializer(biases4Init); + Variable fc2Weights = tf.variableWithInit(tf.math.mul(tf.random + .truncatedNormal(tf.constant(new int[]{512, NUM_LABELS}), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc2Biases = tf + .variableWithInit(tf.fill(tf.constant(new int[]{NUM_LABELS}), tf.constant(0.1f))); Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); @@ -144,11 +150,16 @@ public static Graph build(String optimizerName) { Softmax prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); // Loss function & regularization - OneHot oneHot = tf.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); - SoftmaxCrossEntropyWithLogits batchLoss = tf.nn.softmaxCrossEntropyWithLogits(logits, oneHot); + OneHot oneHot = tf + .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); + SoftmaxCrossEntropyWithLogits batchLoss = tf.nn + .softmaxCrossEntropyWithLogits(logits, oneHot); Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); - Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math.add(tf.nn.l2Loss(fc1Biases), tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); - Add loss = tf.withName(TRAINING_LOSS).math.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); + Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math + .add(tf.nn.l2Loss(fc1Biases), + tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); + Add loss = tf.withName(TRAINING_LOSS).math + .add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); // Optimizer Optimizer optimizer; @@ -170,11 +181,11 @@ public static Graph build(String optimizerName) { break; case "Adam": case "adam": - optimizer = new Adam(graph,0.001f,0.9f,0.999f,1e-8f); + optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); break; case "SGD": case "sgd": - optimizer = new GradientDescent(graph,0.01f); + optimizer = new GradientDescent(graph, 0.01f); break; case "Momentum": case "momentum": @@ -182,7 +193,7 @@ public static Graph build(String optimizerName) { break; case "RMSProp": case "rmsprop": - optimizer = new RMSProp(graph,0.01f, 0.9f, 0.0f, 1e-10f, false); + optimizer = new RMSProp(graph, 0.01f, 0.9f, 0.0f, 1e-10f, false); break; default: throw new IllegalArgumentException("Unknown optimizer " + optimizerName); @@ -195,7 +206,8 @@ public static Graph build(String optimizerName) { return graph; } - public static void train(Session session, int epochs, int minibatchSize, float[][][][] data, int[] labels) { + public static void train(Session session, int epochs, int minibatchSize, float[][][][] data, + int[] labels) { // Initialises the parameters. session.runner().addTarget(INIT).run(); logger.info("Initialised the model parameters"); @@ -222,7 +234,8 @@ public static void train(Session session, int epochs, int minibatchSize, float[] .fetch(TRAINING_LOSS) .run().get(0); if (interval % 100 == 0) { - logger.log(Level.INFO, "Iteration = " + interval + ", training loss = " + loss.floatValue()); + logger.log(Level.INFO, + "Iteration = " + interval + ", training loss = " + loss.floatValue()); } input.close(); target.close(); @@ -252,7 +265,8 @@ public static int pred(float[] probabilities) { } public static DataTuple loadData(String path) throws IOException, ClassNotFoundException { - try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream(path)))) { + try (ObjectInputStream ois = new ObjectInputStream( + new BufferedInputStream(new FileInputStream(path)))) { float[][][][] data = (float[][][][]) ois.readObject(); int[] labels = (int[]) ois.readObject(); return new DataTuple(data, labels); @@ -260,6 +274,7 @@ public static DataTuple loadData(String path) throws IOException, ClassNotFoundE } private static class DataTuple { + public final float[][][][] features; public final int[] labels; @@ -270,7 +285,8 @@ public DataTuple(float[][][][] features, int[] labels) { } public static void main(String[] args) throws IOException, ClassNotFoundException { - logger.info("Usage: MNISTTest "); + logger.info( + "Usage: MNISTTest "); logger.info("Loading training data"); DataTuple train = loadData(args[3]); @@ -311,9 +327,9 @@ public static void main(String[] args) throws IOException, ClassNotFoundExceptio labelBatch[m] = testLabels[k]; } try (Tensor transformedInput = Tensor.create(featureBatch); - Tensor outputTensor = session.runner() - .feed(INPUT_NAME, transformedInput) - .fetch(OUTPUT_NAME).run().get(0)) { + Tensor outputTensor = session.runner() + .feed(INPUT_NAME, transformedInput) + .fetch(OUTPUT_NAME).run().get(0)) { prediction = outputTensor.copyTo(new float[minibatchSize][NUM_LABELS]); } diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java index edc19b2aa47..ba52cc18b44 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java @@ -26,7 +26,7 @@ /** * Optimizer that implements the Adadelta algorithm. - * + *

* See the paper. */ public class AdaDelta extends Optimizer { @@ -59,16 +59,18 @@ protected void createSlots(List> variables) { } private void createAdaDeltaSlot(Output v) { - Operand accumulatorInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); + Operand accumulatorInitializer = tf + .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); - Operand updateInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); + Operand updateInitializer = tf + .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - Variable accumSlot = getSlot(variable,ACCUMULATOR).get(); - Variable accumUpdateSlot = getSlot(variable,ACCUMULATOR_UPDATE).get(); + Variable accumSlot = getSlot(variable, ACCUMULATOR).get(); + Variable accumUpdateSlot = getSlot(variable, ACCUMULATOR_UPDATE).get(); return tf.train.applyAdadelta(variable, accumSlot, accumUpdateSlot, tf.constant(learningRate, gradient.dataType()), tf.constant(rho, gradient.dataType()), diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java index 9146cc1b5c5..45d3de82fd1 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java @@ -26,9 +26,9 @@ /** * Optimizer that implements the Adagrad algorithm. - * - * See the paper - * or this intro. + *

+ * See the paper or this intro. */ public class AdaGrad extends Optimizer { @@ -56,14 +56,16 @@ protected void createSlots(List> variables) { } private void createAdaGradSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE),v.dataType())); + Operand initializer = tf.fill(tf.shape(v), + tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - Variable slot = getSlot(variable,ACCUMULATOR).get(); - return tf.train.applyAdagrad(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient); + Variable slot = getSlot(variable, ACCUMULATOR).get(); + return tf.train + .applyAdagrad(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient); } @Override diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java index 61b3bef21a4..f701488929a 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java @@ -31,7 +31,7 @@ /** * Optimizer that implements the Adagrad Dual-Averaging algorithm. - * + *

* See the paper. */ public class AdaGradDA extends Optimizer { @@ -53,7 +53,8 @@ public AdaGradDA(Graph graph, float learningRate) { this(graph, learningRate, 0.1f, 0.0f, 0.0f); } - public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, float l1Strength, float l2Strength) { + public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, float l1Strength, + float l2Strength) { super(graph); this.learningRate = learningRate; this.initialAccumulatorValue = initialAccumulatorValue; @@ -63,7 +64,7 @@ public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, @Override protected Optional> prepare(String name) { - return Optional.of(tf.assignAdd(globalStep,tf.constant(1L))); + return Optional.of(tf.assignAdd(globalStep, tf.constant(1L))); } @Override @@ -71,22 +72,24 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdaGradDASlot(v); } - globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(),TInt64.DTYPE); + globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(), TInt64.DTYPE); Assign globalStepInitializer = tf.assign(globalStep, tf.constant(0L)); graph.addInitializer(globalStepInitializer); } private void createAdaGradDASlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); + Operand initializer = tf + .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); - Operand sqInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE),v.dataType())); + Operand sqInitializer = tf.fill(tf.shape(v), + tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - Variable gradSlot = getSlot(variable,ACCUMULATOR).get(); - Variable gradSquaredSlot = getSlot(variable,SQUARED_ACCUMULATOR).get(); + Variable gradSlot = getSlot(variable, ACCUMULATOR).get(); + Variable gradSquaredSlot = getSlot(variable, SQUARED_ACCUMULATOR).get(); return tf.train.applyAdagradDa(variable, gradSlot, gradSquaredSlot, gradient, tf.constant(learningRate, gradient.dataType()), tf.constant(l1Strength, gradient.dataType()), @@ -98,14 +101,15 @@ protected Operand applyDense(Output gradient, Output * Gathers up the update operations into a single op that can be used as a run target. *

* Adds the global step update to the end of the updates list. + * * @param updateOperations The update operations. - * @param name The name of the run target. + * @param name The name of the run target. * @return A NoOp with a control dependency on each update operation. */ @Override protected Op finish(List> updateOperations, String name) { - updateOperations.add(tf.assignAdd(globalStep,tf.constant(1L))); - return super.finish(updateOperations,name); + updateOperations.add(tf.assignAdd(globalStep, tf.constant(1L))); + return super.finish(updateOperations, name); } @Override diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java index 87e5263cd3a..7acea714063 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java @@ -31,7 +31,7 @@ /** * Optimizer that implements the Adam algorithm. - * + *

* See the paper. */ public class Adam extends Optimizer { @@ -71,11 +71,13 @@ protected void createSlots(List> variables) { for (Output v : variables) { createAdamSlot(v.asOutput()); } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(),TFloat32.DTYPE); - Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne, TFloat32.DTYPE)); + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); + Assign betaOnePowerInit = tf + .assign(betaOnePower, tf.constant(betaOne, TFloat32.DTYPE)); graph.addInitializer(betaOnePowerInit); - betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(),TFloat32.DTYPE); - Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo, TFloat32.DTYPE)); + betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); + Assign betaTwoPowerInit = tf + .assign(betaTwoPower, tf.constant(betaTwo, TFloat32.DTYPE)); graph.addInitializer(betaTwoPowerInit); } @@ -89,23 +91,25 @@ protected Optional> prepare(String scopeName) { } private void createAdamSlot(Output v) { - Operand firstMomentInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); + Operand firstMomentInitializer = tf + .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); - Operand secondMomentInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); + Operand secondMomentInitializer = tf + .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - Variable firstMomentSlot = getSlot(variable,FIRST_MOMENT).get(); - Variable secondMomentSlot = getSlot(variable,SECOND_MOMENT).get(); + Variable firstMomentSlot = getSlot(variable, FIRST_MOMENT).get(); + Variable secondMomentSlot = getSlot(variable, SECOND_MOMENT).get(); return tf.train.applyAdam(variable, firstMomentSlot, secondMomentSlot, - tf.dtypes.cast(betaOnePower,gradient.dataType()), - tf.dtypes.cast(betaTwoPower,gradient.dataType()), - tf.dtypes.cast(learningRateConst,gradient.dataType()), - tf.dtypes.cast(betaOneConst,gradient.dataType()), - tf.dtypes.cast(betaTwoConst,gradient.dataType()), - tf.dtypes.cast(epsilonConst,gradient.dataType()), + tf.dtypes.cast(betaOnePower, gradient.dataType()), + tf.dtypes.cast(betaTwoPower, gradient.dataType()), + tf.dtypes.cast(learningRateConst, gradient.dataType()), + tf.dtypes.cast(betaOneConst, gradient.dataType()), + tf.dtypes.cast(betaTwoConst, gradient.dataType()), + tf.dtypes.cast(epsilonConst, gradient.dataType()), gradient); } @@ -113,15 +117,16 @@ protected Operand applyDense(Output gradient, Output * Gathers up the update operations into a single op that can be used as a run target. *

* Adds the betaOne and betaTwo updates to the end of the updates list. + * * @param updateOperations The update operations. - * @param name The name of the run target. + * @param name The name of the run target. * @return A NoOp with a control dependency on each update operation. */ @Override protected Op finish(List> updateOperations, String name) { - updateOperations.add(tf.assign(betaOnePower,tf.math.mul(betaOnePower,betaOneConst))); - updateOperations.add(tf.assign(betaTwoPower,tf.math.mul(betaTwoPower,betaTwoConst))); - return super.finish(updateOperations,name); + updateOperations.add(tf.assign(betaOnePower, tf.math.mul(betaOnePower, betaOneConst))); + updateOperations.add(tf.assign(betaTwoPower, tf.math.mul(betaTwoPower, betaTwoConst))); + return super.finish(updateOperations, name); } @Override diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java index 82de29e737a..e06d2d06340 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java @@ -36,7 +36,8 @@ public GradientDescent(Graph graph, float learningRate) { @Override protected Operand applyDense(Output gradient, Output variable) { - return tf.train.applyGradientDescent(variable, tf.dtypes.cast(tf.constant(learningRate, TFloat32.DTYPE), gradient.dataType()), gradient); + return tf.train.applyGradientDescent(variable, + tf.dtypes.cast(tf.constant(learningRate, TFloat32.DTYPE), gradient.dataType()), gradient); } @Override diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java index f925150b561..8ec78a5693a 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java @@ -27,9 +27,9 @@ /** * SGD plus momentum, either nesterov or traditional. - * - * See the paper for - * details of nesterov momentum. + *

+ * See the paper for details of + * nesterov momentum. */ public class Momentum extends Optimizer { @@ -56,14 +56,17 @@ protected void createSlots(List> variables) { } private void createMomentumSlot(Output v) { - Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); + Operand initializer = tf + .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), MOMENTUM, initializer); } @Override protected Operand applyDense(Output gradient, Output variable) { - Variable slot = getSlot(variable,MOMENTUM).get(); - return tf.train.applyMomentum(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient, tf.constant(momentum, gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); + Variable slot = getSlot(variable, MOMENTUM).get(); + return tf.train + .applyMomentum(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient, + tf.constant(momentum, gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); } @Override diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java index 8c2733abb15..5a330bc80a4 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java @@ -39,6 +39,7 @@ * */ public abstract class Optimizer { + public static final String VARIABLE_V2 = "VariableV2"; /** @@ -70,7 +71,7 @@ protected Optimizer(Graph graph) { } public Op minimize(Operand loss) { - return minimize(loss, getOptimizerName()+"-minimize"); + return minimize(loss, getOptimizerName() + "-minimize"); } public Op minimize(Operand loss, String name) { @@ -110,11 +111,12 @@ public List> computeGradients(Operand loss) { } public Op applyGradients(List> gradsAndVars, String name) { - List> variables = gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList()); + List> variables = gradsAndVars.stream().map(GradAndVar::getVariable) + .collect(Collectors.toList()); createSlots(variables); - Optional> prepOp = prepare(name+"/prepare"); + Optional> prepOp = prepare(name + "/prepare"); List> updateOps = new ArrayList<>(); prepOp.ifPresent(updateOps::add); @@ -122,32 +124,34 @@ public Op applyGradients(List> gradsAndVars, String updateOps.add(applyDense(pair)); } - return finish(updateOps,name); + return finish(updateOps, name); } /** * Gets the slot associated with the specified variable and slot name. - * @param var The variable to lookup. + * + * @param var The variable to lookup. * @param slotName The slot name. * @return The slot or {@link Optional#empty}. */ public Optional> getSlot(Output var, String slotName) { - return getSlot(var.op().name(),slotName); + return getSlot(var.op().name(), slotName); } /** * Gets the slot associated with the specified variable and slot name. - * @param varName The variable to lookup. + * + * @param varName The variable to lookup. * @param slotName The slot name. * @return The slot or {@link Optional#empty}. */ private Optional> getSlot(String varName, String slotName) { - Map> variables = slots.get(slotName); + Map> variables = slots.get(slotName); if (variables != null) { Variable slot = variables.get(varName); if (slot != null) { @SuppressWarnings("unchecked") // This method should only be called when the type is known. - Optional> opt = Optional.of((Variable)slot); + Optional> opt = Optional.of((Variable) slot); return opt; } else { return Optional.empty(); @@ -158,20 +162,24 @@ private Optional> getSlot(String varName, String s } /** - * Creates a slot in the graph for the specified variable with the specified name. Adds the slot's initializer - * to the graph's initializers, and the slot to the Optimizer's slot map. - * @param variable The variable to create the slot for. - * @param slotName The name of the slot. + * Creates a slot in the graph for the specified variable with the specified name. Adds the slot's + * initializer to the graph's initializers, and the slot to the Optimizer's slot map. + * + * @param variable The variable to create the slot for. + * @param slotName The name of the slot. * @param initializer The initializer for the slot. - * @param The type of the variable. + * @param The type of the variable. */ - protected void createSlot(Output variable, String slotName, Operand initializer) { - Variable slot = tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.dataType()); + protected void createSlot(Output variable, String slotName, + Operand initializer) { + Variable slot = tf.withName(createName(variable, slotName)) + .variable(variable.shape(), variable.dataType()); Assign slotInit = tf.assign(slot, initializer); graph.addInitializer(slotInit); String varName = variable.op().name(); - Map> variables = slots.computeIfAbsent(slotName,(k) -> new HashMap<>()); - variables.put(varName,slot); + Map> variables = slots + .computeIfAbsent(slotName, (k) -> new HashMap<>()); + variables.put(varName, slot); } /** @@ -185,27 +193,32 @@ protected Optional> prepare(String scopeName) { /** * No-op slot creation method. + * * @param variables The variables to create slots for. */ - protected void createSlots(List> variables) { } + protected void createSlots(List> variables) { + } private Operand applyDense(GradAndVar gradVarPair) { - return applyDense(gradVarPair.getGradient(),gradVarPair.getVariable()); + return applyDense(gradVarPair.getGradient(), gradVarPair.getVariable()); } /** * Generates the gradient update operations for the specific variable and gradient. + * * @param gradient The gradient to use. * @param variable The variable to update. - * @param The type of the variable. + * @param The type of the variable. * @return An operand which applies the desired optimizer update to the variable. */ - protected abstract Operand applyDense(Output gradient, Output variable); + protected abstract Operand applyDense(Output gradient, + Output variable); /** * Gathers up the update operations into a single op that can be used as a run target. + * * @param updateOperations The update operations. - * @param name The name of the run target. + * @param name The name of the run target. * @return A NoOp with a control dependency on each update operation. */ protected Op finish(List> updateOperations, String name) { @@ -217,6 +230,7 @@ protected Op finish(List> updateOperations, String name) { /** * Name of the optimizer. + * * @return The optimizer name. */ public abstract String getOptimizerName(); @@ -226,6 +240,7 @@ public static String createName(Output variable, String slotNam } public static class GradAndVar { + private final Output gradient; private final Output variable; @@ -239,7 +254,7 @@ public Output getGradient() { } public Output getVariable() { - return variable; + return variable; } } } diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java index 73ff1777923..9b572ddbfd2 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java @@ -26,9 +26,9 @@ /** * Optimizer that implements the RMSProp algorithm. - * - * See the lecture notes - * that is inexplicably the canonical reference. + *

+ * See the lecture + * notes that is inexplicably the canonical reference. */ public class RMSProp extends Optimizer { @@ -46,7 +46,8 @@ public RMSProp(Graph graph, float learningRate) { this(graph, learningRate, 0.9f, 0.0f, 1e-10f, false); } - public RMSProp(Graph graph, float learningRate, float decay, float momentum, float epsilon, boolean centered) { + public RMSProp(Graph graph, float learningRate, float decay, float momentum, float epsilon, + boolean centered) { super(graph); this.learningRate = learningRate; this.decay = decay; @@ -63,20 +64,23 @@ protected void createSlots(List> variables) { } private void createRMSPropSlot(Output v) { - Operand rmsInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f, TFloat32.DTYPE),v.dataType())); + Operand rmsInitializer = tf + .fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), RMS, rmsInitializer); - Operand momentumInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); + Operand momentumInitializer = tf + .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand mgInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE),v.dataType())); + Operand mgInitializer = tf + .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); createSlot(v.asOutput(), MG, mgInitializer); } } @Override protected Operand applyDense(Output gradient, Output variable) { - Variable rmsSlot = getSlot(variable,RMS).get(); - Variable momentumSlot = getSlot(variable,MOMENTUM).get(); + Variable rmsSlot = getSlot(variable, RMS).get(); + Variable momentumSlot = getSlot(variable, MOMENTUM).get(); if (centered) { Variable mgSlot = getSlot(variable, MG).get(); return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, From 6ae5ace082234c9c78e8d42efe550be42032edad Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 12 Feb 2020 11:09:06 -0500 Subject: [PATCH 16/22] Bumping the copyright year, and switching to try-with-resources in the MNISTTest. --- .../training/examples/MNISTTest.java | 32 +++++++++---------- .../training/optimizers/AdaDelta.java | 2 +- .../training/optimizers/AdaGrad.java | 2 +- .../training/optimizers/AdaGradDA.java | 2 +- .../tensorflow/training/optimizers/Adam.java | 2 +- .../training/optimizers/GradientDescent.java | 2 +- .../training/optimizers/Momentum.java | 2 +- .../training/optimizers/Optimizer.java | 2 +- .../training/optimizers/RMSProp.java | 2 +- 9 files changed, 23 insertions(+), 25 deletions(-) diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java b/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java index e1097804227..d4bbd7a7127 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -225,21 +225,19 @@ public static void train(Session session, int epochs, int minibatchSize, float[] labelBatch[m] = labels[k]; } //logger.info("Batch = " + batch.size()); - Tensor input = Tensor.create(featureBatch); - Tensor target = Tensor.create(labelBatch); - Tensor loss = session.runner() - .feed(INPUT_NAME, input) - .feed(TARGET, target) - .addTarget(TRAIN) - .fetch(TRAINING_LOSS) - .run().get(0); - if (interval % 100 == 0) { - logger.log(Level.INFO, - "Iteration = " + interval + ", training loss = " + loss.floatValue()); + try (Tensor input = Tensor.create(featureBatch); + Tensor target = Tensor.create(labelBatch); + Tensor loss = session.runner() + .feed(INPUT_NAME, input) + .feed(TARGET, target) + .addTarget(TRAIN) + .fetch(TRAINING_LOSS) + .run().get(0)) { + if (interval % 100 == 0) { + logger.log(Level.INFO, + "Iteration = " + interval + ", training loss = " + loss.floatValue()); + } } - input.close(); - target.close(); - loss.close(); interval++; } //epoch.close(); @@ -307,12 +305,12 @@ public static void main(String[] args) throws IOException, ClassNotFoundExceptio int epochs = Integer.parseInt(args[0]); int minibatchSize = Integer.parseInt(args[1]); - Graph graph = build(args[2]); int correctCount = 0; int[][] confusionMatrix = new int[10][10]; - try (Session session = new Session(graph)) { + try (Graph graph = build(args[2]); + Session session = new Session(graph)) { train(session, epochs, minibatchSize, trainData, trainLabels); logger.info("Trained model"); diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java index ba52cc18b44..fcd7a5813c6 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java index 45d3de82fd1..fca073102c5 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java index f701488929a..77348ce1f34 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java index 7acea714063..e5960782a99 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java index e06d2d06340..bc3d85d3ac0 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java index 8ec78a5693a..fc3c01906eb 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java index 5a330bc80a4..684529cd277 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java index 9b572ddbfd2..aad141c069c 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 56429e80522c205b2327a5c838dfe98894bcfe23 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 25 Feb 2020 10:08:23 -0500 Subject: [PATCH 17/22] Updating variableWithInit to use @Endpoint. --- .../annotations/org/tensorflow/op/Ops.java | 16 ++++++- .../src/main/java/org/tensorflow/Tensor.java | 2 +- .../java/org/tensorflow/op/core/CoreOps.java | 45 +++++++++++++++++++ .../tensorflow/training/optimizers/Adam.java | 23 ++++++++++ .../training/optimizers/Optimizer.java | 19 ++++++++ 5 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/CoreOps.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 3bf50107a61..84cc8b3d78a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -57,6 +57,7 @@ import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.ConsumeMutexLock; import org.tensorflow.op.core.ControlTrigger; +import org.tensorflow.op.core.CoreOps; import org.tensorflow.op.core.CountUpTo; import org.tensorflow.op.core.DeepCopy; import org.tensorflow.op.core.DeleteSessionTensor; @@ -1846,7 +1847,7 @@ public Gradients gradients(Iterable> y, Iterable{@code * Gradients gradients = tf.gradients(loss, Arrays.asList(w, b)); - * Scalar alpha = ops.scalar(1.0f); + * Constant alpha = tf.val(1.0f); * tf.train.applyGradientDescent(w, alpha, gradients.dy(0)); * tf.train.applyGradientDescent(b, alpha, gradients.dy(1)); * } @@ -7389,6 +7390,19 @@ public VariableShape variableShape(Operand input, Data return VariableShape.create(scope, input, outType); } + /** + * Factory method to create a new Variable with it's initializer. + * + * @param scope current scope + * @param init The op to use to initialise this variable. + * @param options carries optional attributes values + * @return a new instance of Variable + */ + public Variable variableWithInit(Operand init, + Variable.Options... options) { + return CoreOps.createVariableWithInit(scope, init, options); + } + /** * Returns locations of nonzero / true values in a tensor. *

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 5939db9ead9..da072f3f473 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -157,7 +157,7 @@ public static Tensor create(Object obj, DataType dtype) } long[] dimSizes = new long[numDimensions(obj, dtype)]; fillShape(obj, 0, dimSizes); - Tensor t = new Tensor<>(dtype, Shape.make(dimSizes)); + Tensor t = new Tensor<>(dtype, Shape.of(dimSizes)); TF_Tensor nativeHandle; if (t.dtype != TString.DTYPE) { long byteSize = elemByteSize(t.dtype) * t.shape.size(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/CoreOps.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/CoreOps.java new file mode 100644 index 00000000000..fae281c8461 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/CoreOps.java @@ -0,0 +1,45 @@ +package org.tensorflow.op.core; + +import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Container class for core methods which add or perform several operations + * and return one of them. + */ +@Operator +public abstract class CoreOps { + + /** + * This class contains static factories. + */ + private CoreOps() {} + + /** + * Factory method to create a new Variable with it's initializer. + * + * @param scope current scope + * @param init The op to use to initialise this variable. + * @param options carries optional attributes values + * @return a new instance of Variable + */ + @Endpoint(name="variableWithInit") + public static Variable createVariableWithInit(Scope scope, Operand init, Variable.Options... options) { + Output initOutput = init.asOutput(); + Variable newVar = Variable.create(scope,initOutput.shape(),initOutput.dataType(),options); + Assign assignOp = Assign.create(scope,newVar,init); + ExecutionEnvironment exEnv = scope.env(); + if (exEnv instanceof Graph) { + Graph graph = (Graph) exEnv; + graph.addInitializer(assignOp); + } + + return newVar; + } +} diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java index e5960782a99..3659c21b0b2 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java @@ -19,6 +19,9 @@ import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; @@ -34,6 +37,7 @@ *

* See the paper. */ +@Operator public class Adam extends Optimizer { public static final String FIRST_MOMENT = "m"; @@ -66,6 +70,25 @@ public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float this.epsilon = epsilon; } + @Endpoint(name="adam_minimize") + public static Op createAdamMinimize(Scope scope, Operand loss, float learningRate, float betaOne, float betaTwo, float epsilon, Optimizer.Options... options) { + if (!(scope.env() instanceof Graph)) { + throw new IllegalArgumentException("Optimizers are only supported on Graphs"); + } + Adam adam = new Adam((Graph)scope.env(),learningRate,betaOne,betaTwo,epsilon); + String name = null; + for (Options o : options) { + if (o.sharedName != null) { + name = o.sharedName; + } + } + if (name == null) { + return adam.minimize(loss); + } else { + return adam.minimize(loss,name); + } + } + @Override protected void createSlots(List> variables) { for (Output v : variables) { diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java index 684529cd277..ca727c8650d 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java @@ -42,6 +42,25 @@ public abstract class Optimizer { public static final String VARIABLE_V2 = "VariableV2"; + /** + * Optional attributes for {@link org.tensorflow.training.optimizers.Optimizer} + */ + public static class Options { + + /** + * @param sharedName If non-empty, this variable is named in the given bucket + * with this shared_name. Otherwise, the node name is used instead. + */ + public Optimizer.Options sharedName(String sharedName) { + this.sharedName = sharedName; + return this; + } + + protected String sharedName; + + private Options() { + } + } /** * Top level map key is the variable name, lower level map key is the slot name. */ From 5d8cb690974d9af71331dafa25891cd06fa8d123 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 25 Feb 2020 10:40:56 -0500 Subject: [PATCH 18/22] Refactorings after code review. --- .../annotations/org/tensorflow/op/Ops.java | 27 +++---- .../op/core/{CoreOps.java => Helpers.java} | 21 ++++- .../training/examples/MNISTTest.java | 54 ++++++------- .../training/optimizers/AdaDelta.java | 14 ++-- .../training/optimizers/AdaGrad.java | 9 +-- .../training/optimizers/AdaGradDA.java | 29 +++---- .../tensorflow/training/optimizers/Adam.java | 31 ++++---- .../training/optimizers/GradientDescent.java | 4 +- .../training/optimizers/Momentum.java | 12 +-- .../training/optimizers/Optimizer.java | 79 ++++++++----------- .../training/optimizers/RMSProp.java | 31 ++++---- 11 files changed, 150 insertions(+), 161 deletions(-) rename tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/{CoreOps.java => Helpers.java} (65%) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 84cc8b3d78a..96730d9cd14 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -57,7 +57,6 @@ import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.ConsumeMutexLock; import org.tensorflow.op.core.ControlTrigger; -import org.tensorflow.op.core.CoreOps; import org.tensorflow.op.core.CountUpTo; import org.tensorflow.op.core.DeepCopy; import org.tensorflow.op.core.DeleteSessionTensor; @@ -80,6 +79,7 @@ import org.tensorflow.op.core.Gradients; import org.tensorflow.op.core.GuaranteeConst; import org.tensorflow.op.core.HashTable; +import org.tensorflow.op.core.Helpers; import org.tensorflow.op.core.HistogramFixedWidth; import org.tensorflow.op.core.Identity; import org.tensorflow.op.core.IdentityN; @@ -7333,6 +7333,18 @@ public VarIsInitializedOp varIsInitializedOp(Operand resource) { return VarIsInitializedOp.create(scope, resource); } + /** + * Factory method to create a new Variable with it's initializer. + * + * @param scope current scope + * @param init The op to use to initialise this variable. + * @param options carries optional attributes values + * @return a new instance of Variable + */ + public Variable variable(Operand init, Variable.Options... options) { + return Helpers.createVariableWithInit(scope, init, options); + } + /** * Holds state in the form of a tensor that persists across steps. *

@@ -7390,19 +7402,6 @@ public VariableShape variableShape(Operand input, Data return VariableShape.create(scope, input, outType); } - /** - * Factory method to create a new Variable with it's initializer. - * - * @param scope current scope - * @param init The op to use to initialise this variable. - * @param options carries optional attributes values - * @return a new instance of Variable - */ - public Variable variableWithInit(Operand init, - Variable.Options... options) { - return CoreOps.createVariableWithInit(scope, init, options); - } - /** * Returns locations of nonzero / true values in a tensor. *

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/CoreOps.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java similarity index 65% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/CoreOps.java rename to tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java index fae281c8461..738aea50000 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/CoreOps.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.tensorflow.op.core; import org.tensorflow.ExecutionEnvironment; @@ -14,12 +29,12 @@ * and return one of them. */ @Operator -public abstract class CoreOps { +public abstract class Helpers { /** * This class contains static factories. */ - private CoreOps() {} + private Helpers() {} /** * Factory method to create a new Variable with it's initializer. @@ -29,7 +44,7 @@ private CoreOps() {} * @param options carries optional attributes values * @return a new instance of Variable */ - @Endpoint(name="variableWithInit") + @Endpoint(name="variable") public static Variable createVariableWithInit(Scope scope, Operand init, Variable.Options... options) { Output initOutput = init.asOutput(); Variable newVar = Variable.create(scope,initOutput.shape(),initOutput.dataType(),options); diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java b/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java index d4bbd7a7127..30e530efd39 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java @@ -21,7 +21,6 @@ import org.tensorflow.Tensor; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.OneHot; import org.tensorflow.op.core.Placeholder; @@ -85,64 +84,64 @@ public static Graph build(String optimizerName) { // Inputs Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat32.DTYPE, - Placeholder.shape(Shape.make(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); + Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); Placeholder labels = tf.withName(TARGET).placeholder(TInt32.DTYPE); // Scaling the features - Constant centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); - Constant scalingFactor = tf.constant((float) PIXEL_DEPTH); + Constant centeringFactor = tf.val(PIXEL_DEPTH / 2.0f); + Constant scalingFactor = tf.val((float) PIXEL_DEPTH); Operand scaledInput = tf.math.div(tf.math.sub(input, centeringFactor), scalingFactor); // First conv layer - Variable conv1Weights = tf.variableWithInit(tf.math.mul(tf.random - .truncatedNormal(tf.constant(new int[]{5, 5, NUM_CHANNELS, 32}), TFloat32.DTYPE, - TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable conv1Weights = tf.variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.val(0.1f))); Conv2d conv1 = tf.nn .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); Variable conv1Biases = tf - .variableWithInit(tf.fill(tf.constant(new int[]{32}), tf.constant(0.0f))); + .variable(tf.fill(tf.array(new int[]{32}), tf.val(0.0f))); Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); // First pooling layer MaxPool pool1 = tf.nn - .maxPool(relu1, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), + .maxPool(relu1, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), PADDING_TYPE); // Second conv layer - Variable conv2Weights = tf.variableWithInit(tf.math.mul(tf.random - .truncatedNormal(tf.constant(new int[]{5, 5, 32, 64}), TFloat32.DTYPE, - TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable conv2Weights = tf.variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.val(0.1f))); Conv2d conv2 = tf.nn .conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); Variable conv2Biases = tf - .variableWithInit(tf.fill(tf.constant(new int[]{64}), tf.constant(0.1f))); + .variable(tf.fill(tf.array(new int[]{64}), tf.val(0.1f))); Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); // Second pooling layer MaxPool pool2 = tf.nn - .maxPool(relu2, tf.constant(new int[]{1, 2, 2, 1}), tf.constant(new int[]{1, 2, 2, 1}), + .maxPool(relu2, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), PADDING_TYPE); // Flatten inputs Reshape flatten = tf.reshape(pool2, tf.concat(Arrays - .asList(tf.slice(tf.shape(pool2), tf.constant(new int[]{0}), tf.constant(new int[]{1})), - tf.constant(new int[]{-1})), tf.constant(0))); + .asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})), + tf.array(new int[]{-1})), tf.val(0))); // Fully connected layer - Variable fc1Weights = tf.variableWithInit(tf.math.mul(tf.random - .truncatedNormal(tf.constant(new int[]{IMAGE_SIZE * IMAGE_SIZE * 4, 512}), TFloat32.DTYPE, - TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc1Weights = tf.variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.val(0.1f))); Variable fc1Biases = tf - .variableWithInit(tf.fill(tf.constant(new int[]{512}), tf.constant(0.1f))); + .variable(tf.fill(tf.array(new int[]{512}), tf.val(0.1f))); Relu relu3 = tf.nn .relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); // Softmax layer - Variable fc2Weights = tf.variableWithInit(tf.math.mul(tf.random - .truncatedNormal(tf.constant(new int[]{512, NUM_LABELS}), TFloat32.DTYPE, - TruncatedNormal.seed(SEED)), tf.constant(0.1f))); + Variable fc2Weights = tf.variable(tf.math.mul(tf.random + .truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.DTYPE, + TruncatedNormal.seed(SEED)), tf.val(0.1f))); Variable fc2Biases = tf - .variableWithInit(tf.fill(tf.constant(new int[]{NUM_LABELS}), tf.constant(0.1f))); + .variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.val(0.1f))); Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); @@ -151,15 +150,15 @@ public static Graph build(String optimizerName) { // Loss function & regularization OneHot oneHot = tf - .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); + .oneHot(labels, tf.val(10), tf.val(1.0f), tf.val(0.0f)); SoftmaxCrossEntropyWithLogits batchLoss = tf.nn .softmaxCrossEntropyWithLogits(logits, oneHot); - Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); + Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.val(0)); Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math .add(tf.nn.l2Loss(fc1Biases), tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); Add loss = tf.withName(TRAINING_LOSS).math - .add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); + .add(labelLoss, tf.math.mul(regularizers, tf.val(5e-4f))); // Optimizer Optimizer optimizer; @@ -305,7 +304,6 @@ public static void main(String[] args) throws IOException, ClassNotFoundExceptio int epochs = Integer.parseInt(args[0]); int minibatchSize = Integer.parseInt(args[1]); - int correctCount = 0; int[][] confusionMatrix = new int[10][10]; diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java index fcd7a5813c6..f1cd9ae8e37 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java @@ -15,15 +15,13 @@ */ package org.tensorflow.training.optimizers; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.List; - /** * Optimizer that implements the Adadelta algorithm. *

@@ -60,10 +58,10 @@ protected void createSlots(List> variables) { private void createAdaDeltaSlot(Output v) { Operand accumulatorInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); + .fill(tf.shape(v), tf.dtypes.cast(tf.val(0.0f), v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, accumulatorInitializer); Operand updateInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); + .fill(tf.shape(v), tf.dtypes.cast(tf.val(0.0f), v.dataType())); createSlot(v.asOutput(), ACCUMULATOR_UPDATE, updateInitializer); } @@ -72,9 +70,9 @@ protected Operand applyDense(Output gradient, Output Variable accumSlot = getSlot(variable, ACCUMULATOR).get(); Variable accumUpdateSlot = getSlot(variable, ACCUMULATOR_UPDATE).get(); return tf.train.applyAdadelta(variable, accumSlot, accumUpdateSlot, - tf.constant(learningRate, gradient.dataType()), - tf.constant(rho, gradient.dataType()), - tf.constant(epsilon, gradient.dataType()), + tf.dtypes.cast(tf.val(learningRate), gradient.dataType()), + tf.dtypes.cast(tf.val(rho), gradient.dataType()), + tf.dtypes.cast(tf.val(epsilon), gradient.dataType()), gradient); } diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java index fca073102c5..d34e331aa6b 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java @@ -15,15 +15,13 @@ */ package org.tensorflow.training.optimizers; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.List; - /** * Optimizer that implements the Adagrad algorithm. *

@@ -57,7 +55,7 @@ protected void createSlots(List> variables) { private void createAdaGradSlot(Output v) { Operand initializer = tf.fill(tf.shape(v), - tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE), v.dataType())); + tf.dtypes.cast(tf.val(initialAccumulatorValue), v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); } @@ -65,7 +63,8 @@ private void createAdaGradSlot(Output v) { protected Operand applyDense(Output gradient, Output variable) { Variable slot = getSlot(variable, ACCUMULATOR).get(); return tf.train - .applyAdagrad(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient); + .applyAdagrad(variable, slot, tf.dtypes.cast(tf.val(learningRate), gradient.dataType()), + gradient); } @Override diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java index 77348ce1f34..a1cd38d6bb2 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java @@ -15,6 +15,8 @@ */ package org.tensorflow.training.optimizers; +import java.util.List; +import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; @@ -22,13 +24,9 @@ import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Variable; import org.tensorflow.tools.Shape; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; -import java.util.List; -import java.util.Optional; - /** * Optimizer that implements the Adagrad Dual-Averaging algorithm. *

@@ -38,16 +36,11 @@ public class AdaGradDA extends Optimizer { public static final String ACCUMULATOR = "gradient_accumulator"; public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator"; - - private Variable globalStep; - private final float learningRate; - private final float initialAccumulatorValue; - private final float l1Strength; - private final float l2Strength; + private Variable globalStep; public AdaGradDA(Graph graph, float learningRate) { this(graph, learningRate, 0.1f, 0.0f, 0.0f); @@ -64,7 +57,7 @@ public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, @Override protected Optional> prepare(String name) { - return Optional.of(tf.assignAdd(globalStep, tf.constant(1L))); + return Optional.of(tf.assignAdd(globalStep, tf.val(1L))); } @Override @@ -73,16 +66,16 @@ protected void createSlots(List> variables) { createAdaGradDASlot(v); } globalStep = tf.withName("adagrad-da-global-step").variable(Shape.scalar(), TInt64.DTYPE); - Assign globalStepInitializer = tf.assign(globalStep, tf.constant(0L)); + Assign globalStepInitializer = tf.assign(globalStep, tf.val(0L)); graph.addInitializer(globalStepInitializer); } private void createAdaGradDASlot(Output v) { Operand initializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); + .fill(tf.shape(v), tf.dtypes.cast(tf.val(0.0f), v.dataType())); createSlot(v.asOutput(), ACCUMULATOR, initializer); Operand sqInitializer = tf.fill(tf.shape(v), - tf.dtypes.cast(tf.constant(initialAccumulatorValue, TFloat32.DTYPE), v.dataType())); + tf.dtypes.cast(tf.val(initialAccumulatorValue), v.dataType())); createSlot(v.asOutput(), SQUARED_ACCUMULATOR, sqInitializer); } @@ -91,9 +84,9 @@ protected Operand applyDense(Output gradient, Output Variable gradSlot = getSlot(variable, ACCUMULATOR).get(); Variable gradSquaredSlot = getSlot(variable, SQUARED_ACCUMULATOR).get(); return tf.train.applyAdagradDa(variable, gradSlot, gradSquaredSlot, gradient, - tf.constant(learningRate, gradient.dataType()), - tf.constant(l1Strength, gradient.dataType()), - tf.constant(l2Strength, gradient.dataType()), + tf.dtypes.cast(tf.val(learningRate), gradient.dataType()), + tf.dtypes.cast(tf.val(l1Strength), gradient.dataType()), + tf.dtypes.cast(tf.val(l2Strength), gradient.dataType()), globalStep); } @@ -108,7 +101,7 @@ protected Operand applyDense(Output gradient, Output */ @Override protected Op finish(List> updateOperations, String name) { - updateOperations.add(tf.assignAdd(globalStep, tf.constant(1L))); + updateOperations.add(tf.assignAdd(globalStep, tf.val(1L))); return super.finish(updateOperations, name); } diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java index 3659c21b0b2..1a5b757d24a 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java @@ -15,6 +15,8 @@ */ package org.tensorflow.training.optimizers; +import java.util.List; +import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; @@ -29,9 +31,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.List; -import java.util.Optional; - /** * Optimizer that implements the Adam algorithm. *

@@ -70,12 +69,14 @@ public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float this.epsilon = epsilon; } - @Endpoint(name="adam_minimize") - public static Op createAdamMinimize(Scope scope, Operand loss, float learningRate, float betaOne, float betaTwo, float epsilon, Optimizer.Options... options) { + @Endpoint(name = "adam_minimize") + public static Op createAdamMinimize(Scope scope, Operand loss, + float learningRate, float betaOne, float betaTwo, float epsilon, + Optimizer.Options... options) { if (!(scope.env() instanceof Graph)) { throw new IllegalArgumentException("Optimizers are only supported on Graphs"); } - Adam adam = new Adam((Graph)scope.env(),learningRate,betaOne,betaTwo,epsilon); + Adam adam = new Adam((Graph) scope.env(), learningRate, betaOne, betaTwo, epsilon); String name = null; for (Options o : options) { if (o.sharedName != null) { @@ -85,7 +86,7 @@ public static Op createAdamMinimize(Scope scope, Operand lo if (name == null) { return adam.minimize(loss); } else { - return adam.minimize(loss,name); + return adam.minimize(loss, name); } } @@ -96,29 +97,29 @@ protected void createSlots(List> variables) { } betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); Assign betaOnePowerInit = tf - .assign(betaOnePower, tf.constant(betaOne, TFloat32.DTYPE)); + .assign(betaOnePower, tf.val(betaOne)); graph.addInitializer(betaOnePowerInit); betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); Assign betaTwoPowerInit = tf - .assign(betaTwoPower, tf.constant(betaTwo, TFloat32.DTYPE)); + .assign(betaTwoPower, tf.val(betaTwo)); graph.addInitializer(betaTwoPowerInit); } @Override protected Optional> prepare(String scopeName) { - betaOneConst = tf.constant(betaOne); - betaTwoConst = tf.constant(betaTwo); - learningRateConst = tf.constant(learningRate); - epsilonConst = tf.constant(epsilon); + betaOneConst = tf.val(betaOne); + betaTwoConst = tf.val(betaTwo); + learningRateConst = tf.val(learningRate); + epsilonConst = tf.val(epsilon); return Optional.empty(); } private void createAdamSlot(Output v) { Operand firstMomentInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); + .fill(tf.shape(v), tf.dtypes.cast(tf.val(0.0f), v.dataType())); createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); Operand secondMomentInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); + .fill(tf.shape(v), tf.dtypes.cast(tf.val(0.0f), v.dataType())); createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); } diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java index bc3d85d3ac0..08876b9c702 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java @@ -18,10 +18,8 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; - /** * Basic SGD. */ @@ -37,7 +35,7 @@ public GradientDescent(Graph graph, float learningRate) { @Override protected Operand applyDense(Output gradient, Output variable) { return tf.train.applyGradientDescent(variable, - tf.dtypes.cast(tf.constant(learningRate, TFloat32.DTYPE), gradient.dataType()), gradient); + tf.dtypes.cast(tf.val(learningRate), gradient.dataType()), gradient); } @Override diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java index fc3c01906eb..9ae866b99c6 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java @@ -15,16 +15,14 @@ */ package org.tensorflow.training.optimizers; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.List; - /** * SGD plus momentum, either nesterov or traditional. *

@@ -57,7 +55,7 @@ protected void createSlots(List> variables) { private void createMomentumSlot(Output v) { Operand initializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); + .fill(tf.shape(v), tf.dtypes.cast(tf.val(0.0f), v.dataType())); createSlot(v.asOutput(), MOMENTUM, initializer); } @@ -65,8 +63,10 @@ private void createMomentumSlot(Output v) { protected Operand applyDense(Output gradient, Output variable) { Variable slot = getSlot(variable, MOMENTUM).get(); return tf.train - .applyMomentum(variable, slot, tf.constant(learningRate, gradient.dataType()), gradient, - tf.constant(momentum, gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); + .applyMomentum(variable, slot, tf.dtypes.cast(tf.val(learningRate), gradient.dataType()), + gradient, + tf.dtypes.cast(tf.val(momentum), gradient.dataType()), + ApplyMomentum.useNesterov(useNesterov)); } @Override diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java index ca727c8650d..2c9fece10d3 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java @@ -15,6 +15,12 @@ */ package org.tensorflow.training.optimizers; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -27,60 +33,29 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - /** * */ public abstract class Optimizer { public static final String VARIABLE_V2 = "VariableV2"; - - /** - * Optional attributes for {@link org.tensorflow.training.optimizers.Optimizer} - */ - public static class Options { - - /** - * @param sharedName If non-empty, this variable is named in the given bucket - * with this shared_name. Otherwise, the node name is used instead. - */ - public Optimizer.Options sharedName(String sharedName) { - this.sharedName = sharedName; - return this; - } - - protected String sharedName; - - private Options() { - } - } - /** - * Top level map key is the variable name, lower level map key is the slot name. - */ - private final Map>> slots; - /** * Global state variables */ //TODO make this be used. protected final List> globals; - /** * The Graph this optimizer is operating on. */ protected final Graph graph; - /** * The ops builder for the graph. */ protected final Ops tf; + /** + * Top level map key is the variable name, lower level map key is the slot name. + */ + private final Map>> slots; protected Optimizer(Graph graph) { this.graph = graph; @@ -89,6 +64,10 @@ protected Optimizer(Graph graph) { this.globals = new ArrayList<>(); } + public static String createName(Output variable, String slotName) { + return variable.op().name() + "-" + slotName; + } + public Op minimize(Operand loss) { return minimize(loss, getOptimizerName() + "-minimize"); } @@ -101,13 +80,11 @@ public Op minimize(Operand loss, String name) { public List> computeGradients(Operand loss) { List variables = new ArrayList<>(); - Iterator opItr = graph.operations(); - while (opItr.hasNext()) { - Operation op = opItr.next(); + graph.operations().forEachRemaining((Operation op) -> { if (op.type().equals(VARIABLE_V2)) { variables.add(op); } - } + }); Output[] variableOutputArray = new Output[variables.size()]; for (int i = 0; i < variables.size(); i++) { @@ -172,12 +149,10 @@ private Optional> getSlot(String varName, String s @SuppressWarnings("unchecked") // This method should only be called when the type is known. Optional> opt = Optional.of((Variable) slot); return opt; - } else { - return Optional.empty(); } - } else { return Optional.empty(); } + return Optional.empty(); } /** @@ -254,8 +229,24 @@ protected Op finish(List> updateOperations, String name) { */ public abstract String getOptimizerName(); - public static String createName(Output variable, String slotName) { - return variable.op().name() + "-" + slotName; + /** + * Optional attributes for {@link org.tensorflow.training.optimizers.Optimizer} + */ + public static class Options { + + protected String sharedName; + + private Options() { + } + + /** + * @param sharedName If non-empty, this variable is named in the given bucket with this + * shared_name. Otherwise, the node name is used instead. + */ + public Optimizer.Options sharedName(String sharedName) { + this.sharedName = sharedName; + return this; + } } public static class GradAndVar { diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java index aad141c069c..9722fd3ba65 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java @@ -15,15 +15,13 @@ */ package org.tensorflow.training.optimizers; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.List; - /** * Optimizer that implements the RMSProp algorithm. *

@@ -65,14 +63,14 @@ protected void createSlots(List> variables) { private void createRMSPropSlot(Output v) { Operand rmsInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f, TFloat32.DTYPE), v.dataType())); + .fill(tf.shape(v), tf.dtypes.cast(tf.val(1.0f), v.dataType())); createSlot(v.asOutput(), RMS, rmsInitializer); Operand momentumInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); + .fill(tf.shape(v), tf.dtypes.cast(tf.val(0.0f), v.dataType())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { Operand mgInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f, TFloat32.DTYPE), v.dataType())); + .fill(tf.shape(v), tf.dtypes.cast(tf.val(0.0f), v.dataType())); createSlot(v.asOutput(), MG, mgInitializer); } } @@ -84,19 +82,18 @@ protected Operand applyDense(Output gradient, Output if (centered) { Variable mgSlot = getSlot(variable, MG).get(); return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, - tf.constant(learningRate, gradient.dataType()), - tf.constant(decay, gradient.dataType()), - tf.constant(momentum, gradient.dataType()), - tf.constant(epsilon, gradient.dataType()), - gradient); - } else { - return tf.train.applyRmsProp(variable, rmsSlot, momentumSlot, - tf.constant(learningRate, gradient.dataType()), - tf.constant(decay, gradient.dataType()), - tf.constant(momentum, gradient.dataType()), - tf.constant(epsilon, gradient.dataType()), + tf.dtypes.cast(tf.val(learningRate), gradient.dataType()), + tf.dtypes.cast(tf.val(decay), gradient.dataType()), + tf.dtypes.cast(tf.val(momentum), gradient.dataType()), + tf.dtypes.cast(tf.val(epsilon), gradient.dataType()), gradient); } + return tf.train.applyRmsProp(variable, rmsSlot, momentumSlot, + tf.dtypes.cast(tf.val(learningRate), gradient.dataType()), + tf.dtypes.cast(tf.val(decay), gradient.dataType()), + tf.dtypes.cast(tf.val(momentum), gradient.dataType()), + tf.dtypes.cast(tf.val(epsilon), gradient.dataType()), + gradient); } @Override From 51f5d47e87de878c16cf252a1e8e2578717ecbab Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 25 Feb 2020 10:41:20 -0500 Subject: [PATCH 19/22] Adding a couple of lines to the gitignore. --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 4a14794504c..2063545f295 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,10 @@ xcuserdata/** /estimator_api_init_files_list.txt *.whl +# Patch files +*.orig +*.rej + # Android .gradle .idea From 66876eddb2403d5ea60a96b42f071545b45802bd Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 25 Feb 2020 12:11:57 -0500 Subject: [PATCH 20/22] Adding a bit of documentation, threading the named operations through the constructors, removing the MNISTTtest. --- .../training/examples/MNISTTest.java | 369 ------------------ .../training/optimizers/AdaDelta.java | 11 + .../training/optimizers/AdaGrad.java | 10 + .../training/optimizers/AdaGradDA.java | 13 + .../tensorflow/training/optimizers/Adam.java | 12 + .../training/optimizers/GradientDescent.java | 5 + .../training/optimizers/Momentum.java | 7 + .../training/optimizers/Optimizer.java | 20 +- .../training/optimizers/RMSProp.java | 14 + 9 files changed, 91 insertions(+), 370 deletions(-) delete mode 100644 tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java b/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java deleted file mode 100644 index 30e530efd39..00000000000 --- a/tensorflow-training/src/main/java/org/tensorflow/training/examples/MNISTTest.java +++ /dev/null @@ -1,369 +0,0 @@ -/* - * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.tensorflow.training.examples; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Session; -import org.tensorflow.Tensor; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.OneHot; -import org.tensorflow.op.core.Placeholder; -import org.tensorflow.op.core.Reshape; -import org.tensorflow.op.core.Variable; -import org.tensorflow.op.math.Add; -import org.tensorflow.op.math.Mean; -import org.tensorflow.op.nn.Conv2d; -import org.tensorflow.op.nn.MaxPool; -import org.tensorflow.op.nn.Relu; -import org.tensorflow.op.nn.Softmax; -import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; -import org.tensorflow.op.random.TruncatedNormal; -import org.tensorflow.training.optimizers.AdaDelta; -import org.tensorflow.training.optimizers.AdaGrad; -import org.tensorflow.training.optimizers.AdaGradDA; -import org.tensorflow.training.optimizers.Adam; -import org.tensorflow.training.optimizers.GradientDescent; -import org.tensorflow.training.optimizers.Momentum; -import org.tensorflow.training.optimizers.Optimizer; -import org.tensorflow.training.optimizers.RMSProp; -import org.tensorflow.tools.Shape; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt32; - -import java.io.BufferedInputStream; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.util.Arrays; -import java.util.logging.Level; -import java.util.logging.Logger; - -/** - * Builds a LeNet-5 style CNN for MNIST. - */ -public class MNISTTest { - - private static final Logger logger = Logger.getLogger(MNISTTest.class.getName()); - - private static final int PIXEL_DEPTH = 255; - private static final int NUM_CHANNELS = 1; - private static final int IMAGE_SIZE = 28; - private static final int NUM_LABELS = 10; - private static final long SEED = 123456789L; - - private static final String PADDING_TYPE = "SAME"; - - public static final String INPUT_NAME = "input"; - public static final String OUTPUT_NAME = "output"; - public static final String TARGET = "target"; - public static final String TRAIN = "train"; - public static final String TRAINING_LOSS = "training_loss"; - public static final String EPOCH = "epoch"; - public static final String INIT = "init"; - - public static Graph build(String optimizerName) { - Graph graph = new Graph(); - - Ops tf = Ops.create(graph); - - // Inputs - Placeholder input = tf.withName(INPUT_NAME).placeholder(TFloat32.DTYPE, - Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))); - Placeholder labels = tf.withName(TARGET).placeholder(TInt32.DTYPE); - - // Scaling the features - Constant centeringFactor = tf.val(PIXEL_DEPTH / 2.0f); - Constant scalingFactor = tf.val((float) PIXEL_DEPTH); - Operand scaledInput = tf.math.div(tf.math.sub(input, centeringFactor), scalingFactor); - - // First conv layer - Variable conv1Weights = tf.variable(tf.math.mul(tf.random - .truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE, - TruncatedNormal.seed(SEED)), tf.val(0.1f))); - Conv2d conv1 = tf.nn - .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); - Variable conv1Biases = tf - .variable(tf.fill(tf.array(new int[]{32}), tf.val(0.0f))); - Relu relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); - - // First pooling layer - MaxPool pool1 = tf.nn - .maxPool(relu1, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), - PADDING_TYPE); - - // Second conv layer - Variable conv2Weights = tf.variable(tf.math.mul(tf.random - .truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.DTYPE, - TruncatedNormal.seed(SEED)), tf.val(0.1f))); - Conv2d conv2 = tf.nn - .conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); - Variable conv2Biases = tf - .variable(tf.fill(tf.array(new int[]{64}), tf.val(0.1f))); - Relu relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); - - // Second pooling layer - MaxPool pool2 = tf.nn - .maxPool(relu2, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), - PADDING_TYPE); - - // Flatten inputs - Reshape flatten = tf.reshape(pool2, tf.concat(Arrays - .asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})), - tf.array(new int[]{-1})), tf.val(0))); - - // Fully connected layer - Variable fc1Weights = tf.variable(tf.math.mul(tf.random - .truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE, - TruncatedNormal.seed(SEED)), tf.val(0.1f))); - Variable fc1Biases = tf - .variable(tf.fill(tf.array(new int[]{512}), tf.val(0.1f))); - Relu relu3 = tf.nn - .relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); - - // Softmax layer - Variable fc2Weights = tf.variable(tf.math.mul(tf.random - .truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.DTYPE, - TruncatedNormal.seed(SEED)), tf.val(0.1f))); - Variable fc2Biases = tf - .variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.val(0.1f))); - - Add logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); - - // Predicted outputs - Softmax prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); - - // Loss function & regularization - OneHot oneHot = tf - .oneHot(labels, tf.val(10), tf.val(1.0f), tf.val(0.0f)); - SoftmaxCrossEntropyWithLogits batchLoss = tf.nn - .softmaxCrossEntropyWithLogits(logits, oneHot); - Mean labelLoss = tf.math.mean(batchLoss.loss(), tf.val(0)); - Add regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math - .add(tf.nn.l2Loss(fc1Biases), - tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); - Add loss = tf.withName(TRAINING_LOSS).math - .add(labelLoss, tf.math.mul(regularizers, tf.val(5e-4f))); - - // Optimizer - Optimizer optimizer; - switch (optimizerName) { - case "AdaDelta": - case "Adadelta": - case "adadelta": - optimizer = new AdaDelta(graph, 1f, 0.95f, 1e-8f); - break; - case "AdaGradDA": - case "AdagradDA": - case "adagradda": - optimizer = new AdaGradDA(graph, 0.01f); - break; - case "AdaGrad": - case "Adagrad": - case "adagrad": - optimizer = new AdaGrad(graph, 0.01f); - break; - case "Adam": - case "adam": - optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); - break; - case "SGD": - case "sgd": - optimizer = new GradientDescent(graph, 0.01f); - break; - case "Momentum": - case "momentum": - optimizer = new Momentum(graph, 0.01f, 0.9f, false); - break; - case "RMSProp": - case "rmsprop": - optimizer = new RMSProp(graph, 0.01f, 0.9f, 0.0f, 1e-10f, false); - break; - default: - throw new IllegalArgumentException("Unknown optimizer " + optimizerName); - } - logger.info("Optimizer = " + optimizer.toString()); - Op minimize = optimizer.minimize(loss, TRAIN); - - Op init = graph.variablesInitializer(); - - return graph; - } - - public static void train(Session session, int epochs, int minibatchSize, float[][][][] data, - int[] labels) { - // Initialises the parameters. - session.runner().addTarget(INIT).run(); - logger.info("Initialised the model parameters"); - - float[][][][] featureBatch = new float[minibatchSize][][][]; - int[] labelBatch = new int[minibatchSize]; - - int interval = 0; - for (int i = 0; i < epochs; i++) { - logger.log(Level.INFO, "Starting epoch " + i); - //Tensor epoch = Tensor.create(i); - for (int j = 0; j < data.length; j += minibatchSize) { - for (int k = j, m = 0; k < (j + minibatchSize) && k < data.length; k++, m++) { - featureBatch[m] = data[k]; - labelBatch[m] = labels[k]; - } - //logger.info("Batch = " + batch.size()); - try (Tensor input = Tensor.create(featureBatch); - Tensor target = Tensor.create(labelBatch); - Tensor loss = session.runner() - .feed(INPUT_NAME, input) - .feed(TARGET, target) - .addTarget(TRAIN) - .fetch(TRAINING_LOSS) - .run().get(0)) { - if (interval % 100 == 0) { - logger.log(Level.INFO, - "Iteration = " + interval + ", training loss = " + loss.floatValue()); - } - } - interval++; - } - //epoch.close(); - } - } - - /** - * Find the maximum probability and return it's index. - * - * @param probabilities The probabilites. - * @return The index of the max. - */ - public static int pred(float[] probabilities) { - float maxVal = Float.NEGATIVE_INFINITY; - int idx = 0; - for (int i = 0; i < probabilities.length; i++) { - if (probabilities[i] > maxVal) { - maxVal = probabilities[i]; - idx = i; - } - } - return idx; - } - - public static DataTuple loadData(String path) throws IOException, ClassNotFoundException { - try (ObjectInputStream ois = new ObjectInputStream( - new BufferedInputStream(new FileInputStream(path)))) { - float[][][][] data = (float[][][][]) ois.readObject(); - int[] labels = (int[]) ois.readObject(); - return new DataTuple(data, labels); - } - } - - private static class DataTuple { - - public final float[][][][] features; - public final int[] labels; - - public DataTuple(float[][][][] features, int[] labels) { - this.features = features; - this.labels = labels; - } - } - - public static void main(String[] args) throws IOException, ClassNotFoundException { - logger.info( - "Usage: MNISTTest "); - - logger.info("Loading training data"); - DataTuple train = loadData(args[3]); - logger.info("Loading testing data"); - DataTuple test = loadData(args[4]); - - logger.info("Loaded data."); - - float[][][][] trainData = train.features; - int[] trainLabels = train.labels; - - float[][][][] testData = test.features; - int[] testLabels = test.labels; - - logger.info("Loaded " + trainLabels.length + " training labels"); - logger.info("Loaded " + testLabels.length + " testing labels"); - - int epochs = Integer.parseInt(args[0]); - int minibatchSize = Integer.parseInt(args[1]); - - int correctCount = 0; - int[][] confusionMatrix = new int[10][10]; - - try (Graph graph = build(args[2]); - Session session = new Session(graph)) { - train(session, epochs, minibatchSize, trainData, trainLabels); - - logger.info("Trained model"); - - float[][][][] featureBatch = new float[minibatchSize][][][]; - int[] labelBatch = new int[minibatchSize]; - float[][] prediction; - - for (int j = 0; j < testData.length; j += minibatchSize) { - for (int k = j, m = 0; k < (j + minibatchSize) && k < testData.length; k++, m++) { - featureBatch[m] = testData[k]; - labelBatch[m] = testLabels[k]; - } - try (Tensor transformedInput = Tensor.create(featureBatch); - Tensor outputTensor = session.runner() - .feed(INPUT_NAME, transformedInput) - .fetch(OUTPUT_NAME).run().get(0)) { - prediction = outputTensor.copyTo(new float[minibatchSize][NUM_LABELS]); - } - - for (int k = 0; k < labelBatch.length; k++) { - int predLabel; - - predLabel = pred(prediction[k]); - if (predLabel == labelBatch[k]) { - correctCount++; - } - - confusionMatrix[labelBatch[k]][predLabel]++; - } - - if (j % 1000 == 0) { - logger.log(Level.INFO, "Cur accuracy = " + ((float) correctCount) / (j + minibatchSize)); - } - } - - logger.info("Final accuracy = " + ((float) correctCount) / testLabels.length); - - StringBuilder sb = new StringBuilder(); - sb.append("Label"); - for (int i = 0; i < confusionMatrix.length; i++) { - sb.append(String.format("%1$5s", "" + i)); - } - sb.append("\n"); - - for (int i = 0; i < confusionMatrix.length; i++) { - sb.append(String.format("%1$5s", "" + i)); - for (int j = 0; j < confusionMatrix[i].length; j++) { - sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); - } - sb.append("\n"); - } - - System.out.println(sb.toString()); - } - - } -} diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java index f1cd9ae8e37..1267a6ac001 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaDelta.java @@ -49,6 +49,17 @@ public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { this.epsilon = epsilon; } + public AdaDelta(Graph graph, String name, float learningRate) { + this(graph, name, learningRate, 0.95f, 1e-8f); + } + + public AdaDelta(Graph graph, String name, float learningRate, float rho, float epsilon) { + super(graph, name); + this.learningRate = learningRate; + this.rho = rho; + this.epsilon = epsilon; + } + @Override protected void createSlots(List> variables) { for (Output v : variables) { diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java index d34e331aa6b..a320153fab5 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGrad.java @@ -46,6 +46,16 @@ public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { this.initialAccumulatorValue = initialAccumulatorValue; } + public AdaGrad(Graph graph, String name, float learningRate) { + this(graph, name, learningRate, 0.01f); + } + + public AdaGrad(Graph graph, String name, float learningRate, float initialAccumulatorValue) { + super(graph, name); + this.learningRate = learningRate; + this.initialAccumulatorValue = initialAccumulatorValue; + } + @Override protected void createSlots(List> variables) { for (Output v : variables) { diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java index a1cd38d6bb2..4c4fc8d24ef 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/AdaGradDA.java @@ -55,6 +55,19 @@ public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, this.l2Strength = l2Strength; } + public AdaGradDA(Graph graph, String name, float learningRate) { + this(graph, name, learningRate, 0.1f, 0.0f, 0.0f); + } + + public AdaGradDA(Graph graph, String name, float learningRate, float initialAccumulatorValue, float l1Strength, + float l2Strength) { + super(graph, name); + this.learningRate = learningRate; + this.initialAccumulatorValue = initialAccumulatorValue; + this.l1Strength = l1Strength; + this.l2Strength = l2Strength; + } + @Override protected Optional> prepare(String name) { return Optional.of(tf.assignAdd(globalStep, tf.val(1L))); diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java index 1a5b757d24a..4a0afb6ae2f 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Adam.java @@ -69,6 +69,18 @@ public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float this.epsilon = epsilon; } + public Adam(Graph graph, String name, float learningRate) { + this(graph, name, learningRate, 0.9f, 0.999f, 1e-8f); + } + + public Adam(Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { + super(graph, name); + this.learningRate = learningRate; + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + @Endpoint(name = "adam_minimize") public static Op createAdamMinimize(Scope scope, Operand loss, float learningRate, float betaOne, float betaTwo, float epsilon, diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java index 08876b9c702..58267bc2534 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/GradientDescent.java @@ -32,6 +32,11 @@ public GradientDescent(Graph graph, float learningRate) { this.learningRate = learningRate; } + public GradientDescent(Graph graph, String name, float learningRate) { + super(graph, name); + this.learningRate = learningRate; + } + @Override protected Operand applyDense(Output gradient, Output variable) { return tf.train.applyGradientDescent(variable, diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java index 9ae866b99c6..fcec40bf9d3 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Momentum.java @@ -46,6 +46,13 @@ public Momentum(Graph graph, float learningRate, float momentum, boolean useNest this.useNesterov = useNesterov; } + public Momentum(Graph graph, String name, float learningRate, float momentum, boolean useNesterov) { + super(graph, name); + this.learningRate = learningRate; + this.momentum = momentum; + this.useNesterov = useNesterov; + } + @Override protected void createSlots(List> variables) { for (Output v : variables) { diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java index 2c9fece10d3..fc540f43ffc 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/Optimizer.java @@ -34,7 +34,7 @@ import org.tensorflow.types.family.TType; /** - * + * Base class for gradient optimizers. */ public abstract class Optimizer { @@ -57,6 +57,12 @@ public abstract class Optimizer { */ private final Map>> slots; + /** + * Builds an optimizer for the supplied graph. + *

+ * Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. + * @param graph The graph to optimize. + */ protected Optimizer(Graph graph) { this.graph = graph; this.tf = Ops.create(graph).withName(getOptimizerName()); @@ -64,6 +70,18 @@ protected Optimizer(Graph graph) { this.globals = new ArrayList<>(); } + /** + * Builds an optimizer for the supplied graph. + * @param graph The graph to optimize. + * @param name The base name for the operations. + */ + protected Optimizer(Graph graph, String name) { + this.graph = graph; + this.tf = Ops.create(graph).withName(name); + this.slots = new HashMap<>(); + this.globals = new ArrayList<>(); + } + public static String createName(Output variable, String slotName) { return variable.op().name() + "-" + slotName; } diff --git a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java index 9722fd3ba65..5ac18bd7163 100644 --- a/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java +++ b/tensorflow-training/src/main/java/org/tensorflow/training/optimizers/RMSProp.java @@ -54,6 +54,20 @@ public RMSProp(Graph graph, float learningRate, float decay, float momentum, flo this.centered = centered; } + public RMSProp(Graph graph, String name, float learningRate) { + this(graph, name, learningRate, 0.9f, 0.0f, 1e-10f, false); + } + + public RMSProp(Graph graph, String name, float learningRate, float decay, float momentum, float epsilon, + boolean centered) { + super(graph, name); + this.learningRate = learningRate; + this.decay = decay; + this.momentum = momentum; + this.epsilon = epsilon; + this.centered = centered; + } + @Override protected void createSlots(List> variables) { for (Output v : variables) { From 7a2fd256ff60786165f6f14772b9af05a0d3c9f2 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 25 Feb 2020 12:47:49 -0500 Subject: [PATCH 21/22] Adding a guard to prevent variableWithInit being called on an EagerSession. --- .../src/gen/annotations/org/tensorflow/op/Ops.java | 3 +++ .../src/main/java/org/tensorflow/op/core/Helpers.java | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 96730d9cd14..77bef553018 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -7335,6 +7335,9 @@ public VarIsInitializedOp varIsInitializedOp(Operand resource) { /** * Factory method to create a new Variable with it's initializer. + *

+ * Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op + * does not work in an EagerSession. * * @param scope current scope * @param init The op to use to initialise this variable. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java index 738aea50000..170fef6eb0e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Helpers.java @@ -38,7 +38,9 @@ private Helpers() {} /** * Factory method to create a new Variable with it's initializer. - * + *

+ * Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op + * does not work in an EagerSession. * @param scope current scope * @param init The op to use to initialise this variable. * @param options carries optional attributes values @@ -53,6 +55,8 @@ public static Variable createVariableWithInit(Scope scope, if (exEnv instanceof Graph) { Graph graph = (Graph) exEnv; graph.addInitializer(assignOp); + } else { + throw new IllegalArgumentException("variable with init is only supported on Graph sessions."); } return newVar; From 1b98f5227169bd09abd1390039ea6c41cf09075c Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Sun, 1 Mar 2020 22:29:09 -0500 Subject: [PATCH 22/22] Update Ops.java --- .../src/gen/annotations/org/tensorflow/op/Ops.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 77bef553018..6be58021dc1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -7334,7 +7334,7 @@ public VarIsInitializedOp varIsInitializedOp(Operand resource) { } /** - * Factory method to create a new Variable with it's initializer. + * Factory method to create a new Variable with its initializer. *

* Only supported on Graph sessions as the {@link org.tensorflow.op.core.Assign} op * does not work in an EagerSession.