Search code examples
vectorplotly

How can i efficiently plot a 2D vector field with triangles in plotly?


Disclaimer: I have added a feature request on the plotly github for a function like the one I am looking for.

I have written a piece of code that plots a vector field using plotly, which is similar to the Cone plot in 3D, here is the code snippet:

import numpy as np
import plotly.graph_objects as go
import plotly.colors as pc


def plot_vector_field(x,y,u,v):
    
    # Function to get colors for the triangles
    def get_colors(values, colorscale='Viridis'):
                        colorscale = pc.get_colorscale(colorscale)
                        unique_magnitudes = np.unique(values)
                        color_map = {val: pc.sample_colorscale(colorscale, val)[0] for val in unique_magnitudes}
                        return np.vectorize(color_map.get)(values)

    # Function to plot the triangles
    def plot_triangle(fig, position,direction,size,color):
                        x = [position[0]+direction[0]*size/2, 
                                position[0]-direction[0]*size/3 + direction[1]*size/4, 
                                position[0]-direction[0]*size/3 - direction[1]*size/4]
                        y = [position[1]+direction[1]*size/2, 
                                position[1]-direction[1]*size/3 - direction[0]*size/4, 
                                position[1]-direction[1]*size/3 + direction[0]*size/4]

                        fig.add_trace(go.Scatter(
                                                x=x,
                                                y=y,
                                                fill='toself',
                                                mode='lines', 
                                                line=dict(color='rgba(0,0,0,0)'),
                                                fillcolor=color,
                                                showlegend=False,
                                                name=''
                                            ))
        
    # Calculate the magnitude of the vectors
    magnitude = np.sqrt(u**2 + v**2)
    magnitude_normalized = magnitude/np.max(magnitude)

    # Get the colors for the triangles
    colors = get_colors(magnitude_normalized, colorscale='viridis')

    # Get the direction of the vectors
    angle = np.arctan2(v, u)
    direction = np.array([np.cos(angle), np.sin(angle)]).T

    # Create the figure
    fig = go.Figure()

    # For each point in the grid, plot a triangle
    for i in range(len(x)):
        plot_triangle(fig, 
                    position=[x[i],y[i]], 
                    direction=direction[i], 
                    size=0.8*magnitude_normalized[i], 
                    color=colors[i])

    
    # Add a trace for the colorbar
    colorbar_trace = go.Scatter(
                        x=[None],
                        y=[None],
                        mode='markers',
                        marker=dict(
                            colorscale='Viridis',
                            cmin=0,
                            cmax=np.max(magnitude),
                            colorbar=dict(
                                title=''
                            )
                        ),
                        showlegend=False
                    )
    fig.add_trace(colorbar_trace)

    return fig


x = np.linspace(-2, 2, 10)
y = np.linspace(-2, 2, 10)
X, Y = np.meshgrid(x, y)
u = -1 - X**2 + Y
v = 1 + X - Y**2

fig = plot_vector_field(X.flatten(), Y.flatten(), u.flatten(), v.flatten())
fig.show()

which produces this figure: Triangle plot

The problem is that as I scale up the number of points, this function performs badly, seeing as it is plotting every triangle as a different trace.

Is there a better way to do this?

I have looked into plotting all the triangles in a single trace, but find it hard to make the fill and colors work as desired.


Solution

  • I actually figured out the way to do it myself, by instead messing around with the scatter function. Here is a functionning script that did the trick:

    import plotly.graph_objects as go
    import numpy as np
    
    x = np.linspace(-2, 2, 10)
    y = np.linspace(-2, 2, 10)
    X, Y = np.meshgrid(x, y)
    u = -1 - X**2 + Y
    v = 1 + X - Y**2
    
    
    # Calculate the magnitude of the vectors
    magnitude = np.sqrt(u**2 + v**2)
    magnitude_normalized = magnitude/np.max(magnitude)
    
    # Get the direction of the vectors
    angle = np.arctan2(v, u)
    direction = np.array([np.cos(angle), np.sin(angle)]).T
    
    fig = go.Figure(data = go.Scatter(
    x=X.flatten(),
    y=Y.flatten(),
    mode='markers',
    marker=dict(symbol='arrow', angle=90-angle.flatten()*180/np.pi, size=50*magnitude_normalized.flatten(), color=magnitude_normalized.flatten())
    )
    )
    
    fig.show()