00001 #ifndef HZZ2L2NU_INCLUDE_XGBOOSTPREDICTOR_H_ 00002 #define HZZ2L2NU_INCLUDE_XGBOOSTPREDICTOR_H_ 00003 00004 #include <exception> 00005 #include <string> 00006 00007 #include <xgboost/c_api.h> 00008 00009 00011 class XGBoostPredictorException : public std::exception { 00012 public: 00019 XGBoostPredictorException(std::string const &functionName, int exitCode); 00020 00022 char const *what() const noexcept override { 00023 return message_.c_str(); 00024 } 00025 00026 private: 00028 std::string message_; 00029 }; 00030 00031 00041 class XGBoostPredictor { 00042 public: 00049 XGBoostPredictor(std::string const &path, int numFeatures); 00050 00051 ~XGBoostPredictor(); 00052 00060 float Predict(float const *x) const; 00061 00062 private: 00068 static void CheckCall(std::string const &functionName, int exitCode); 00069 00071 BoosterHandle booster_; 00072 00074 int numFeatures_; 00075 }; 00076 00077 #endif // HZZ2L2NU_INCLUDE_XGBOOSTPREDICTOR_H_ 00078