#include <Shader.h>
#include <string>
#include <desslibs.h>
#include <FileSystem.h>
#include <string.h>
#include <Hash.h>
#include <RenderWindow.h>

std::map<const char*, Shader> Shader::shaderCache;

const Shader * Shader::LoadShader(const char * shaderPathPrefix)
{
    for (std::map<const char*, Shader>::iterator it = shaderCache.begin(); it != shaderCache.end(); ++it)
    {
        if (strcmp(it->first, shaderPathPrefix) == 0)
        {
            return &(it->second);
        }
    }
    Shader shader;
    if (FileSystem::IsExtenstion(shaderPathPrefix, "dsh"))
    {
        shader.LoadMultipass(shaderPathPrefix);
    }
    else if (LoadShaderPair(&shader, shaderPathPrefix) == false)
    {
        shader.passes.clear();
    }

    shaderCache[shaderPathPrefix] = shader;

    return &shaderCache[shaderPathPrefix];
}

bool Shader::LoadShaderPair(Shader * shader, const char * shaderPathPrefix, u32 existingProgramId)
{
    char shaderPath[64];
    sprintf(shaderPath, "%s.vert", shaderPathPrefix);
    shader->passes.push_back(ShaderPass());
    ShaderPass * pass = &shader->passes.back();
    pass->vertId = LoadAndCompileShader(shaderPath, GL_VERTEX_SHADER);
    if (!pass->vertId)
    {
	    Log_Error("Failed to load vertex shader\n");
        return false;
    }
    
    sprintf(shaderPath, "%s.frag", shaderPathPrefix);
    pass->fragId = LoadAndCompileShader(shaderPath, GL_FRAGMENT_SHADER);
    if (!pass->fragId)
    {
	    Log_Error("Failed to load fragment shader\n");
        return false;
    }

    pass->programId = LinkShader(pass->vertId, pass->fragId, existingProgramId);
    pass->uniformSAMPLER = glGetUniformLocation(pass->programId, "SAMPLER");
    pass->uniformMVP = glGetUniformLocation(pass->programId, "MVP");
    pass->uniformTIME = glGetUniformLocation(pass->programId, "TIME");

	return true;
}

u32 Shader::LinkShader(u32 vertId, u32 fragId, u32 existingProgramId)
{
    u32 programId = existingProgramId;
    if (programId == 0)
    {
        programId = glCreateProgram();
    }
	glAttachShader(programId, vertId);
	glAttachShader(programId, fragId);
	glLinkProgram(programId);

	GLint result = GL_FALSE;
    glGetProgramiv(programId, GL_LINK_STATUS, &result);
    if (result) 
    {
	    glDetachShader(programId, vertId);
	    glDetachShader(programId, fragId);
    } 
    else
    {
        GLint infoLogLength;
        glGetProgramiv(programId, GL_INFO_LOG_LENGTH, &infoLogLength);
        if (infoLogLength > 0)
        {
            char * infoLog = new char[infoLogLength];
            glGetProgramInfoLog(programId, infoLogLength, NULL, (GLchar *)infoLog);
            Log_Error("Shader program load error%s\n", infoLog);
        }
        return 0;
    }
	
	glDeleteShader(vertId);
	glDeleteShader(fragId);
    return programId;
}

u32 Shader::LoadAndCompileShader(const char * filePath, u32 shaderType)
{
    std::string shaderCode = FileSystem::LoadText(filePath);
    if (shaderCode.length() < 1)
    {
	    Log_Error("Shader load failed: %s\n", filePath);
        return 0;
    }

	Log_Info("Shader compile: %s\n", filePath);
    return CompileShader(shaderCode.c_str(), shaderType);
}

const HashInt HASH_PASS(Hash("#PASS")); 
const HashInt HASH_VERT(Hash("#VERT"));
const HashInt HASH_FRAG(Hash("#FRAG"));
const HashInt HASH_DEPTHLESS(Hash("DEPTHLESS"));
#define SHADER_PREFIX "#version 330 core\n"

void Shader::LoadMultipass(const char * shaderPath)
{
    std::string shaderCode = FileSystem::LoadText(shaderPath);
    passes.clear();
    u32 shaderType = 0;
    size_t passStartI = 0, lineStartI = 0, opEndI;
    char c;
    HashInt opHash = 0;
    HashInt mode = HASH_PASS;
    bool completePass = false;
    passes.push_back(ShaderPass());
    ShaderPass * pass = &passes.back();
    std::string passCode;
    std::string line;
    std::string op;
    for (size_t i=0, lim=shaderCode.size(); i <= lim; i++)
    {
        c = shaderCode[i];
        if (i > 0 && (c == '\0' || c == '\n'))
        {
            line = shaderCode.substr(lineStartI, i - lineStartI - 1);
            opEndI = line.find_first_of(' ');
            if (opEndI > 0)
            {
                op = line.substr(0, opEndI);
                opHash = Hash(op.c_str());
            }
            else
            {
                opHash = Hash(line.c_str());
            }
            if (opHash == HASH_PASS || opHash == HASH_FRAG || opHash == HASH_VERT || c == '\0')
            {
                if (c == '\0')
                {
                    lineStartI = i;
                }
                if (mode == HASH_FRAG || mode == HASH_VERT)
                {
                    passCode = SHADER_PREFIX + shaderCode.substr(passStartI, lineStartI - passStartI);
                    if (mode == HASH_VERT)
                    {
                        pass->vertId = CompileShader(passCode.c_str(), GL_VERTEX_SHADER);
                        if (pass->vertId == 0)
                        {
                            Log_Error("Failed to load vert shader program %s\n", shaderPath);
                        }
                    }
                    else if (mode == HASH_FRAG)
                    {
                        pass->fragId = CompileShader(passCode.c_str(), GL_FRAGMENT_SHADER);
                        if (pass->fragId == 0)
                        {
                            Log_Error("Failed to load frag shader program %s\n", shaderPath);
                        }
                    }
                }
                if (c == '\0')
                {
                    break;
                }
                if (opHash == HASH_PASS)
                {
                    if (mode != HASH_PASS)
                    {
                        passes.push_back(ShaderPass());
                        pass = &passes.back();
                    }
                    if (line.find("DEPTHLESS") != std::string::npos)
                    {
                        pass->renderMode = RenderMode_Depthless;
                    }
                    else if (line.find("WIREFRAME") != std::string::npos)
                    {
                        pass->renderMode = RenderMode_Wireframe;
                    }
                }
                mode = opHash;
                passStartI = i;
            }
            lineStartI = i + 1;
        }
    }
    for (std::vector<ShaderPass>::iterator it = passes.begin(); it != passes.end(); ++it)
    {
        it->programId = LinkShader(it->vertId, it->fragId);
        if (it->programId == 0)
        {
            Log_Error("Failed to link shader %s\n", shaderPath);
        }
        else
        {
            it->uniformSAMPLER = glGetUniformLocation(it->programId, "SAMPLER");
            it->uniformMVP = glGetUniformLocation(it->programId, "MVP");
            it->uniformTIME = glGetUniformLocation(it->programId, "TIME");
        }
    }
}

const ShaderPass * Shader::GetPass(u8 index) const
{
    if (index >= GetPassCount())
        return NULL;
    return &passes[index];
}

u32 Shader::GetProgram(u8 index) const
{
    if (index >= GetPassCount())
        return 0;
    return passes[index].programId;
}

u32 Shader::CompileShader(const char * shaderCode, u32 shaderType)
{
	GLuint shaderId = glCreateShader(shaderType);
	glShaderSource(shaderId, 1, &shaderCode, NULL);
	glCompileShader(shaderId);
	
	GLint result = GL_FALSE;
    glGetShaderiv(shaderId, GL_COMPILE_STATUS, &result);
    if (result == GL_FALSE) 
    {
        GLint infoLogLength;
        glGetShaderiv(shaderId, GL_INFO_LOG_LENGTH, &infoLogLength);
        if (infoLogLength > 0)
        {
            char * infoLog = new char[infoLogLength];
            glGetShaderInfoLog(shaderId, infoLogLength, NULL, infoLog);
            Log_Error("Shader error: %s\n", infoLog);
        }
        shaderId = 0;
    }

	return shaderId;
}