Skip to content

Commit ca8bd66

Browse files
committed
update default example
1 parent 5de5f3b commit ca8bd66

File tree

3 files changed

+98
-84
lines changed

3 files changed

+98
-84
lines changed

Content/ExampleAssets/Maps/Mnist.umap

405 Bytes
Binary file not shown.
Binary file not shown.

Content/Scripts/mnistSpawnSamples.py

Lines changed: 98 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,128 @@
1-
#Converted to ue4 use from: https://www.tensorflow.org/get_started/mnist/beginners
2-
#mnist_softmax.py: https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/examples/tutorials/mnist/mnist_softmax.py
3-
4-
# Import data
5-
from tensorflow.examples.tutorials.mnist import input_data
1+
#converted for ue4 use from
2+
#https://github.com/tensorflow/docs/blob/master/site/en/tutorials/_index.ipynb
63

74
import tensorflow as tf
85
import unreal_engine as ue
96
from TFPluginAPI import TFPluginAPI
107

11-
import operator
8+
#additional includes
9+
from tensorflow.python.keras import backend as K #to ensure things work well with multi-threading
10+
import numpy as np #for reshaping input
11+
import operator #used for getting max prediction from 1x10 output array
12+
import random
1213

13-
class MnistSimple(TFPluginAPI):
14-
15-
#expected api: storedModel and session, json inputs
16-
def onJsonInput(self, jsonInput):
17-
#expect an image struct in json format
18-
pixelarray = jsonInput['pixels']
19-
ue.log('image len: ' + str(len(pixelarray)))
14+
class MnistTutorial(TFPluginAPI):
2015

21-
#embedd the input image pixels as 'x'
22-
feed_dict = {self.model['x']: [pixelarray]}
16+
#keras stop callback
17+
class StopCallback(tf.keras.callbacks.Callback):
18+
def __init__(self, outer):
19+
self.outer = outer
2320

24-
result = self.sess.run(self.model['y'], feed_dict)
21+
def on_train_begin(self, logs={}):
22+
self.losses = []
2523

26-
#convert our raw result to a prediction
27-
index, value = max(enumerate(result[0]), key=operator.itemgetter(1))
24+
def on_batch_end(self, batch, logs={}):
25+
if(self.outer.shouldStop):
26+
#notify on first call
27+
if not (self.model.stop_training):
28+
ue.log('Early stop called!')
29+
self.model.stop_training = True
2830

29-
ue.log('max: ' + str(value) + 'at: ' + str(index))
31+
else:
32+
if(batch % 5 == 0):
33+
#json convertible types are float64 not float32
34+
logs['acc'] = np.float64(logs['acc'])
35+
logs['loss'] = np.float64(logs['loss'])
36+
self.outer.callEvent('TrainingUpdateEvent', logs, True)
3037

31-
#set the prediction result in our json
32-
jsonInput['prediction'] = index
38+
#callback an example image from batch to see the actual data we're training on
39+
if((batch*self.outer.batch_size) % 100 == 0):
40+
index = random.randint(0,self.outer.batch_size)*batch
41+
self.outer.jsonPixels['pixels'] = self.outer.x_train[index].ravel().tolist()
42+
self.outer.callEvent('PixelEvent', self.outer.jsonPixels, True)
3343

34-
return jsonInput
3544

36-
#expected api: no params forwarded for training? TBC
37-
def onBeginTraining(self):
45+
#Called when TensorflowComponent sends Json input
46+
def onJsonInput(self, jsonInput):
47+
#build the result object
48+
result = {'prediction':-1}
3849

39-
ue.log("starting mnist simple training")
50+
#If we try to predict before training is complete
51+
if not hasattr(self, 'model'):
52+
ue.log_warning("Warning! No 'model' found, prediction invalid. Did training complete?")
53+
return result
4054

41-
self.scripts_path = ue.get_content_dir() + "Scripts"
42-
self.data_dir = self.scripts_path + '/dataset/mnist'
55+
#prepare the input, reshape 784 array to a 1x28x28 array
56+
x_raw = jsonInput['pixels']
57+
x = np.reshape(x_raw, (1, 28, 28))
4358

44-
mnist = input_data.read_data_sets(self.data_dir, one_hot=True)
59+
#run the input through our network using stored model and graph
60+
with self.graph.as_default():
61+
output = self.model.predict(x)
4562

46-
# Create the model
47-
x = tf.placeholder(tf.float32, [None, 784])
48-
W = tf.Variable(tf.zeros([784, 10]))
49-
b = tf.Variable(tf.zeros([10]))
50-
y = tf.matmul(x, W) + b
63+
#convert output array to max value prediction index (0-10)
64+
index, value = max(enumerate(output[0]), key=operator.itemgetter(1))
5165

52-
# Define loss and optimizer
53-
y_ = tf.placeholder(tf.float32, [None, 10])
66+
#Optionally log the output so you can see the weights for each value and final prediction
67+
ue.log('Output array: ' + str(output) + ',\nPrediction: ' + str(index))
5468

55-
# The raw formulation of cross-entropy,
56-
#
57-
# tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
58-
# reduction_indices=[1]))
59-
#
60-
# can be numerically unstable.
61-
#
62-
# So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
63-
# outputs of 'y', and then average across the batch.
64-
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y))
65-
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
69+
result['prediction'] = index
70+
71+
return result
72+
73+
#Called when TensorflowComponent signals begin training (default: begin play)
74+
def onBeginTraining(self):
75+
ue.log("starting MnistTutorial training")
6676

67-
#update session for this thread
68-
self.sess = tf.InteractiveSession()
69-
tf.global_variables_initializer().run()
77+
#training parameters
78+
self.batch_size = 128
79+
num_classes = 10
80+
epochs = 3
7081

71-
training_range = 1000
82+
#reset the session each time we get training calls
83+
self.kerasCallback = self.StopCallback(self)
84+
K.clear_session()
7285

73-
#pre-fill our callEvent data to minimize setting
86+
#load mnist data set
87+
mnist = tf.keras.datasets.mnist
88+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
89+
90+
#rescale 0-255 -> 0-1.0
91+
x_train, x_test = x_train / 255.0, x_test / 255.0
92+
93+
#define model
94+
model = tf.keras.models.Sequential([
95+
tf.keras.layers.Flatten(),
96+
tf.keras.layers.Dense(512, activation=tf.nn.relu),
97+
tf.keras.layers.Dropout(0.2),
98+
tf.keras.layers.Dense(num_classes, activation=tf.nn.softmax)
99+
])
100+
101+
model.compile( optimizer='adam',
102+
loss='sparse_categorical_crossentropy',
103+
metrics=['accuracy'])
104+
105+
#pre-fill our callEvent data to optimize callbacks
74106
jsonPixels = {}
75107
size = {'x':28, 'y':28}
76108
jsonPixels['size'] = size
109+
self.jsonPixels = jsonPixels
110+
self.x_train = x_train
111+
112+
#this will do the actual training
113+
model.fit(x_train, y_train,
114+
batch_size=self.batch_size,
115+
epochs=epochs,
116+
callbacks=[self.kerasCallback])
117+
model.evaluate(x_test, y_test)
118+
119+
ue.log("Training complete.")
77120

78-
# Train
79-
for i in range(training_range):
80-
batch_xs, batch_ys = mnist.train.next_batch(100)
81-
self.sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
82-
if i % 100 == 0:
83-
ue.log(i)
84-
85-
#send two pictures from our dataset per batch
86-
jsonPixels['pixels'] = batch_xs[0].tolist()
87-
self.callEvent('PixelEvent', jsonPixels, True)
88-
jsonPixels['pixels'] = batch_xs[49].tolist()
89-
self.callEvent('PixelEvent', jsonPixels, True)
90-
91-
if(self.shouldStop):
92-
ue.log('early break')
93-
break
94-
95-
# Test trained model
96-
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
97-
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
98-
finalAccuracy = self.sess.run(accuracy, feed_dict={x: mnist.test.images,
99-
y_: mnist.test.labels})
100-
ue.log('final training accuracy: ' + str(finalAccuracy))
101-
102-
#return trained model
103-
self.model = {'x':x, 'y':y, 'W':W,'b':b}
104-
105-
#store optional summary information
106-
self.summary = {'x':str(x), 'y':str(y), 'W':str(W), 'b':str(b)}
107-
108-
self.stored['summary'] = self.summary
109-
return self.stored
121+
#store our model and graph for prediction
122+
self.graph = tf.get_default_graph()
123+
self.model = model
110124

111125
#required function to get our api
112126
def getApi():
113127
#return CLASSNAME.getInstance()
114-
return MnistSimple.getInstance()
128+
return MnistTutorial.getInstance()

0 commit comments

Comments
 (0)