Home

模型可视化方法

本文收集了常用的深度学习模型可视化方法,按场景分为结构可视化训练监控可解释性三类。


模型结构可视化

原生 PyTorch

# 打印模型结构
print(model)

# 统计参数量
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total: {total:,} | Trainable: {trainable:,}")

PyG 图神经网络可使用 torch_geometric.nn.summary


torchinfo

torchinfo

最推荐的模型摘要工具,提供层级信息、输出形状、参数量和内存占用。

pip install torchinfo
from torchinfo import summary

summary(model, input_size=(1, 3, 224, 224))

# 详细模式
summary(model, input_size=(32, 3, 224, 224),
        col_names=["output_size", "num_params", "mult_adds"], depth=4)

torchsummary 已停止维护,请使用 torchinfo


Netron

Netron

跨平台模型可视化工具,支持 ONNX/TensorFlow/PyTorch/Keras 等格式。

导出 ONNX 并可视化:

import torch

model.eval()
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(model, dummy_input, "model.onnx",
                  opset_version=17, input_names=["input"], output_names=["output"],
                  dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
pip install netron && netron model.onnx

也可直接拖拽 .onnxnetron.app 在线查看。


PlotNeuralNet

PlotNeuralNet

基于 LaTeX/TikZ 生成论文级网络结构图。

# Linux
sudo apt-get install texlive-latex-base texlive-fonts-recommended texlive-latex-extra

# 克隆并使用
git clone https://github.com/HarisIqbal88/PlotNeuralNet.git
cd PlotNeuralNet && bash tikzmake.sh my_arch

训练过程可视化

TensorBoard

TensorBoard

PyTorch 已原生支持 TensorBoard。

pip install tensorboard
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/exp1')
writer.add_graph(model, torch.randn(1, 3, 224, 224))

for epoch in range(epochs):
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Acc/val', acc, epoch)
writer.close()
tensorboard --logdir=runs  # 访问 localhost:6006

模型可解释性

SHAP

SHAP

基于 Shapley 值的特征重要性解释。

import shap

# 树模型
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test)

# 深度学习
explainer = shap.DeepExplainer(model, X_train[:100])

LIME

LIME

局部可解释的模型无关解释方法。

from lime import lime_tabular

explainer = lime_tabular.LimeTabularExplainer(X_train, feature_names=names, mode='classification')
exp = explainer.explain_instance(X_test[0], model.predict_proba)
exp.show_in_notebook()

Captum

Captum

PyTorch 官方可解释性库,支持集成梯度、Grad-CAM 等。

from captum.attr import IntegratedGradients, LayerGradCam

ig = IntegratedGradients(model)
attr = ig.attribute(input_tensor, target=target_class)

# Grad-CAM
gc = LayerGradCam(model, model.layer4[-1])
attr = gc.attribute(input_tensor, target=target_class)

Yellowbrick

Yellowbrick

sklearn 模型评估可视化库。

from yellowbrick.classifier import ConfusionMatrix, ROCAUC
from yellowbrick.model_selection import LearningCurve

cm = ConfusionMatrix(model)
cm.fit(X_train, y_train)
cm.score(X_test, y_test)
cm.show()

总结

工具场景特点
torchinfo结构概览轻量、参数统计
Netron结构可视化交互式、多格式
PlotNeuralNet论文图表LaTeX 输出
TensorBoard训练监控实时、多指标
SHAP特征重要性理论扎实
LIME局部解释模型无关
CaptumDL 解释PyTorch 官方
YellowbrickML 评估sklearn 集成