Skip to main content

API参数报错修改工作

1.lossAPI

1.需要修改的loss

torch包含size_average和reduce的loss类如下:(共17个)

Loss Name
L1Loss
NLLLoss(以及 NLLLoss2d)
PoissonNLLLoss
KLDivLoss
MSELoss
BCELoss
BCEWithLogitsLoss
HingeEmbeddingLoss
MultiLabelMarginLoss
SmoothL1Loss
SoftMarginLoss
CrossEntropyLoss
MultiLabelSoftMarginLoss
CosineEmbeddingLoss
MarginRankingLoss
MultiMarginLoss
TripletMarginLoss

2.对应关系

  • reduce=False(无论 size_average 是 True/False/None)→ reduction='none'

  • reduce=True 且 size_average=False / reduce=None 且 size_average=False→ reduction='sum'

  • reduce=True 且 size_average=True / reduce=None 且 size_average=True/reduce=True 且 size_average=None/reduct=None且size_average=None → reduction='mean'

reducesize_average对应 reduction 值
FalseFalse'none'
FalseTrue'none'
FalseNone'none'
TrueFalse'sum'
NoneFalse'sum'
TrueTrue'mean'
NoneTrue'mean'
TrueNone'mean'
NoneNone'mean'

3.torch.nn.functional下面还有需要处理的

Loss Functionsize_average indexreduce index
binary_cross_entropy34
binary_cross_entropy_with_logits34
poisson_nll_loss46
cosine_embedding_loss45
cross_entropy35
hinge_embedding_loss34
kl_div23
l1_loss23
mse_loss23
margin_ranking_loss45
multilabel_margin_loss23
multilabel_soft_margin_loss34
multi_margin_loss56
nll_loss35
smooth_l1_loss23
soft_margin_loss23
triplet_margin_loss78

2.WindowAPI

paddle的get_window目前支持的窗口函数共14个:

来源窗口函数数量
PyTorchbartlett_window, blackman_window, hamming_window, hann_window, kaiser_window5
SciPygaussian, general_gaussian, exponential, triang, bohman , cosine, tukey, taylor, nuttall9

pytorch5个窗口函数暴露参数情况(指对用户暴露的可传参数)dtype、layout、device、pin_memory默认均为None

参数名 / 函数名hamming_windowbartlett_windowblackman_windowhann_windowkaiser_windowget_window
window_length
periodic✅默认True✅默认为True✅默认为True✅默认为True✅默认为Truefftbins
alpha✅默认0.54❌默认0.5(底层调用hamming_window)
beta✅默认0.46❌默认0.5(底层调用hamming_window)(默认 12.0)
*dtypedtype
*layout
*device
*pin_memory官方文档没暴露但底层实现可传官方文档没暴露但底层实现可传官方文档没暴露但底层实现可传官方文档没暴露但底层实现可传
*requires_gradstop_gradient