Source code for tomaat.frameworks.pytorch

import torch
import torch.backends.cudnn as cudnn
import numpy as np


[docs]class Prediction(object): def __init__(self, model_path, input_arg_names, input_fields, output_fields, with_gpu=True): super(Prediction, self).__init__() self.model = torch.load(model_path) self.with_gpu = with_gpu if self.with_gpu: self.model.cuda() # avoid nonsense from cudnn cudnn.enabled = True cudnn.benchmark = True self.input_fields = input_fields self.input_arg_names = input_arg_names assert len(self.input_fields) == len(self.input_arg_names) self.output_fields = output_fields def __call__(self, data): arg_dict = {} for arg_name, field_name in zip(self.input_arg_names, self.input_fields): arg_dict[arg_name] = torch.from_numpy(data[field_name]) if self.with_gpu: arg_dict[arg_name].cuda() outputs = self.model(**arg_dict) if not hasattr(outputs, 'len'): outputs = [outputs] assert len(outputs) == len(self.output_fields) for output, output_field in zip(outputs, self.output_fields): data[output_field] = output.cpu().detach().numpy().astype(dtype=np.float32) return data