//! Vector, matrix and quaternion operations.

const std = @import("std");

/// A 2-component vector.
pub const Vec2 = @Vector(2, f32);

/// A 3-component vector.
pub const Vec3 = @Vector(3, f32);

/// Operations on `Vec3`.
pub const vec3 = struct {
    pub const XUP: Vec3 = .{ 1.0, 0.0, 0.0 };
    pub const YUP: Vec3 = .{ 0.0, 1.0, 0.0 };
    pub const ZUP: Vec3 = .{ 0.0, 0.0, 1.0 };

    pub fn dot(a: Vec3, b: Vec3) f32 {
        return @reduce(.Add, a * b);
    }

    pub fn lengthSq(v: Vec3) f32 {
        return dot(v, v);
    }

    pub fn length(v: Vec3) f32 {
        return @sqrt(lengthSq(v));
    }

    pub fn normalize(v: Vec3) Vec3 {
        const len: Vec3 = @splat(length(v));
        return v / len;
    }

    pub fn cross(a: Vec3, b: Vec3) Vec3 {
        return Vec3{
            a[1] * b[2] - a[2] * b[1],
            a[2] * b[0] - a[0] * b[2],
            a[0] * b[1] - a[1] * b[0],
        };
    }

    pub fn lerp(a: Vec3, b: Vec3, t: f32) Vec3 {
        const t_vec: Vec3 = @splat(t);
        return a + (b - a) * t_vec;
    }

    /// Rotates vector 'v' around normalized axis 'k' by 'angle' (radians).
    pub fn rotate(v: Vec3, k: Vec3, angle: f32) Vec3 {
        const cos_v: Vec3 = @splat(std.math.cos(angle));
        const sin_v: Vec3 = @splat(std.math.sin(angle));

        const term1 = v * cos_v;
        const term2 = vec3.cross(k, v) * sin_v;

        const dot_val = vec3.dot(k, v);
        const one_minus_cos = @as(Vec3, @splat(1.0)) - cos_v;
        const term3 = k * @as(Vec3, @splat(dot_val)) * one_minus_cos;

        return term1 + term2 + term3;
    }
};

/// A 4-component vector.
pub const Vec4 = @Vector(4, f32);

pub const vec4 = struct {
    // TODO: implement anytype-generic versions of these
    pub fn lerp(a: Vec4, b: Vec4, t: f32) Vec4 {
        const t_vec: Vec4 = @splat(t);
        return a + (b - a) * t_vec;
    }
};

/// A 4x4 square matrix.
pub const Mat4 = extern struct {
    /// ## Column-major layout
    /// | `col[0][0]` | `col[1][0]` | `col[2][0]` | `col[3][0]` |
    /// | `col[0][1]` | `col[1][1]` | `col[2][1]` | `col[3][1]` |
    /// | `col[0][2]` | `col[1][2]` | `col[2][2]` | `col[3][2]` |
    /// | `col[0][3]` | `col[1][3]` | `col[2][3]` | `col[3][3]` |
    col: [4][4]f32,

    /// Computes the result of `self * other`.
    pub fn mmul(self: Mat4, other: Mat4) Mat4 {
        var result: Mat4 = undefined;

        // Iterate over the columns of the 'other' matrix
        inline for (0..4) |i| {
            // We use @Vector to enable SIMD instructions.
            // Zig arrays [4]f32 coerce implicitly to @Vector(4, f32)
            var acc: @Vector(4, f32) = @splat(0.0);

            inline for (0..4) |k| {
                // column k of self * scalar k of other's column i
                const self_col: @Vector(4, f32) = self.col[k];
                const scalar: @Vector(4, f32) = @splat(other.col[i][k]);

                // Fused multiply-add if hardware supports it
                acc = @mulAdd(@Vector(4, f32), self_col, scalar, acc);
            }

            result.col[i] = acc;
        }

        return result;
    }

    /// Constructs a perspective projection matrix.
    ///
    /// The matrix transforms vertices from camera space to clip space.
    ///
    /// Camera Space (Right-Handed):
    /// * X is right
    /// * Y is up
    /// * Z is forward (looking out of screen)
    ///
    /// NDC Space:
    /// * X: [-1.0, 1.0]
    /// * Y: [-1.0, 1.0]
    /// * Z: [ 0.0, 1.0]
    pub fn perspective(
        /// The vertical field of view in radians.
        fov: f32,
        /// The width of the viewport divided by the height.
        aspect: f32,
        /// The distance to the near clipping plane (must be > 0).
        near: f32,
        /// The distance to the far clipping plane (must be > nearPlane).
        far: f32,
    ) Mat4 {
        var matrix = std.mem.zeroes(Mat4);

        const yscale = 1.0 / @tan(fov / 2.0);
        const xscale = yscale / aspect;
        const frustum_len = far - near;

        // --- Column 0 ---
        // Scale the X coordinate in camera space
        matrix.col[0][0] = xscale;

        // --- Column 1 ---
        // Scale the Y coordinate in camera space
        matrix.col[1][1] = yscale;

        // --- Column 2 ---
        // Scale the Z coordinate
        matrix.col[2][2] = -far / frustum_len;
        // Perspective divide by Z-distance from the camera
        matrix.col[2][3] = -1.0;

        // --- Column 3 ---
        // Translate the Z coordinate
        matrix.col[3][2] = -(far * near) / frustum_len;

        return matrix;
    }

    /// Constructs a view matrix that looks at a target from a specific position.
    ///
    /// This matrix transforms coordinates from world space to camera space.
    pub fn lookAt(pos: Vec3, target: Vec3, roll: f32) Mat4 {
        const f = vec3.normalize(target - pos);
        const world_up = if (@abs(f[1]) > 0.999) vec3.ZUP else vec3.YUP;
        const r_base = vec3.normalize(vec3.cross(f, world_up));
        const u_base = vec3.cross(r_base, f);

        const c: Vec3 = @splat(std.math.cos(roll));
        const s: Vec3 = @splat(std.math.sin(roll));
        const r = (r_base * c) + (u_base * s);
        const u = (u_base * c) - (r_base * s);

        return .{
            .col = .{
                .{ r[0], u[0], -f[0], 0 },
                .{ r[1], u[1], -f[1], 0 },
                .{ r[2], u[2], -f[2], 0 },
                .{ -vec3.dot(r, pos), -vec3.dot(u, pos), vec3.dot(f, pos), 1 },
            },
        };
    }
};

/// A 4-component vector representing a rotation (x, y, z, w).
pub const Quat = @Vector(4, f32);

/// Operations on `Quat`.
pub const quat = struct {
    /// The identity quaternion (no rotation).
    pub const IDENTITY: Quat = .{ 0.0, 0.0, 0.0, 1.0 }; // TODO: lowercase these

    /// Creates a rotation from an axis and an angle (in radians).
    ///
    /// The axis must be normalized.
    pub fn fromAxisAngle(axis: Vec3, angle: f32) Quat {
        const half_angle = angle * 0.5;
        const s = @sin(half_angle);
        const c = @cos(half_angle);

        return Quat{
            axis[0] * s,
            axis[1] * s,
            axis[2] * s,
            c,
        };
    }

    /// Creates a rotation that aligns the 'start' vector to the 'dest' vector.
    /// Both vectors must be normalized.
    pub fn rotationBetween(start: Vec3, dest: Vec3) Quat {
        const cos_theta = vec3.dot(start, dest);
        var axis: Vec3 = undefined;

        if (cos_theta < -1.0 + 0.001) {
            // Corner case: vectors are exactly opposite
            // We need to rotate 180 degrees around any arbitrary perpendicular axis
            axis = vec3.cross(vec3.ZUP, start);
            if (vec3.lengthSq(axis) < 0.01) {
                // If start was parallel to Z, try X
                axis = vec3.cross(vec3.XUP, start);
            }
            axis = vec3.normalize(axis);
            // Construct 180 degree quaternion (w=0, axis=normalized)
            return .{ axis[0], axis[1], axis[2], 0.0 };
        }

        // Standard case
        axis = vec3.cross(start, dest);

        const s = @sqrt((1.0 + cos_theta) * 2.0);
        const invs = 1.0 / s;

        return .{
            axis[0] * invs,
            axis[1] * invs,
            axis[2] * invs,
            s * 0.5,
        };
    }

    /// Normalizes the quaternion.
    ///
    /// Rotations must always be unit length.
    pub fn normalize(q: Quat) Quat {
        const dot = @reduce(.Add, q * q);
        return q / @as(Quat, @splat(@sqrt(dot)));
    }

    /// Combines two rotations (equivalent to `lhs * rhs`).
    ///
    /// The result represents the rotation of `rhs` followed by `lhs`.
    pub fn mul(lhs: Quat, rhs: Quat) Quat {
        const q1x = lhs[0];
        const q1y = lhs[1];
        const q1z = lhs[2];
        const q1w = lhs[3];
        const q2x = rhs[0];
        const q2y = rhs[1];
        const q2z = rhs[2];
        const q2w = rhs[3];

        return Quat{
            q1w * q2x + q1x * q2w + q1y * q2z - q1z * q2y,
            q1w * q2y - q1x * q2z + q1y * q2w + q1z * q2x,
            q1w * q2z + q1x * q2y - q1y * q2x + q1z * q2w,
            q1w * q2w - q1x * q2x - q1y * q2y - q1z * q2z,
        };
    }
};

/// Convert degrees to radians.
pub fn radians(degrees: f32) f32 {
    return degrees * (std.math.pi / 180.0);
}

/// Hermite interpolation curve.
pub fn smoothstep(t: f32) f32 {
    return t * t * (3.0 - 2.0 * t);
}
