Search code examples
zig

error: type capture contains reference to comptime var


I'm new to Zig and experimenting with the type system. I have a piece of code that does some comptime operations to determine the shape of a tensor. I'd like to use this comptime variable for some static type analysis as well as keep it around (for now) as helpful runtime info. I'm getting the error: error: type capture contains reference to comptime var which I've tried to reproduce by making smaller versions of what I thought it was saying but failed in my attempts.

I believe the error that I'm getting is related to this post where the advice is: To fix the error, copy your finalized array to a const before taking a pointer.

I think my comptime_shape is a const. I could use some help understanding this error in my context and how to fix it:

const std = @import("std");

pub fn Tensor(comptime data: anytype) type {
    const comptime_shape = getShape(@TypeOf(data));

    return struct {
        data: @TypeOf(data),
        shape: [comptime_shape.len]usize,

        const Self = @This();
        pub fn init() Self {
            return Self{ .data = data, .shape = comptime_shape };
        }
    };
}

fn getShape(comptime T: type) []const usize {
    const info = @typeInfo(T);
    switch (info) {
        .Array => |arr| {
            if (@typeInfo(arr.child) == .Array) {
                const child_shape = getShape(arr.child);
                var result: [child_shape.len + 1]usize = undefined;
                result[0] = arr.len;
                @memcpy(result[1..], child_shape);
                return &result;
            } else {
                return &[_]usize{arr.len};
            }
        },
        else => @compileError("Invalid tensor type"),
    }
}

test "init tensor" {
    const data: [2][3]f32 = .{ .{ 1, 2, 3 }, .{ 2, 3, 4 } };
    const result = Tensor(data).init();
    std.debug.print("tensor [{}]\n", .{result});
    std.debug.print("tensor.shape [{}]\n", .{result.shape});
}
main.zig:6:12: error: type capture contains reference to comptime var
main.zig:37:26: note: called from here

It's also pretty interesting to me that if I remove the typehint for const data then I get an entirely new error:

test "init tensor" {
    const data = .{ .{ 1, 2, 3 }, .{ 2, 3, 4 } };
    const result = Tensor(data).init();
    std.debug.print("tensor [{}]\n", .{result});
    std.debug.print("tensor.shape [{}]\n", .{result.shape});
}
main.zig:31:17: error: Invalid tensor type
main.zig:4:36: note: called from here
main.zig:37:26: note: called from here

Solution

  • I think my comptime_shape is a const. I could use some help understanding this error in my context and how to fix it:

    comptime_shape is indeed const but it is just a pointer into comptime var inside the getShape function. You can fix this by returning an array by value. You also need to create a second function that calculates the size of this array:

    const std = @import("std");
    
    pub fn Tensor(data: anytype) type {
        const comptime_shape = getShape(@TypeOf(data));
    
        return struct {
            data: @TypeOf(data),
    
            const shape = comptime_shape;
            const Self = @This();
    
            pub fn init() Self {
                return Self{ .data = data };
            }
        };
    }
    
    fn getShapeLen(comptime T: type) usize {
        const info = @typeInfo(T);
        switch (info) {
            .Array => |arr| {
                if (@typeInfo(arr.child) == .Array) {
                    return getShapeLen(arr.child) + 1;
                } else {
                    return 1;
                }
            },
            else => @compileError("Invalid tensor type"),
        }
    }
    
    fn getShape(comptime T: type) [getShapeLen(T)]usize {
        const info = @typeInfo(T);
        switch (info) {
            .Array => |arr| {
                if (@typeInfo(arr.child) == .Array) {
                    const child_shape = getShape(arr.child);
                    return [1]usize{arr.len} ++ child_shape;
                } else {
                    return [1]usize{arr.len};
                }
            },
            else => @compileError("Invalid tensor type"),
        }
    }
    
    test "init tensor" {
        const data: [2][3]f32 = .{ .{ 1, 2, 3 }, .{ 2, 3, 4 } };
        const result = Tensor(data).init();
        std.debug.print("tensor [{}]\n", .{result});
        std.debug.print("tensor.shape [{any}]\n", .{@TypeOf(result).shape});
    }
    

    PS: the data argument of the Tensor function need not be comptime. You are only using its type at comptime and the type of any variable is always comptime known. I would probably take out the type into a separate function and take the data as a runtime argument. But I might be missing some context:

    const std = @import("std");
    
    pub fn Tensor(comptime T: type) type {
        return struct {
            data: T,
    
            const shape = getShape(T);
            const Self = @This();
    
            pub fn init(data: T) Self {
                return .{ .data = data };
            }
        };
    }
    
    fn getShapeLen(comptime T: type) usize {
        const info = @typeInfo(T);
        return switch (info) {
            .Array => |arr| if (@typeInfo(arr.child) == .Array)
                getShapeLen(arr.child) + 1
            else
                1,
            else => @compileError("Invalid tensor type"),
        };
    }
    
    fn getShape(comptime T: type) [getShapeLen(T)]usize {
        const info = @typeInfo(T);
        return switch (info) {
            .Array => |arr| if (@typeInfo(arr.child) == .Array)
                [1]usize{arr.len} ++ getShape(arr.child)
            else
                [1]usize{arr.len},
            else => @compileError("Invalid tensor type"),
        };
    }