#include <MeshRenderer.h>
#include <RenderWindow.h>
#include <Mesh.h>
#include <Shader.h>

#define INVALID_BUFFER UINT32_MAX

MeshRenderer::MeshRenderer() :
    renderSpace(RenderSpace_World),
    mtxModel(Matrix(1.0f)),
    mesh(NULL)
{
    for (u32 i=0; i<RenderBuffer_MAX; i++)
    {
        buffers[i] = INVALID_BUFFER;
        bufferSizes[i] = 0;
    }
}

MeshRenderer::~MeshRenderer()
{
    ClearBuffers();
}

void MeshRenderer::ClearBuffers()
{
    for (size_t i=0; i<RenderBuffer_MAX; i++)
    {
        if (buffers[i] != INVALID_BUFFER && bufferSizes[i] != 0)
        {
            glDeleteBuffers(1, &buffers[i]);
            buffers[i] = INVALID_BUFFER;
        }
        bufferSizes[i] = 0;
    }
}

void MeshRenderer::Render(const Mesh * mesh, const Matrix * matrix)
{
    mtxModel = Matrix(*matrix);
    SetMesh((Mesh*)mesh);
    Render_Internal();
}

void MeshRenderer::Render(const Mesh * mesh, const Matrix matrix)
{
    mtxModel = Matrix(matrix);
    SetMesh((Mesh*)mesh);
    Render_Internal();
}

void MeshRenderer::SetMesh(const Mesh * newMesh)
{
    if (mesh == newMesh)
        return;

    ClearBuffers();

    mesh = newMesh;
    if (mesh == NULL)
        return;
}

void MeshRenderer::SetRenderSpace(RenderSpace _renderSpace)
{
    renderSpace = _renderSpace;
}

void MeshRenderer::SetBuffer(u32 glBufferType, RenderBuffer renderBuffer, u32 bufferSize, void * data)
{
    u32 previousSize = bufferSizes[renderBuffer];
    if (buffers[renderBuffer] == INVALID_BUFFER) 
    {
        previousSize = 0;
    }
    if (previousSize != bufferSize)
    {
        if (previousSize > 0 && buffers[renderBuffer] != INVALID_BUFFER)
        {
            glDeleteBuffers(1, &buffers[renderBuffer]);
        }

        bufferSizes[renderBuffer] = bufferSize;
        if (bufferSize > 0)
        {
            glGenBuffers(1, &buffers[renderBuffer]);
        }
        else
        {
            buffers[renderBuffer] = INVALID_BUFFER;
        }
    }
    if (bufferSize > 0 && buffers[renderBuffer] != INVALID_BUFFER)
    {
        if (renderBuffer < RenderBuffer_INDEX)
        {
	        glEnableVertexAttribArray(renderBuffer);
        }
        glBindBuffer(glBufferType, buffers[renderBuffer]);
        glBufferData(glBufferType, bufferSize, data, GL_DYNAMIC_DRAW);
    }
}

void MeshRenderer::Render_Internal()
{
    RenderWindow * window = RenderWindow::Current();
    RenderMode renderMode = window->GetRenderMode();

    const Material * material = mesh->material.Inst();
    if (material == NULL)
    {
        material = window->GetDefaultMaterial();
    }

    Matrix mtxView = window->GetViewMatrix();
    Matrix mtxProjection = window->GetProjectionMatrix();
    if (renderSpace == RenderSpace_Screen)
    {
        float width = (float)window->GetViewportWidth();
        float height = (float)window->GetViewportHeight();
        mtxProjection = glm::ortho(0.0f, width, 0.0f, height, 0.0f, 100.0f);
        float scale = width / (float)VIEWPORT_DEFAULT_WIDTH;
        mtxView = glm::scale(Matrix(1.0f), Vector3(1.0f, -1.0f, 1.0f));
        mtxView = glm::translate(mtxView, Vector3(0.0f, height * -1.0f, -99.0f));
    }
	Matrix mvp = mtxProjection * mtxView * mtxModel;

    const ShaderPass * program = NULL;
    for (u8 i = 0; i < UINT8_MAX; i++)
    {
        program = material->Use(i);
        if (program == NULL)
            break;

        if (program->renderMode != window->GetRenderMode())
        {
            window->SetRenderMode((RenderMode)program->renderMode);
        }

        if (material->textureId != 0 && program->uniformSAMPLER >= 0)
        {
            glActiveTexture(GL_TEXTURE0);
            glBindTexture(GL_TEXTURE_2D, material->textureId);
            glUniform1i(program->uniformSAMPLER, 0);
        }

        if (program->uniformMVP >= 0)
        {
            glUniformMatrix4fv(program->uniformMVP, 1, GL_FALSE, &mvp[0][0]);
        }

        if (program->uniformTIME >= 0)
        {
            glUniform1f(program->uniformTIME, (float)SDL_GetTicks() / 1000.0f);
        }

        const std::vector<Vector3> & vertices = mesh->GetVertices();
        u32 vertexCount = Mesh::ClampBuffer(vertices.size());
        if (vertexCount < 1)
            return;
        SetBuffer(GL_ARRAY_BUFFER, RenderBuffer_VERTEX, vertexCount * sizeof(GLfloat) * 3, (void *)&vertices[0]);
        glVertexAttribPointer(RenderBuffer_VERTEX, 3, GL_FLOAT, GL_FALSE, 0, (void *)0);

        const std::vector<Vector2> & uvs = mesh->GetUvs();
        u32 uvCount = Mesh::ClampBuffer(uvs.size());
        if (uvCount > 0)
        {
            SetBuffer(GL_ARRAY_BUFFER, RenderBuffer_UV, uvCount * sizeof(GLfloat) * 2, (void *)&uvs[0]);
            glVertexAttribPointer(RenderBuffer_UV, 2, GL_FLOAT, GL_FALSE, 0, (void *)0);
        }

        const std::vector<Vector3> & normals = mesh->GetNormals();
        u32 normalCount = Mesh::ClampBuffer(normals.size());
        if (normalCount > 0)
        {
            SetBuffer(GL_ARRAY_BUFFER, RenderBuffer_NORMAL, normalCount * sizeof(GLfloat) * 3, (void *)&normals[0]);
            glVertexAttribPointer(RenderBuffer_NORMAL, 3, GL_FLOAT, GL_FALSE, 0, (void *)0);
        }

        const std::vector<Vector3> & colors = mesh->GetColors();
        u32 colorCount = Mesh::ClampBuffer(colors.size());
        if (colorCount > 0)
        {
            SetBuffer(GL_ARRAY_BUFFER, RenderBuffer_COLOR, colorCount * sizeof(GLfloat) * 3, (void *)&colors[0]);
            glVertexAttribPointer(RenderBuffer_COLOR, 3, GL_FLOAT, GL_FALSE, 0, (void *)0);
        }

        const std::vector<u32> & indices = mesh->GetIndices();
        u32 indexCount = Mesh::ClampBuffer(indices.size());
        if (indexCount > 0)
        {
            SetBuffer(GL_ELEMENT_ARRAY_BUFFER, RenderBuffer_INDEX, indexCount * sizeof(u32), (void *)&indices[0]);
            glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, buffers[RenderBuffer_INDEX]);
            glDrawElements(GL_TRIANGLES, indexCount, GL_UNSIGNED_INT, (void *)0);
        }
        else
        {
            glDrawArrays(GL_TRIANGLES, 0, vertexCount);
        }

        for (u32 i = 0; i < RenderBuffer_MAX; i++)
        {
            glDisableVertexAttribArray(i);
        }

        window->SetRenderMode(renderMode);
    }
    window->SetRenderMode(renderMode);
}