全層連接後的分類問題,求解

網路成最後一層使用 Linear 時,如何取得推論結果 ex:

nn.Linear(channels, num_classes)

output = model(img) # 模型推論 假定得到 [1,31] 的分類問題
output = output.squeeze() # 擠壓為 [31]
sorted, indices = torch.sort(output, descending=True) # 排序 這邊用到 indices 中會得得大到小排序的索引位置
probs = F.softmax(output, dim=-1) # 使用 softmax 得到機率

假定有個 classes

取得類型

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *