@@ -103,6 +103,7 @@ class ModelState : public BackendModel {
103103
104104 bool EnabledWeightSharing () { return enable_weight_sharing_; }
105105 const std::vector<std::string>& ModelOutputs () { return output_names_; }
106+ const std::string& MethodToCall () { return method_to_call_; }
106107
107108 private:
108109 ModelState (TRITONBACKEND_Model* triton_model);
@@ -145,6 +146,10 @@ class ModelState : public BackendModel {
145146 // List of all the outputs specified in the output section of model
146147 // configuration.
147148 std::vector<std::string> output_names_;
149+
150+ // Method to call on PyTorch Module.
151+ // Defaults to "forward".
152+ std::string method_to_call_;
148153};
149154
150155TRITONSERVER_Error*
@@ -180,7 +185,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
180185 enable_weight_sharing_(false ), enable_tensor_fuser_pair_({false , true }),
181186 enable_jit_profiling_pair_({false , true }),
182187 enable_jit_executor_pair_({false , true }),
183- enable_nvfuser_pair_({false , false })
188+ enable_nvfuser_pair_({false , false }),
189+ method_to_call_(" forward" )
184190{
185191 output_names_.clear ();
186192
@@ -454,6 +460,29 @@ ModelState::ParseParameters()
454460 " for model instance '" + Name () + " '" )
455461 .c_str ());
456462 }
463+
464+ // If 'ENABLE_NVFUSER' is not present in 'parameters' then no
465+ // update is made to 'enable_nvfuser'.
466+ std::string method_to_call = " forward" ;
467+ err = GetParameterValue (params, " METHOD_TO_CALL" , &method_to_call);
468+ if (err != nullptr ) {
469+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
470+ return err;
471+ } else {
472+ LOG_MESSAGE (
473+ TRITONSERVER_LOG_INFO, (std::string (" method_to_call is not specified" ) +
474+ " for model instance '" + Name () + " '" )
475+ .c_str ());
476+ TRITONSERVER_ErrorDelete (err);
477+ }
478+ } else {
479+ method_to_call_ = std::string (" forward" );
480+ LOG_MESSAGE (
481+ TRITONSERVER_LOG_INFO, (std::string (" method_to_call is " ) +
482+ method_to_call_ +
483+ " for model instance '" + Name () + " '" )
484+ .c_str ());
485+ }
457486 }
458487
459488 return nullptr ;
@@ -764,7 +793,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
764793 // configuration specifies only those.
765794 std::vector<std::string> allowed_inputs;
766795
767- const torch::jit::Method& method = torch_model_->get_method (" forward " );
796+ const torch::jit::Method& method = torch_model_->get_method (model_state_-> MethodToCall () );
768797 const auto & schema = method.function ().getSchema ();
769798 const std::vector<c10::Argument>& arguments = schema.arguments ();
770799
@@ -1324,16 +1353,23 @@ ModelInstanceState::Execute(
13241353 torch::NoGradGuard no_grad;
13251354
13261355 // If input is a dictionary, prepare dictionary from 'input_tensors'.
1356+ std::string method_to_call = model_state_->MethodToCall ();
13271357 if (is_dict_input_) {
13281358 torch::Dict<std::string, torch::Tensor> input_dict;
13291359 for (auto & input_index : input_index_map_) {
13301360 torch::jit::IValue ival = (*input_tensors)[input_index.second ];
13311361 input_dict.insert (input_index.first , ival.toTensor ());
13321362 }
1333- std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1334- model_outputs_ = torch_model_->forward (input_dict_ivalue);
1363+ auto typ = c10::DictType::create (c10::StringType::get (), c10::TensorType::get ());
1364+ auto inp = c10::impl::GenericList (typ);
1365+ inp.emplace_back (input_dict);
1366+ model_outputs_ = torch_model_->run_method (method_to_call, inp);
13351367 } else {
1336- model_outputs_ = torch_model_->forward (*input_tensors);
1368+ auto inp = c10::impl::GenericList (c10::TensorType::get ());
1369+ for (auto & input_tensor : *input_tensors) {
1370+ inp.emplace_back (input_tensor.toTensor ());
1371+ }
1372+ model_outputs_ = torch_model_->run_method (method_to_call, inp);
13371373 }
13381374
13391375 if (model_outputs_.isTuple ()) {
0 commit comments