const std = @import("std");
const Allocator = std.mem.Allocator;

const sdlerr = @import("../err.zig").sdlerr;
const resource = @import("../resource.zig");
const c = @import("c.zig").c;

const log = std.log.scoped(.shader);

pub const Stage = enum(c_uint) {
    vert = c.SDL_GPU_SHADERSTAGE_VERTEX,
    frag = c.SDL_GPU_SHADERSTAGE_FRAGMENT,
};

pub fn stageFromExtension(name: []const u8) !Stage {
    const extension = std.fs.path.extension(name);
    if (extension.len == 0) return error.NoStageExtension;
    return std.meta.stringToEnum(Stage, extension[1..]) orelse error.NoStageExtension;
}

pub fn loadShader(gpa: Allocator, device: *c.SDL_GPUDevice, name: []const u8, info: anytype) !*c.SDL_GPUShader {
    return loadShaderSpirv(gpa, device, name, info);
}

fn createShader(device: *c.SDL_GPUDevice, data: []u8, stage: Stage, info: anytype) !*c.SDL_GPUShader {
    var create_info = std.mem.zeroInit(c.SDL_GPUShaderCreateInfo, info);
    create_info.code_size = data.len;
    create_info.code = data.ptr;
    create_info.entrypoint = "main";
    create_info.format = c.SDL_GPU_SHADERFORMAT_SPIRV;
    create_info.stage = @intFromEnum(stage);

    return sdlerr(c.SDL_CreateGPUShader(device, &create_info));
}

fn loadShaderSpirv(gpa: Allocator, device: *c.SDL_GPUDevice, name: []const u8, info: anytype) !*c.SDL_GPUShader {
    // Determine shader stage from extension
    const stage = try stageFromExtension(name);

    // Allocate relative path to SPIR-V file
    const spirv_name = try std.mem.concat(gpa, u8, &.{ name, ".spv" });
    defer gpa.free(spirv_name);

    // Load SPIR-V binary
    const path = try resource.dataFilePath(gpa, spirv_name);
    defer gpa.free(path);
    const data = try resource.loadFileZ(gpa, path);
    defer gpa.free(data);

    // Load into SDL GPU
    return createShader(device, data, stage, info);
}
