I'm new to Zig and experimenting with the type system. I have a piece of code that does some comptime operations to determine the shape of a tensor. I'd like to use this comptime variable for some static type analysis as well as keep it around (for now) as helpful runtime info. I'm getting the error: error: type capture contains reference to comptime var
which I've tried to reproduce by making smaller versions of what I thought it was saying but failed in my attempts.
I believe the error that I'm getting is related to this post where the advice is: To fix the error, copy your finalized array to a const before taking a pointer.
I think my comptime_shape
is a const. I could use some help understanding this error in my context and how to fix it:
const std = @import("std");
pub fn Tensor(comptime data: anytype) type {
const comptime_shape = getShape(@TypeOf(data));
return struct {
data: @TypeOf(data),
shape: [comptime_shape.len]usize,
const Self = @This();
pub fn init() Self {
return Self{ .data = data, .shape = comptime_shape };
}
};
}
fn getShape(comptime T: type) []const usize {
const info = @typeInfo(T);
switch (info) {
.Array => |arr| {
if (@typeInfo(arr.child) == .Array) {
const child_shape = getShape(arr.child);
var result: [child_shape.len + 1]usize = undefined;
result[0] = arr.len;
@memcpy(result[1..], child_shape);
return &result;
} else {
return &[_]usize{arr.len};
}
},
else => @compileError("Invalid tensor type"),
}
}
test "init tensor" {
const data: [2][3]f32 = .{ .{ 1, 2, 3 }, .{ 2, 3, 4 } };
const result = Tensor(data).init();
std.debug.print("tensor [{}]\n", .{result});
std.debug.print("tensor.shape [{}]\n", .{result.shape});
}
main.zig:6:12: error: type capture contains reference to comptime var
main.zig:37:26: note: called from here
It's also pretty interesting to me that if I remove the typehint for const data then I get an entirely new error:
test "init tensor" {
const data = .{ .{ 1, 2, 3 }, .{ 2, 3, 4 } };
const result = Tensor(data).init();
std.debug.print("tensor [{}]\n", .{result});
std.debug.print("tensor.shape [{}]\n", .{result.shape});
}
main.zig:31:17: error: Invalid tensor type
main.zig:4:36: note: called from here
main.zig:37:26: note: called from here
I think my comptime_shape is a const. I could use some help understanding this error in my context and how to fix it:
comptime_shape
is indeed const but it is just a pointer into comptime var inside the getShape
function. You can fix this by returning an array by value. You also need to create a second function that calculates the size of this array:
const std = @import("std");
pub fn Tensor(data: anytype) type {
const comptime_shape = getShape(@TypeOf(data));
return struct {
data: @TypeOf(data),
const shape = comptime_shape;
const Self = @This();
pub fn init() Self {
return Self{ .data = data };
}
};
}
fn getShapeLen(comptime T: type) usize {
const info = @typeInfo(T);
switch (info) {
.Array => |arr| {
if (@typeInfo(arr.child) == .Array) {
return getShapeLen(arr.child) + 1;
} else {
return 1;
}
},
else => @compileError("Invalid tensor type"),
}
}
fn getShape(comptime T: type) [getShapeLen(T)]usize {
const info = @typeInfo(T);
switch (info) {
.Array => |arr| {
if (@typeInfo(arr.child) == .Array) {
const child_shape = getShape(arr.child);
return [1]usize{arr.len} ++ child_shape;
} else {
return [1]usize{arr.len};
}
},
else => @compileError("Invalid tensor type"),
}
}
test "init tensor" {
const data: [2][3]f32 = .{ .{ 1, 2, 3 }, .{ 2, 3, 4 } };
const result = Tensor(data).init();
std.debug.print("tensor [{}]\n", .{result});
std.debug.print("tensor.shape [{any}]\n", .{@TypeOf(result).shape});
}
PS: the data
argument of the Tensor
function need not be comptime
. You are only using its type at comptime and the type of any variable is always comptime known. I would probably take out the type into a separate function and take the data as a runtime argument. But I might be missing some context:
const std = @import("std");
pub fn Tensor(comptime T: type) type {
return struct {
data: T,
const shape = getShape(T);
const Self = @This();
pub fn init(data: T) Self {
return .{ .data = data };
}
};
}
fn getShapeLen(comptime T: type) usize {
const info = @typeInfo(T);
return switch (info) {
.Array => |arr| if (@typeInfo(arr.child) == .Array)
getShapeLen(arr.child) + 1
else
1,
else => @compileError("Invalid tensor type"),
};
}
fn getShape(comptime T: type) [getShapeLen(T)]usize {
const info = @typeInfo(T);
return switch (info) {
.Array => |arr| if (@typeInfo(arr.child) == .Array)
[1]usize{arr.len} ++ getShape(arr.child)
else
[1]usize{arr.len},
else => @compileError("Invalid tensor type"),
};
}