# 
# this python program was just copied over from
# start page of tensorflow:
#        https://www.tensorflow.org/overview?hl=en
#
# It trains a tensorflow model with the MNIST data set
# which contains 70.000 images of hand-written digits
# from '0' to '9'.
#
# The output looks like this:                                                             (Phase)
#                                                                                          ----.
#Epoch 1/5                                                                                     :
#                                                                                              :
#1875/1875 [==============================] - 4s 2ms/step - loss: 0.2968 - accuracy: 0.9142    :
#Epoch 2/5                                                                                     :
#                                                                                              :
#1875/1875 [==============================] - 3s 2ms/step - loss: 0.1446 - accuracy: 0.9576    :
#Epoch 3/5                                                                                    (1)
#                                                                                              :
#1875/1875 [==============================] - 3s 2ms/step - loss: 0.1085 - accuracy: 0.9670    :
#Epoch 4/5                                                                                     :
#                                                                                              :
#1875/1875 [==============================] - 3s 1ms/step - loss: 0.0870 - accuracy: 0.9734    :
#Epoch 5/5                                                                                     :
#                                                                                              :
#1875/1875 [==============================] - 3s 1ms/step - loss: 0.0743 - accuracy: 0.9766    :
#                                                                                          ----+
#313/313 [==============================] - 1s 1ms/step - loss: 0.0741 - accuracy: 0.9760     (2)
#                                                                                          ----´
#
# Phase (1) is the call to model.fit(), which uses the training part of the mnist data
#           it does so by running five times over all training data  
# Phase (2) is then testing the trained model on a test data set, i.e. data the model
#           has not seen during training.
#           This last line should show the expected accuracy for unknown images of digits.
#           In the example above, 97.6% of all input images were classified correctly,
#           leaving 2.4% which were not.
import tensorflow as tf
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

#
# Bosch IoT Insights' addition to save trained model weights:
#
model.save_weights("src/.mnist")
#
# That last line results in two files created:
# * mnist.index
# * mnist.data-00000-of-00001
#
# If the achieved accuracy of this training run is sufficient, those two
# files can be included in the actual model usage within Bosch IoT Insights
# by copying them into the ./src folder.
# The step.py there will recreate the same model and loads the weights using
# code like:
# model = create_model()
# model.load_weights("src/.mnist") 