Source code for tomaat.frameworks.tf

import tensorflow as tf


[docs]class Prediction(object): def __init__(self, model_path, input_tensors_names, input_fields, output_tensors_names, output_fields): super(Prediction, self).__init__() tf_conf = tf.ConfigProto() tf_conf.gpu_options.allow_growth = True self.sess = tf.Session(config=tf_conf) _ = tf.saved_model.loader.load(self.sess, [tf.saved_model.tag_constants.SERVING], model_path) self.graph = tf.get_default_graph() self.input_tensors = [] for input_name in input_tensors_names: self.input_tensors.append(self.graph.get_tensor_by_name(input_name)) self.input_fields = input_fields self.output_tensors = [] for output_name in output_tensors_names: self.output_tensors.append(self.graph.get_tensor_by_name(output_name)) self.output_fields = output_fields assert len(self.input_tensors) == len(self.input_fields) assert len(self.output_tensors) == len(self.output_fields) def __call__(self, data): feed_dict = {} for input_tensor, input_field in zip(self.input_tensors, self.input_fields): feed_dict[input_tensor] = data[input_field] outputs = self.sess.run(self.output_tensors, feed_dict=feed_dict) for output, output_field in zip(outputs, self.output_fields): data[output_field] = output return data