#ifndef MAT4_H
#define MAT4_H

#define Swap(t, a, b) { t temp = a; a = b; b = temp; }

MAT4
mat4_zero()
{
    MAT4 result = {
        { 0.0f, 0.0f, 0.0f, 0.0f },
        { 0.0f, 0.0f, 0.0f, 0.0f },
        { 0.0f, 0.0f, 0.0f, 0.0f },
        { 0.0f, 0.0f, 0.0f, 0.0f }
    };
    return(result);
}

MAT4
mat4_identity()
{
    MAT4 result = {
        { 1.0f, 0.0f, 0.0f, 0.0f },
        { 0.0f, 1.0f, 0.0f, 0.0f },
        { 0.0f, 0.0f, 1.0f, 0.0f },
        { 0.0f, 0.0f, 0.0f, 1.0f }
    };
    return(result);
}

F32
mat4_det(MAT4 mat)
{
    F32 a = mat.m0.x, b = mat.m0.y, c = mat.m0.z, d = mat.m0.w;
    F32 e = mat.m1.x, f = mat.m1.y, g = mat.m1.z, h = mat.m1.w;
    F32 i = mat.m2.x, j = mat.m2.y, k = mat.m2.z, l = mat.m2.w;
    F32 m = mat.m3.x, n = mat.m3.y, o = mat.m3.z, p = mat.m3.w;

    F32 aminor = (f*k*p)+(j*o*h)+(g*l*n)-(h*k*n)-(g*j*p)-(f*l*o);
    F32 eminor = (b*k*p)+(j*o*d)+(c*l*n)-(n*k*d)-(o*l*b)-(j*c*p);
    F32 iminor = (b*g*p)+(c*h*n)+(f*o*h)-(n*g*d)-(o*h*b)-(f*c*p);
    F32 mminor = (b*g*l)+(c*h*j)+(f*k*d)-(j*g*d)-(f*c*l)-(k*h*b);

    return(a*aminor+e*eminor-i*iminor+m*mminor);
}

MAT4
mat4_transpose(MAT4 mat)
{
    MAT4 result = mat;

    Swap(F32, result.m0.y, result.m1.x);
    Swap(F32, result.m0.z, result.m2.x);
    Swap(F32, result.m0.w, result.m3.x);

    Swap(F32, result.m1.z, result.m2.y);
    Swap(F32, result.m1.w, result.m3.y);

    Swap(F32, result.m2.w, result.m3.z);

    return(result);
}

MAT4
mat4_mul(MAT4 left, MAT4 right)
{
    F32 l00 = left.m0.x, l01 = left.m0.y, l02 = left.m0.z, l03 = left.m0.w;
    F32 l10 = left.m1.x, l11 = left.m1.y, l12 = left.m1.z, l13 = left.m1.w;
    F32 l20 = left.m2.x, l21 = left.m2.y, l22 = left.m2.z, l23 = left.m2.w;
    F32 l30 = left.m3.x, l31 = left.m3.y, l32 = left.m3.z, l33 = left.m3.w;

    F32 r00 = right.m0.x, r01 = right.m0.y, r02 = right.m0.z, r03 = right.m0.w;
    F32 r10 = right.m1.x, r11 = right.m1.y, r12 = right.m1.z, r13 = right.m1.w;
    F32 r20 = right.m2.x, r21 = right.m2.y, r22 = right.m2.z, r23 = right.m2.w;
    F32 r30 = right.m3.x, r31 = right.m3.y, r32 = right.m3.z, r33 = right.m3.w;

    MAT4 result;

    result.m0.x = l00*r00+l10*r01+l20*r02+l30*r03;
    result.m0.y = l01*r00+l11*r01+l21*r02+l31*r03;
    result.m0.z = l02*r00+l12*r01+l22*r02+l32*r03;
    result.m0.w = l03*r00+l13*r01+l23*r02+l33*r03;

    result.m1.x = l00*r10+l10*r11+l20*r12+l30*r13;
    result.m1.y = l01*r10+l11*r11+l21*r12+l31*r13;
    result.m1.z = l02*r10+l12*r11+l22*r12+l32*r13;
    result.m1.w = l03*r10+l13*r11+l23*r12+l33*r13;

    result.m2.x = l00*r20+l10*r21+l20*r22+l30*r23;
    result.m2.y = l01*r20+l11*r21+l21*r22+l31*r23;
    result.m2.z = l02*r20+l12*r21+l22*r22+l32*r23;
    result.m2.w = l03*r20+l13*r21+l23*r22+l33*r23;

    result.m3.x = l00*r30+l10*r31+l20*r32+l30*r33;
    result.m3.y = l01*r30+l11*r31+l21*r32+l31*r33;
    result.m3.z = l02*r30+l12*r31+l22*r32+l32*r33;
    result.m3.w = l03*r30+l13*r31+l23*r32+l33*r33;

    return(result);
}

MAT4
mat4_make_translate(V3F vec)
{
    MAT4 result = mat4_identity();
    result.m3.x = vec.x;
    result.m3.y = vec.y;
    result.m3.z = vec.z;
    return(result);
}

MAT4
mat4_translate(MAT4 mat, V3F vec)
{
    MAT4 translate = mat4_make_translate(vec);
    MAT4 result = mat4_mul(translate, mat);
    return(result);
}

MAT4
mat4_make_scale(V3F scale)
{
    MAT4 result = mat4_identity();
    result.m0.x = scale.x;
    result.m1.y = scale.y;
    result.m2.z = scale.z;
    return(result);
}

MAT4
mat4_scale(MAT4 mat, V3F scale)
{
    MAT4 mat_scale = mat4_make_scale(scale);
    MAT4 result = mat4_mul(mat_scale, mat);
    return(result);
}

V4F
mat4_v4f_mul(MAT4 m, V4F v)
{
    V4F result;
    result.x = m.m0.x*v.x+m.m1.x*v.y+m.m2.x*v.z+m.m3.x*v.w;
    result.y = m.m0.y*v.x+m.m1.y*v.y+m.m2.y*v.z+m.m3.y*v.w;
    result.z = m.m0.z*v.x+m.m1.z*v.y+m.m2.z*v.z+m.m3.z*v.w;
    result.w = m.m0.w*v.x+m.m1.w*v.y+m.m2.w*v.z+m.m3.w*v.w;
    return(result);
}

void
mat4_print(MAT4 a)
{
    printf("%.4f %.4f %.4f %.4f\n",
           a.m0.x, a.m1.x, a.m2.x, a.m3.x);
    printf("%.4f %.4f %.4f %.4f\n",
           a.m0.y, a.m1.y, a.m2.y, a.m3.y);
    printf("%.4f %.4f %.4f %.4f\n",
           a.m0.z, a.m1.z, a.m2.z, a.m3.z);
    printf("%.4f %.4f %.4f %.4f\n\n",
           a.m0.w, a.m1.w, a.m2.w, a.m3.w);
}

#endif /* MAT4_H */