当前位置 : 主页 > 编程语言 > python >

【pytorch】——Could not export Python function call ‘Scatter‘

来源:互联网 收集:自由互联 发布时间:2022-06-15
pytorch 用pytorch的 trace 导出模型的时候,报错 error RuntimeError: Could not export Python function call 'Scatter'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If t



pytorch


用pytorch的 trace 导出模型的时候,报错

error

RuntimeError:
Could not export Python function call 'Scatter'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(13): scatter_map
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(15): scatter_map
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(28): scatter
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/scatter_gather.py(36): scatter_kwargs
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py(168): scatter
/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py(157): forward
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py(709): _slow_forward
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py(725): _call_impl
/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py(940): trace_module
/usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py(742): trace
<ipython-input-14-e92379b43790>(2): <module>

解决方案

将model改为

model = model.module



网友评论