import json
import requests
import tempfile
import uuid
import os
import SimpleITK as sitk
import base64
import traceback
import numpy as np
import shutil
try:
# For Python 3.0 and later
from urllib.request import urlopen
except ImportError:
# Fall back to Python 2's urllib2
from urllib2 import urlopen
from multiprocessing import Process, Manager, Lock
from klein import Klein
from twisted.internet.defer import inlineCallbacks, returnValue, DeferredLock
from twisted.internet import threads
from twisted.internet.task import LoopingCall
from twisted.internet import reactor
from twisted.logger import Logger
ANNOUNCEMENT_SERVER_URL = 'http://tomaat.cloud:8001/announce'
ANNOUNCEMENT_INTERVAL = 1600 # seconds
logger = Logger()
[docs]def is_base64(s):
try:
if base64.b64encode(base64.b64decode(s)) == s:
return True
except:
pass
return False
[docs]def do_announcement(announcement_server_url, message):
json_message = json.dumps(message)
try:
response = requests.post(announcement_server_url, data=json_message)
response_json = response.json()
if response_json['status'] != 0:
logger.error('status {}'.format(response_json['status']))
logger.error('errors: {}'.format(response_json['error']))
except:
logger.error('WARNING: ERROR while connecting to announcement service.')
pass
[docs]class TomaatApp(object):
"""
A TomaatApp is an object that implements the functionality of the user application. More specifically,
this object implements the workflow that is needed by the app. It requires the following arguments
"""
def __init__(self, preprocess_fun, inference_fun, postprocess_fun):
"""
To instantiate a TomaatApp the following arguments are needed
:type preprocess_fun: Callable function or callable object implementing pre-processing
:type inference_fun: Callable function or callable object implementing inference
:type postprocess_fun: Callable function or callable object implementing post-processing
"""
super(TomaatApp, self).__init__()
self.preprocess_fun = preprocess_fun
self.inference_fun = inference_fun
self.postprocess_fun = postprocess_fun
def __call__(self, data, gpu_lock=None):
"""
When a TomaatApp object is called it performs pre-processing, inference and post-processing
:type data: dict dictionary containing data. The dictionary must contain the fields expected by pre-processing
:type gpu_lock: DeferredLock optional lock to allow threads to safely use the GPU. No GPU => no lock needed
:return: dict containing inference results after post-processing
"""
transformed_data = self.preprocess_fun(data)
if gpu_lock is not None:
gpu_lock.acquire() # acquire GPU lock
result = self.inference_fun(transformed_data) # GPU call
if gpu_lock is not None:
gpu_lock.release() # release GPU lock
transformed_result = self.postprocess_fun(result)
return transformed_result
[docs]class TomaatService(object):
klein_app = Klein()
announcement_task = None
gpu_lock = DeferredLock()
def __init__(self, config, app, input_interface, output_interface):
"""
To instantitate a TomaatService the following arguments are needed
:type config: dict dictionary containing configuration for the service
:type app: TomaatApp implementing the user application to be served through TomaatService
:type input_interface: dict containing the specification for input interface (See documentation or example)
:type output_interface: dict containing the specification for output interface (See documentation or example)
"""
super(TomaatService, self).__init__()
self.config = config
self.app = app
self.input_interface = input_interface
self.output_interface = output_interface
[docs] @klein_app.route('/interface', methods=['GET'])
def interface(self, request):
request.setHeader('Access-Control-Allow-Origin', '*')
request.setHeader('Access-Control-Allow-Methods', 'GET')
request.setHeader('Access-Control-Allow-Headers', '*')
request.setHeader('Access-Control-Max-Age', 2520) # 42 hours
return json.dumps(self.input_interface)
[docs] @klein_app.route('/predict', methods=['POST'])
@inlineCallbacks
def predict(self, request):
request.setHeader('Access-Control-Allow-Origin', '*')
request.setHeader('Access-Control-Allow-Methods', 'POST')
request.setHeader('Access-Control-Allow-Headers', '*')
request.setHeader('Access-Control-Max-Age', 2520) # 42 hours
logger.info('predicting...')
result = yield threads.deferToThread(self.received_data_handler, request)
returnValue(result)
[docs] def start_service_announcement(
self,
fun=do_announcement,
announcement_server_url=ANNOUNCEMENT_SERVER_URL,
delay=ANNOUNCEMENT_INTERVAL
):
try:
api_key = self.config['api_key']
except KeyError:
raise ValueError('Api-key is missing')
try:
host = self.config['host']
except KeyError:
ip = urlopen('http://ip.42.pl/raw').read()
port = self.config['port']
host = 'http://' + str(ip) + ':' + str(port) + '/'
pass
message = {
'api_key': api_key,
'prediction_url': host+'/predict',
'interface_url': host+'/interface',
'name': self.config['name'],
'modality': self.config['modality'],
'task': self.config['task'],
'anatomy': self.config['anatomy'],
'description': self.config['description'],
}
self.announcement_task = LoopingCall(fun, *(announcement_server_url, message))
self.announcement_task.start(delay)
[docs] def stop_service_announcement(self):
self.announcement_task.stop()
[docs] def make_error_response(self, message):
"""
Create simple error message to be returned to the client as plain text
:type message: str error message to be returned to the client
:return: response to be returned to client
"""
response = [{'type': 'PlainText', 'content': message, 'label': 'Error!'}]
return response
[docs] def parse_request(self, request, savepath):
"""
This function takes in the content of the client message and creates a dictionary containing data.
The service interface, that was specified in the input_interface dictionary specified at init,
contains the specifications of the data that is needed to run this service and the fields of the dictionary
returned by this function where the client data should be stored.
:type request: dict request sent by the client
:return: dict containing data that can be fed to the pre-processing, inference, post-processing pipeline
"""
data = {}
for element in self.input_interface:
raw = request.args[element['destination'].encode('UTF-8')]
if element['type'] == 'volume':
uid = uuid.uuid4()
mha_file = str(uid).replace('-', '') + '.mha'
tmp_filename_mha = os.path.join(savepath, mha_file)
with open(tmp_filename_mha, 'wb') as f:
if is_base64(raw[0]):
f.write(base64.decodestring(raw[0]))
else:
f.write(raw[0])
print(
'Your client has passed RAW file content instead of base64 encoded string: '
'this is deprecated and will result in errors in future version of the server'
)
data[element['destination']] = [tmp_filename_mha]
elif element['type'] == 'slider':
data[element['destination']] = [float(raw[0])]
elif element['type'] == 'checkbox':
data[element['destination']] = [str(raw[0])]
elif element['type'] == 'radiobutton':
data[element['destination']] = [str(raw[0])]
elif element['type'] == 'fiducials':
# Each coordinate is separated by ';'
# Each coord value is separated by ','
fiducial_string = str(raw[0])
fiducial_list = [ [ float(val) for val in coords.split(',')] for coords in fiducial_string.split(';')]
data[element['destination']] = [np.asarray(fiducial_list)]
elif element['type'] == 'transform':
dtype = {
'nii.gz':'grid',
'h5': 'bspline',
'mat':'linear'
}
# transform encoding:
# <filetype> newline
# <base64 of file>
# determine file type
req = str(raw[0])
trf_file_type = ""
for trf_type in dtype.keys():
if req.startswith(trf_type+"\n"):
trf_file_type = '.' + trf_type
if not trf_file_type:
# invalid format
return
# store file
uid = uuid.uuid4()
trf_file = str(uid) + trf_file_type
tmp_transform = os.path.join(savepath, trf_file)
with open(tmp_transform, 'wb') as f:
# write base64 data
f.write(base64.decodestring(req[len(trf_file_type):]))
data[element['destination']] = [tmp_transform]
return data
[docs] def make_response(self, data, savepath):
"""
This function takes in the post-processed results of inference and creates a message for the client.
The message is created according to the directives specified in the output_interface dictionary passed
during instantiation of TomaatService object.
:type request: dict containing the inference results (stored in the appropriate fields)
:return: JSON containing response that can be returned to the client
"""
message = []
for element in self.output_interface:
type = element['type']
field = element['field']
if type == 'LabelVolume':
uid = uuid.uuid4()
mha_seg = str(uid).replace('-', '') + '_seg.mha'
tmp_label_volume = os.path.join(savepath, mha_seg)
writer = sitk.ImageFileWriter()
writer.SetFileName(tmp_label_volume)
writer.SetUseCompression(True)
writer.Execute(data[field][0])
with open(tmp_label_volume, 'rb') as f:
vol_string = base64.encodestring(f.read()).decode('utf-8')
message.append({'type': 'LabelVolume', 'content': vol_string, 'label': ''})
os.remove(tmp_label_volume)
elif type == 'VTKMesh':
import vtk
uid = uuid.uuid4()
vtk_mesh = str(uid).replace('-', '') + '_seg.vtk'
tmp_vtk_mesh = os.path.join(savepath, vtk_mesh)
writer = vtk.vtkPolyDataWriter()
writer.SetFileName(tmp_vtk_mesh)
writer.SetInput(data[field][0])
writer.SetFileTypeToASCII()
writer.Write()
with open(tmp_vtk_mesh, 'rb') as f:
mesh_string = base64.encodestring(f.read())
message.append({'type': 'VTKMesh', 'content': mesh_string, 'label': ''})
os.remove(tmp_vtk_mesh)
elif type == 'PlainText':
message.append({'type': 'PlainText', 'content': str(data[field][0]), 'label': ''})
elif type == 'Fiducials':
fiducial_array = data[field][0]
fiducial_str = ';'.join([','.join(map(str,fid_point)) for fid_point in fiducial_array])
message.append({'type': 'Fiducials', 'content': fiducial_str, 'label': ''})
elif type in ['TransformGrid','TransformBSpline','TransformLinear']:
trf_file_type = {
'TransformGrid':'nii.gz',
'TransformBSpline':'h5',
'TransformLinear':'mat',
}
uid = uuid.uuid4()
trf_file_name = str(uid) + '.' + trf_file_type[type]
trf_file_path = os.path.join(savepath, trf_file_name)
if type == "TransformGrid":
# Displacement fields are stored as regular volumes.
sitk.WriteImage(data[field][0],trf_file_path)
else:
sitk.WriteTransform(data[field][0],trf_file_path)
with open(trf_file_path, 'rb') as f:
vol_string = base64.encodestring(f.read()).decode("utf-8")
message.append({'type': type, 'content': vol_string, 'label': ''})
os.remove(trf_file_path)
return message
[docs] def received_data_handler(self, request):
savepath = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()).replace('-', ''))
os.mkdir(savepath)
try:
data = self.parse_request(request, savepath)
except:
traceback.print_exc()
logger.error('Server-side ERROR during request parsing')
response = self.make_error_response('Server-side ERROR during request parsing')
return json.dumps(response)
try:
transformed_result = self.app(data, gpu_lock=self.gpu_lock)
except:
traceback.print_exc()
logger.error('Server-side ERROR during processing')
response = self.make_error_response('Server-side ERROR during processing')
return json.dumps(response)
try:
response = self.make_response(transformed_result, savepath)
except:
traceback.print_exc()
logger.error('Server-side ERROR during response message creation')
response = self.make_error_response('Server-side ERROR during response message creation')
return json.dumps(response)
shutil.rmtree(savepath)
return json.dumps(response)
[docs] def run(self):
self.klein_app.run(port=self.config['port'], host='0.0.0.0')
reactor.run()
[docs]class TomaatServiceDelayedResponse(TomaatService):
announcement_task = None
gpu_lock = DeferredLock()
multiprocess_manager = Manager()
result_dict = multiprocess_manager.dict()
reqest_list = multiprocess_manager.list()
multiprocess_lock = Lock()
klein_app = Klein()
def __init__(self, no_concurrent_thread_execution=True, **kwargs):
super(TomaatServiceDelayedResponse, self).__init__(**kwargs)
self.no_concurrent_thread_execution = no_concurrent_thread_execution
[docs] def received_data_handler(self, request):
req_id = str(uuid.uuid4()).replace('-', '')
savepath = os.path.join(tempfile.gettempdir(), req_id)
os.mkdir(savepath)
def processing_thread():
if self.no_concurrent_thread_execution:
self.multiprocess_lock.acquire()
response = self.make_error_response('Server-side ERROR during processing')
try:
data = self.parse_request(request, savepath)
except:
traceback.print_exc()
logger.error('Server-side ERROR during request parsing')
try:
transformed_result = self.app(data, gpu_lock=self.gpu_lock)
except:
traceback.print_exc()
logger.error('Server-side ERROR during processing')
try:
response = self.make_response(transformed_result, savepath)
except:
traceback.print_exc()
logger.error('Server-side ERROR during response message creation')
response = [{
'type': 'PlainText',
'content': 'The results of your earlier request {} have been received'.format(req_id),
'label': ''
}] + response
self.result_dict[req_id] = response
shutil.rmtree(savepath)
if self.no_concurrent_thread_execution:
self.multiprocess_lock.release()
delegated_process = Process(target=processing_thread, args=())
delegated_process.start()
self.reqest_list.append(req_id)
response = [{'type': 'DelayedResponse', 'request_id': req_id}]
return json.dumps(response)
[docs] @klein_app.route('/interface', methods=['GET'])
def interface(self, request):
request.setHeader('Access-Control-Allow-Origin', '*')
request.setHeader('Access-Control-Allow-Methods', 'GET')
request.setHeader('Access-Control-Allow-Headers', '*')
request.setHeader('Access-Control-Max-Age', 2520) # 42 hours
return json.dumps(self.input_interface)
[docs] @klein_app.route('/predict', methods=['POST'])
@inlineCallbacks
def predict(self, request):
request.setHeader('Access-Control-Allow-Origin', '*')
request.setHeader('Access-Control-Allow-Methods', 'POST')
request.setHeader('Access-Control-Allow-Headers', '*')
request.setHeader('Access-Control-Max-Age', 2520) # 42 hours
logger.info('predicting...')
result = yield threads.deferToThread(self.received_data_handler, request)
returnValue(result)
[docs] @klein_app.route('/responses', methods=['POST'])
@inlineCallbacks
def responses(self, request):
request.setHeader('Access-Control-Allow-Origin', '*')
request.setHeader('Access-Control-Allow-Methods', 'POST')
request.setHeader('Access-Control-Allow-Headers', '*')
request.setHeader('Access-Control-Max-Age', 2520) # 42 hours
logger.info('getting responses...')
result = yield threads.deferToThread(self.responses_data_handler, request)
returnValue(result)
[docs] def responses_data_handler(self, request):
req_id = request.args['request_id'][0]
print(req_id)
print(self.reqest_list)
if req_id not in self.reqest_list:
response = [{
'type': 'PlainText',
'content': 'The results of request {} cannot be retrieved'.format(req_id),
'label': ''
}]
return json.dumps(response)
try:
response = self.result_dict[req_id]
#removing content of list and dict
del self.result_dict[req_id]
self.reqest_list.remove(req_id)
except KeyError:
response = [{'type': 'DelayedResponse', 'request_id': req_id}]
return json.dumps(response)