如何解决如何在 DJL深度 Java 库中调用自定义 mxnet 运算符?
如何从 DJL 调用自定义 mxnet 操作员?例如。 examples 中的 my_gemm
运算符。
解决方法
可以通过与内置 mxnet 引擎相同的方式手动调用 JnaUtils,只需使用您的自定义库。对于 my_gemm
示例,这看起来像这样:
import ai.djl.Device;
import ai.djl.mxnet.jna.FunctionInfo;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.util.PairList;
import java.util.Map;
// Load the external mxnet operator library
JnaUtils.loadLib("path/to/incubator-mxnet/example/extensions/lib_custom_op/libgemm_lib.so",1);
// get a handle to the loaded operator
Map<String,FunctionInfo> allFunctionsAfterLoading = JnaUtils.getNdArrayFunctions();
FunctionInfo myGemmFunction = allFunctionsAfterLoading.get("my_gemm");
// create a manager to execute the example with
try (NDManager ndManager = NDManager.newBaseManager().newSubManager(Device.cpu())) {
// create input for the gemm call
NDArray a = ndManager.create(new float[][]{new float[]{1,2,3},new float[]{4,5,6}});
NDArray b = ndManager.create(new float[][]{new float[]{7},new float[]{8},new float[]{9}});
// call the function manually (NDManager.invoke will not work,as it caches the mxnet
// engine operators and ignores external ones)
PairList<String,Object> params = new PairList<>();
NDArray result = myGemmFunction.invoke(ndManager,new NDArray[]{a,b},params)[0];
// prints
// ND: (2,1) cpu() float32
//[[ 50.],// [122.],//]
// (same as the python example)
System.out.println(result);
}
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。