模型可视化方法
本文收集了常用的深度学习模型可视化方法,按场景分为结构可视化、训练监控和可解释性三类。
模型结构可视化
原生 PyTorch
# 打印模型结构
print(model)
# 统计参数量
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total: {total:,} | Trainable: {trainable:,}")PyG 图神经网络可使用 torch_geometric.nn.summary。
torchinfo
torchinfo最推荐的模型摘要工具,提供层级信息、输出形状、参数量和内存占用。
pip install torchinfofrom torchinfo import summary
summary(model, input_size=(1, 3, 224, 224))
# 详细模式
summary(model, input_size=(32, 3, 224, 224),
col_names=["output_size", "num_params", "mult_adds"], depth=4)torchsummary 已停止维护,请使用 torchinfo。
Netron
Netron跨平台模型可视化工具,支持 ONNX/TensorFlow/PyTorch/Keras 等格式。
导出 ONNX 并可视化:
import torch
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx",
opset_version=17, input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})pip install netron && netron model.onnx也可直接拖拽 .onnx 到 netron.app 在线查看。
PlotNeuralNet
PlotNeuralNet基于 LaTeX/TikZ 生成论文级网络结构图。
# Linux
sudo apt-get install texlive-latex-base texlive-fonts-recommended texlive-latex-extra
# 克隆并使用
git clone https://github.com/HarisIqbal88/PlotNeuralNet.git
cd PlotNeuralNet && bash tikzmake.sh my_arch训练过程可视化
TensorBoard
TensorBoardPyTorch 已原生支持 TensorBoard。
pip install tensorboardfrom torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/exp1')
writer.add_graph(model, torch.randn(1, 3, 224, 224))
for epoch in range(epochs):
writer.add_scalar('Loss/train', loss, epoch)
writer.add_scalar('Acc/val', acc, epoch)
writer.close()tensorboard --logdir=runs # 访问 localhost:6006模型可解释性
SHAP
SHAP基于 Shapley 值的特征重要性解释。
import shap
# 树模型
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test)
# 深度学习
explainer = shap.DeepExplainer(model, X_train[:100])LIME
LIME局部可解释的模型无关解释方法。
from lime import lime_tabular
explainer = lime_tabular.LimeTabularExplainer(X_train, feature_names=names, mode='classification')
exp = explainer.explain_instance(X_test[0], model.predict_proba)
exp.show_in_notebook()Captum
CaptumPyTorch 官方可解释性库,支持集成梯度、Grad-CAM 等。
from captum.attr import IntegratedGradients, LayerGradCam
ig = IntegratedGradients(model)
attr = ig.attribute(input_tensor, target=target_class)
# Grad-CAM
gc = LayerGradCam(model, model.layer4[-1])
attr = gc.attribute(input_tensor, target=target_class)Yellowbrick
Yellowbricksklearn 模型评估可视化库。
from yellowbrick.classifier import ConfusionMatrix, ROCAUC
from yellowbrick.model_selection import LearningCurve
cm = ConfusionMatrix(model)
cm.fit(X_train, y_train)
cm.score(X_test, y_test)
cm.show()总结
| 工具 | 场景 | 特点 |
|---|---|---|
| torchinfo | 结构概览 | 轻量、参数统计 |
| Netron | 结构可视化 | 交互式、多格式 |
| PlotNeuralNet | 论文图表 | LaTeX 输出 |
| TensorBoard | 训练监控 | 实时、多指标 |
| SHAP | 特征重要性 | 理论扎实 |
| LIME | 局部解释 | 模型无关 |
| Captum | DL 解释 | PyTorch 官方 |
| Yellowbrick | ML 评估 | sklearn 集成 |