|
61 | 61 | // https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133 |
62 | 62 | #include <ATen/Parallel.h> |
63 | 63 |
|
| 64 | +// Default forward method to call on PyTorch modules |
| 65 | +const std::string DEFAULT_MODULE_METHOD_NAME = "forward"; |
64 | 66 |
|
65 | 67 | // |
66 | 68 | // PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API. |
@@ -111,6 +113,7 @@ class ModelState : public BackendModel { |
111 | 113 | { |
112 | 114 | return model_outputs_; |
113 | 115 | } |
| 116 | + const std::string& ModuleMethodName() { return module_method_name_; } |
114 | 117 |
|
115 | 118 | private: |
116 | 119 | ModelState(TRITONBACKEND_Model* triton_model); |
@@ -153,6 +156,10 @@ class ModelState : public BackendModel { |
153 | 156 | // is specified both in the output section and state section, it indicates |
154 | 157 | // that the backend must return the output state to the client too. |
155 | 158 | std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_; |
| 159 | + |
| 160 | + // Method to call on PyTorch Module. |
| 161 | + // Defaults to DEFAULT_MODULE_METHOD_NAME. |
| 162 | + std::string module_method_name_; |
156 | 163 | }; |
157 | 164 |
|
158 | 165 | TRITONSERVER_Error* |
@@ -230,7 +237,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) |
230 | 237 | enable_inference_mode_(true), enable_cache_cleaning_(false), |
231 | 238 | enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}), |
232 | 239 | enable_jit_profiling_pair_({false, true}), |
233 | | - enable_jit_executor_pair_({false, true}) |
| 240 | + enable_jit_executor_pair_({false, true}), |
| 241 | + module_method_name_(DEFAULT_MODULE_METHOD_NAME) |
234 | 242 | { |
235 | 243 | } |
236 | 244 |
|
@@ -519,6 +527,30 @@ ModelState::ParseParameters() |
519 | 527 | .c_str()); |
520 | 528 | } |
521 | 529 | } |
| 530 | + |
| 531 | + // If 'MODULE_METHOD_NAME' is not present in 'parameters' then |
| 532 | + // 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward'). |
| 533 | + std::string module_method_name = DEFAULT_MODULE_METHOD_NAME; |
| 534 | + err = GetParameterValue(params, "MODULE_METHOD_NAME", &module_method_name); |
| 535 | + if (err != nullptr) { |
| 536 | + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { |
| 537 | + return err; |
| 538 | + } else { |
| 539 | + LOG_MESSAGE( |
| 540 | + TRITONSERVER_LOG_INFO, |
| 541 | + (std::string("module_method_name is not specified") + |
| 542 | + " for model instance '" + Name() + "'") |
| 543 | + .c_str()); |
| 544 | + TRITONSERVER_ErrorDelete(err); |
| 545 | + } |
| 546 | + } else { |
| 547 | + module_method_name_ = module_method_name; |
| 548 | + LOG_MESSAGE( |
| 549 | + TRITONSERVER_LOG_INFO, |
| 550 | + (std::string("module_method_name is ") + module_method_name_ + |
| 551 | + " for model instance '" + Name() + "'") |
| 552 | + .c_str()); |
| 553 | + } |
522 | 554 | } |
523 | 555 |
|
524 | 556 | return nullptr; |
@@ -940,7 +972,20 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) |
940 | 972 | // configuration specifies only those. |
941 | 973 | std::vector<std::string> allowed_inputs; |
942 | 974 |
|
943 | | - const torch::jit::Method& method = torch_model_->get_method("forward"); |
| 975 | + // First check if method exists in the model and throw an error if absent |
| 976 | + const auto methodNameToExecute = model_state_->ModuleMethodName(); |
| 977 | + const auto optionalMethodHandle = |
| 978 | + torch_model_->find_method(methodNameToExecute); |
| 979 | + if (!optionalMethodHandle.has_value()) { |
| 980 | + return TRITONSERVER_ErrorNew( |
| 981 | + TRITONSERVER_ERROR_INVALID_ARG, |
| 982 | + (std::string("unable to find method '") + methodNameToExecute + |
| 983 | + "' in model '" + model_path_ + "'") |
| 984 | + .c_str()); |
| 985 | + } |
| 986 | + |
| 987 | + // Get the method schema and validate the inputs |
| 988 | + const torch::jit::Method& method = optionalMethodHandle.value(); |
944 | 989 | const auto& schema = method.function().getSchema(); |
945 | 990 | const std::vector<c10::Argument>& arguments = schema.arguments(); |
946 | 991 |
|
@@ -1583,18 +1628,24 @@ ModelInstanceState::Execute( |
1583 | 1628 | torch::NoGradGuard no_grad; |
1584 | 1629 |
|
1585 | 1630 | // If input is a dictionary, prepare dictionary from 'input_tensors'. |
| 1631 | + std::string module_method_name = model_state_->ModuleMethodName(); |
| 1632 | + std::vector<c10::IValue> inputs; |
1586 | 1633 | if (is_dict_input_) { |
1587 | | - torch::Dict<std::string, torch::Tensor> input_dict; |
| 1634 | + c10::Dict<std::string, at::Tensor> dict; |
1588 | 1635 | for (auto& input_index : input_index_map_) { |
1589 | 1636 | torch::jit::IValue ival = (*input_tensors)[input_index.second]; |
1590 | | - input_dict.insert(input_index.first, ival.toTensor()); |
| 1637 | + dict.insert(input_index.first, ival.toTensor()); |
1591 | 1638 | } |
1592 | | - std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict}; |
1593 | | - model_outputs_ = torch_model_->forward(input_dict_ivalue); |
| 1639 | + inputs.push_back(dict); |
1594 | 1640 | } else { |
1595 | | - model_outputs_ = torch_model_->forward(*input_tensors); |
| 1641 | + for (auto& input_tensor : *input_tensors) { |
| 1642 | + inputs.push_back(input_tensor.toTensor()); |
| 1643 | + } |
1596 | 1644 | } |
1597 | 1645 |
|
| 1646 | + // Actually run the method on the model. |
| 1647 | + model_outputs_ = torch_model_->get_method(module_method_name)(inputs); |
| 1648 | + |
1598 | 1649 | if (model_outputs_.isTuple()) { |
1599 | 1650 | auto model_outputs_tuple = model_outputs_.toTuple(); |
1600 | 1651 | size_t op_index = 0; |
|
0 commit comments