#!/usr/bin/python3
from insights_protocol import StepBase, ExchangeDataHolder
from bson import json_util
import sys

import tensorflow as tf
import numpy as np
# print('python' + sys.executable)
def create_model():
    model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(28, 28)),
      tf.keras.layers.Dense(128, activation='relu'),
    #  tf.keras.layers.Dropout(0.2), # -- only needed during training
      tf.keras.layers.Dense(10, activation='softmax')
    ])
    #loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

model = create_model()
model.load_weights("src/.mnist")


class MyInsightsPythonStep(StepBase):
    # Called for every input
    async def invoke(self, exchange_data_holder: ExchangeDataHolder) -> ExchangeDataHolder:
        print("\nStart processing " + str(exchange_data_holder))

        document = exchange_data_holder.documents[0]
        
        # prepare the data
        image_raw = np.array( document['payload'] )
        image = tf.io.decode_image(image_raw, dtype=tf.float32, expand_animations=False)
        image = tf.image.resize(image, [28,28])
        image = tf.squeeze(image)
    
        input_image = np.array([image.numpy()])
        
        # run it through the NN
     
        prediction = model(input_image).numpy()
        
        # store the result
        filename = document['metaData'].get('fileName')
        document['_id'] = filename or exchange_data_holder.id.document
        document['metaData']['prediction'] = prediction.tolist()
        document['payload'] = { 
            'img': document['payload'],
             'a': image.numpy().tolist() 
        }

        print("processed %s\n" % exchange_data_holder.id)

        return exchange_data_holder

if __name__ == "__main__":
    input = sys.stdin.buffer
    output = sys.stdout.buffer
    sys.stdout = sys.stderr  # make any print in code going to stderr
    MyInsightsPythonStep(input, output).process_stream()
