14 Star 80 Fork 29

PaddlePaddle / FastDeploy

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
develop_a_new_model.md 11.42 KB
一键复制 编辑 原始数据 按行查看 历史

English | 中文

How to Integrate New Model on FastDeploy

How to add a new model on FastDeploy, including C++/Python deployment? Here, we take the ResNet50 model in torchvision v0.12.0 as an example, introducing external Model Integration on FastDeploy. The whole process only needs 3 steps.

Step Description Create or modify the files
1 Add a model implementation to the corresponding task module in FastDeploy/vision resnet.h, resnet.cc, vision.h
2 Python interface binding via pybind resnet_pybind.cc, classification_pybind.cc
3 Use Python to call Interface resnet.py, __init__.py

After completing the above 3 steps, an external model is integrated.

If you want to contribute your code to FastDeploy, it is very kind of you to add test code, instructions (Readme), and code annotations for the added model in the test.

Model Integration

Prepare the models

Before integrating external models, it is important to convert the trained models (.pt, .pdparams, etc.) to the model formats (.onnx, .pdmodel) that FastDeploy supports for deployment. Most open source repositories provide model conversion scripts for developers. As torchvision does not provide conversion scripts, developers can write conversion scripts manually. In this demo, we convert torchvison.models.resnet50 to resnet50.onnx with the following code for your reference.

import torch
import torchvision.models as models
model = models.resnet50(pretrained=True)
batch_size = 1  #batch size
input_shape = (3, 224, 224)   #Input data, change to your own input shape
model.eval()
x = torch.randn(batch_size, *input_shape)    # Generate Tensor
export_onnx_file = "resnet50.onnx"            # ONNX file name
torch.onnx.export(model,
                    x,
                    export_onnx_file,
                    opset_version=12,
                    input_names=["input"],    # Input names
                    output_names=["output"],    # Output names
                    dynamic_axes={"input":{0:"batch_size"},  # batch size variables
                                    "output":{0:"batch_size"}})

Running the above script will generate aresnet50.onnx file.

C++

  • Createresnet.h file
    • Create a path
      • FastDeploy/fastdeploy/vision/classification/contrib/resnet.h (FastDeploy/C++ code/vision/task name/external model name/model name.h)
    • Create content
      • First, create ResNet class in resnet.h and inherit from FastDeployModel parent class, then declare Predict, Initialize, Preprocess, Postprocess and Constructor, and necessary variables, please refer to resnet.h for details.
class FASTDEPLOY_DECL ResNet : public FastDeployModel {
 public:
  ResNet(...);
  virtual bool Predict(...);
 private:
  bool Initialize();
  bool Preprocess(...);
  bool Postprocess(...);
};
  • Createresnet.cc file
    • Create a path
      • FastDeploy/fastdeploy/vision/classification/contrib/resnet.cc (FastDeploy/C++ code/vision/task name/external model name/model name.cc)
    • Create content
      • Implement the specific logic of the functions declared in resnet.h to resnet.cc, where PreProcess and PostProcess need to refer to the official source library for pre- and post-processing logic reproduction. The specific logic of each ResNet function is as follows. For more detailed code, please refer to resnet.cc.
ResNet::ResNet(...) {
  // Constructor logic
  // 1. Specify Backend 2. Set RuntimeOption 3. Call Initialize()function
}
bool ResNet::Initialize() {
  // Initialization logic
  // 1. Assign values to global variables 2. Call InitRuntime()function
  return true;
}
bool ResNet::Preprocess(Mat* mat, FDTensor* output) {
// Preprocess logic
// 1. Resize 2. BGR2RGB 3. Normalize 4. HWC2CHW 5. save the results to FDTensor class  
  return true;
}
bool ResNet::Postprocess(FDTensor& infer_result, ClassifyResult* result, int topk) {
  //Postprocess logic
  // 1. Softmax 2. Choose topk labels 3. Save the results to ClassifyResult
  return true;
}
bool ResNet::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
  Preprocess(...)
  Infer(...)
  Postprocess(...)
  return true;
}

  • Add new model file tovision.h
    • modify location
      • FastDeploy/fastdeploy/vision.h
    • modify content
#ifdef ENABLE_VISION
#include "fastdeploy/vision/classification/contrib/resnet.h"
#endif

Pybind

  • Create Pybind file

    • Create path

      • FastDeploy/fastdeploy/vision/classification/contrib/resnet_pybind.cc (FastDeploy/C++ code/vision model/taks name/external model/model name_pybind.cc)
    • Create content

      • Use Pybind to bind function variables from C++ to Python, please refer to resnet_pybind.cc for more details.

        void BindResNet(pybind11::module& m) {
        pybind11::class_<vision::classification::ResNet, FastDeployModel>(
        m, "ResNet")
        .def(pybind11::init<std::string, std::string, RuntimeOption, ModelFormat>())
        .def("predict", ...)
        .def_readwrite("size", &vision::classification::ResNet::size)
        .def_readwrite("mean_vals", &vision::classification::ResNet::mean_vals)
        .def_readwrite("std_vals", &vision::classification::ResNet::std_vals);
        }
  • Call Pybind function

    • modify path

      • FastDeploy/fastdeploy/vision/classification/classification_pybind.cc (FastDeploy/C++ code/vision model/task name/task name}_pybind.cc)
    • modify content

      void BindResNet(pybind11::module& m);
      void BindClassification(pybind11::module& m) {
      auto classification_module =
        m.def_submodule("classification", "Image classification models.");
      BindResNet(classification_module);
      }

Python

  • Createresnet.pyfile
    • Create path
      • FastDeploy/python/fastdeploy/vision/classification/contrib/resnet.py (FastDeploy/Python code/fastdeploy/vision model/task name/external model/model name.py)
    • Create content
      • Create ResNet class inherited from FastDeployModel, and implement \_\_init\_\_, Pybind bonded functions (such as predict()), and functions to assign and get global variables bound to Pybind. Please refer to resnet.py for details
class ResNet(FastDeployModel):
    def __init__(self, ...):
        self._model = C.vision.classification.ResNet(...)
    def predict(self, input_image, topk=1):
        return self._model.predict(input_image, topk)
    @property
    def size(self):
        return self._model.size
    @size.setter
    def size(self, wh):
        ...

  • Import ResNet classes
    • modify path
      • FastDeploy/python/fastdeploy/vision/classification/__init__.py (FastDeploy/Python code/fastdeploy/vision model/task name/__init__.py)
    • modify content
from .contrib.resnet import ResNet

Test

Compile

  • C++
    • Path:FastDeploy/
mkdir build & cd build
cmake .. -DENABLE_ORT_BACKEND=ON -DENABLE_VISION=ON -DCMAKE_INSTALL_PREFIX=${PWD/fastdeploy-0.0.3
-DENABLE_PADDLE_BACKEND=ON -DENABLE_TRT_BACKEND=ON -DWITH_GPU=ON -DTRT_DIRECTORY=/PATH/TO/TensorRT/
make -j8
make install

Compile to get build/fastdeploy-0.0.3/。

  • Python
    • Path:FastDeploy/python/
export TRT_DIRECTORY=/PATH/TO/TensorRT/    #If TensorRT is used, developers need to fill in the location of TensorRT and enable ENABLE_TRT_BACKEND
export ENABLE_TRT_BACKEND=ON
export WITH_GPU=ON
export ENABLE_PADDLE_BACKEND=ON
export ENABLE_OPENVINO_BACKEND=ON
export ENABLE_VISION=ON
export ENABLE_ORT_BACKEND=ON
python setup.py build
python setup.py bdist_wheel
cd dist
pip install fastdeploy_gpu_python-Version number-cpxx-cpxxm-system architecture.whl

Compile Test Code

  • Create path: FastDeploy/examples/vision/classification/resnet/ (FastDeploy/examples/vision model/task anme/model name/)
  • Creating directory structure
.
├── cpp
│   ├── CMakeLists.txt
│   ├── infer.cc    // C++ test code
│   └── README.md   // C++ Readme
├── python
│   ├── infer.py    // Python test code
│   └── README.md   // Python Readme
└── README.md   // ResNet model integration readme
  • C++
    • Write CmakeLists、C++ code and README.md . Please refer to cpp/
    • Compile infer.cc
      • Path:FastDeploy/examples/vision/classification/resnet/cpp/
mkdir build & cd build
cmake .. -DFASTDEPLOY_INSTALL_DIR=/PATH/TO/FastDeploy/build/fastdeploy-0.0.3/
make
  • Python
    • Please refer to python/ for Python code and Readme.md

Annotate the Code

To make the code clear for understanding, developers can annotate the newly-added code.

  • C++ code Developers need to add annotations for functions and variables in the resnet.h file, there are three annotating methods as follows, please refer to resnet.h for more details.
/** \brief Predict for the input "im", the result will be saved in "result".
*
* \param[in] im Input image for inference.
* \param[in] result Saving the inference result.
* \param[in] topk The length of return values, e.g., if topk==2, the result will include the 2 most possible class label for input image.
*/
virtual bool Predict(cv::Mat* im, ClassifyResult* result, int topk = 1);
/// Tuple of (width, height)
std::vector<int> size;
/*! @brief Initialize for ResNet model, assign values to the global variables and call InitRuntime()
*/
bool Initialize();
  • Python The following example is to demonstrate how to annotate functions and variables in resnet.py file. For more details, please refer to resnet.py.
  def predict(self, input_image, topk=1):
    """Classify an input image
    :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
    :param topk: (int)The topk result by the classify confidence score, default 1
    :return: ClassifyResult
    """
    return self._model.predict(input_image, topk)

Other files in the integration process can also be annotated to explain the details of the implementation.

Python
1
https://gitee.com/paddlepaddle/FastDeploy.git
git@gitee.com:paddlepaddle/FastDeploy.git
paddlepaddle
FastDeploy
FastDeploy
develop

搜索帮助