API参数报错修改工作
1.lossAPI
1.需要修改的loss
torch包含size_average和reduce的loss类如下:(共17个)
| Loss Name |
|---|
| L1Loss |
| NLLLoss(以及 NLLLoss2d) |
| PoissonNLLLoss |
| KLDivLoss |
| MSELoss |
| BCELoss |
| BCEWithLogitsLoss |
| HingeEmbeddingLoss |
| MultiLabelMarginLoss |
| SmoothL1Loss |
| SoftMarginLoss |
| CrossEntropyLoss |
| MultiLabelSoftMarginLoss |
| CosineEmbeddingLoss |
| MarginRankingLoss |
| MultiMarginLoss |
| TripletMarginLoss |
2.对应关系
reduce=False(无论 size_average 是 True/False/None)→ reduction='none'
reduce=True 且 size_average=False / reduce=None 且 size_average=False→ reduction='sum'
reduce=True 且 size_average=True / reduce=None 且 size_average=True/reduce=True 且 size_average=None/reduct=None且size_average=None → reduction='mean'
| reduce | size_average | 对应 reduction 值 |
|---|---|---|
False | False | 'none' |
False | True | 'none' |
False | None | 'none' |
True | False | 'sum' |
None | False | 'sum' |
True | True | 'mean' |
None | True | 'mean' |
True | None | 'mean' |
None | None | 'mean' |
3.torch.nn.functional下面还有需要处理的
| Loss Function | size_average index | reduce index |
|---|---|---|
| binary_cross_entropy | 3 | 4 |
| binary_cross_entropy_with_logits | 3 | 4 |
| poisson_nll_loss | 4 | 6 |
| cosine_embedding_loss | 4 | 5 |
| cross_entropy | 3 | 5 |
| hinge_embedding_loss | 3 | 4 |
| kl_div | 2 | 3 |
| l1_loss | 2 | 3 |
| mse_loss | 2 | 3 |
| margin_ranking_loss | 4 | 5 |
| multilabel_margin_loss | 2 | 3 |
| multilabel_soft_margin_loss | 3 | 4 |
| multi_margin_loss | 5 | 6 |
| nll_loss | 3 | 5 |
| smooth_l1_loss | 2 | 3 |
| soft_margin_loss | 2 | 3 |
| triplet_margin_loss | 7 | 8 |
2.WindowAPI
paddle的get_window目前支持的窗口函数共14个:
| 来源 | 窗口函数 | 数量 |
|---|---|---|
| PyTorch | bartlett_window, blackman_window, hamming_window, hann_window, kaiser_window | 5 |
| SciPy | gaussian, general_gaussian, exponential, triang, bohman , cosine, tukey, taylor, nuttall | 9 |
pytorch5个窗口函数暴露参数情况(指对用户暴露的可传参数)dtype、layout、device、pin_memory默认均为None:
| 参数名 / 函数名 | hamming_window | bartlett_window | blackman_window | hann_window | kaiser_window | get_window |
|---|---|---|---|---|---|---|
window_length | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
periodic | ✅默认True | ✅默认为True | ✅默认为True | ✅默认为True | ✅默认为True | fftbins |
alpha | ✅默认0.54 | ❌ | ❌ | ❌默认0.5(底层调用hamming_window) | ❌ | 无 |
beta | ✅默认0.46 | ❌ | ❌ | ❌默认0.5(底层调用hamming_window) | ✅ (默认 12.0) | 无 |
*dtype | ✅ | ✅ | ✅ | ✅ | ✅ | dtype |
*layout | ✅ | ✅ | ✅ | ✅ | ✅ | 无 |
*device | ✅ | ✅ | ✅ | ✅ | ✅ | 无 |
*pin_memory | ✅ | 官方文档没暴露但底层实现可传 | 官方文档没暴露但底层实现可传 | 官方文档没暴露但底层实现可传 | 官方文档没暴露但底层实现可传 | 无 |
*requires_grad | ✅ | ✅ | ✅ | ✅ | ✅ | stop_gradient |