Skip to content

Three.js 3D Visualization

modelviz includes a stunning Three.js-based 3D renderer that creates interactive visualizations of neural network architectures.

Overview

The Three.js renderer creates self-contained HTML files with: - Distinct 3D shapes for each layer type - Horizontal layout with data flowing left to right - Text labels above each layer - Animated particles showing data flow - Interactive controls for exploration

Usage

from modelviz import visualize_threejs

html = visualize_threejs(
    model,
    input_shape=(1, 3, 224, 224),
    title="My Network",
    save_path="network.html"
)

# Open network.html in any web browser

Layer Shapes

Each layer type has a semantically meaningful 3D representation:

Layer Type Shape Rationale
Conv2d 3D Box Feature maps are 3D volumes (Channels × Height × Width)
Linear/Dense Flat Plane Weight matrix is 2D (input features × output features)
Pooling Small Cube Reduces spatial dimensions → smaller representation
Activation Glowing Sphere Element-wise operation applied uniformly
BatchNorm Thin Slab Normalizes across batch, "flattens" distribution
Flatten Cone Funnels multi-dimensional data into 1D vector
Dropout Wireframe Cube Sparse/transparent = random neurons "dropped"
RNN/LSTM Cylinder Circular shape suggests recurrent/cyclical flow
Attention Octahedron Multi-faceted for multi-head attention patterns

Color Scheme

Layer Type Color Hex Code
Convolution Indigo #6366f1
Linear/Dense Purple #8b5cf6
Pooling Cyan #06b6d4
Normalization Emerald #10b981
Activation Amber #f59e0b
Dropout Red #ef4444
Flatten Pink #ec4899
Embedding Lime #84cc16
RNN/LSTM Teal #14b8a6
Attention Orange #f97316

Interactive Controls

Action Mouse/Keyboard
Rotate Click and drag
Zoom Scroll wheel
Pan Shift + drag
Layer details Hover over shape

Labels

Each 3D shape has a floating label showing: - Layer type (e.g., "Conv2d + BatchNorm2d + ReLU") - Output dimensions (e.g., "16×16")

Labels always face the camera for readability.

Tooltips

Hovering over any layer shows detailed information: - Full layer name - Complete output shape - Parameter count - Whether grouped with other layers

Animations

  • Data flow particles: Blue spheres animate along connection lines
  • Activation rotation: Spheres slowly rotate
  • Attention oscillation: Octahedrons gently rotate

Examples

Simple MLP

import torch.nn as nn
from modelviz import visualize_threejs

model = nn.Sequential(
    nn.Linear(784, 512),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

visualize_threejs(model, input_shape=(1, 784), save_path="mlp.html")

CNN

model = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(64, 128, 3, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(128 * 8 * 8, 10)
)

visualize_threejs(model, input_shape=(1, 3, 32, 32), save_path="cnn.html")

Without grouping

To see all individual layers:

visualize_threejs(
    model,
    input_shape=(1, 3, 32, 32),
    group_blocks=False,  # Show Conv, BatchNorm, ReLU separately
    save_path="cnn_detailed.html"
)

Technical Details

  • Renderer: Three.js r160 with WebGL
  • Labels: CSS2DRenderer for crisp text
  • Self-contained: No external dependencies (CDN imports)
  • File size: ~25-35 KB per visualization
  • Browser support: Chrome, Firefox, Safari, Edge (modern versions)