Source code for polo.marco.run_marco

import os
import sys
import operator
from polo import make_default_logger

logger = make_default_logger(__name__)

# https://github.com/tensorflow/models/blob/master/research/marco/Automated_Marco.py
[docs]def run_model(tf_predictor, image_path): '''Given a tensorflow predictor (the MARCO model) and the path to an image, runs the model on that image. Returns a tuple where the first item is the classification with greatest confidence and the second is a dictionary where keys are image classification types and values are model confidence for that classification. The image classifications of this dictionary are used universally throughout the program and are accessible through the :const:`IMAGE_CLASSIFICATIONS` constant. :param tf_predictor: Loaded MARCO model :type tf_predictor: tensorflow model :param image_path: Path to the image to be classified by the model :type image_path: str :return: tuple :rtype: tuple ''' try: def load_image(): with open(image_path, 'rb') as f: return {'image_bytes': [f.read()]} sys.stdout = open(os.devnull, "w") data = load_image() results = tf_predictor(data) sys.stdout = sys.__stdout__ # suppress output vals = results['scores'][0] classes = results['classes'][0] dictionary = dict(zip(classes, vals)) prediction = max(dictionary.items(), key=operator.itemgetter(1))[ 0] # gets max confidence prediction new_dict = {} for key in dictionary: if type(key) == bytes: new_dict[key.decode('utf-8')] = dictionary[key] else: new_dict[key] = dictionary # round values in the new dict new_dict = {key: round(val, 3) for key, val in new_dict.items()} return prediction.decode('utf-8'), new_dict except Exception as e: logger.error('Caught exception {}'.format(e)) return None, {}