
#include "statement.h"
#include "expression.h"
#include "type.h"

#define indent(depth, txt) for(int i = 0; i < depth; i++) os << "    "; os << txt << std::endl;



Statement::~Statement()
{
}

// ===========================================================

Block::Block(std::vector<Statement*>* statements)
{
    this->statements = statements;
}
Block::~Block()
{
    for(Statement* statement : *statements)
        delete statement;
    delete statements;
}
void Block::dump(std::ostream& os, int depth)
{
    indent(depth, "Block")
    for(Statement* statement : *statements)
        statement->dump(os, depth+1);
}

void Block::translate(TranslationContext* ctx)
{
    ctx->symboltable->push_scope();
    for(Statement* statement : *statements)
        statement->translate(ctx);
    ctx->symboltable->pop_scope();
}

// ===========================================================

VarDeclaration::VarDeclaration(Type* type, std::string* name, RValue* value)
{
    this->type = type;
    this->name = name;
    this->value = value;
}
VarDeclaration::~VarDeclaration()
{
    delete type;
    delete name;
    delete value;
}
void VarDeclaration::dump(std::ostream& os, int depth)
{
    indent(depth, "VarDeclaration " << *name)
    type->dump(os, depth+1);
    value->dump(os, depth+1);
}
void VarDeclaration::translate(TranslationContext* ctx)
{
    // move the alloca to the beginning of the function
    llvm::Type* typ = type->translate(ctx);
    llvm::Value* ptr = ctx->entryblockBuilder->CreateAlloca(typ, nullptr, *name);
    ctx->symboltable->insert(typ, *name, ptr);

    // put only an assignment to the ptr at the position of the statement
    llvm::Value* val = value->translate(ctx);
    ctx->builder->CreateStore(val, ptr);
}


// ===========================================================

Assignment::Assignment(LValue* lvalue, RValue* rvalue)
{
    this->lvalue = lvalue;
    this->rvalue = rvalue;
}
Assignment::~Assignment()
{
    delete lvalue;
    delete rvalue;
}
void Assignment::dump(std::ostream& os, int depth)
{
    indent(depth, "Assignment")
    lvalue->dump(os, depth+1);
    rvalue->dump(os, depth+1);
}
void Assignment::translate(TranslationContext* ctx)
{
    llvm::Value* lv = lvalue->translateLValue(ctx);
    llvm::Value* rv = rvalue->translate(ctx);
    ctx->builder->CreateStore(rv, lv);
}

// ===========================================================

Input::Input(LValue* lvalue)
{
    this->lvalue = lvalue;
}
Input::~Input()
{
    delete lvalue;
}
void Input::dump(std::ostream& os, int depth)
{
    indent(depth, "Input")
    lvalue->dump(os, depth+1);
}
void Input::translate(TranslationContext* ctx)
{
    llvm::Type* charptrTy = llvm::PointerType::getUnqual(ctx->builder->getInt8Ty());
    llvm::FunctionCallee scanfFnc = ctx->llvmmodule->getOrInsertFunction("scanf",
        llvm::FunctionType::get(ctx->builder->getInt32Ty(), charptrTy, true));

    llvm::Value* format = ctx->builder->CreateGlobalString("%i");
    std::vector<llvm::Value*> args {format, lvalue->translateLValue(ctx)};
    ctx->builder->CreateCall(scanfFnc, args);
}

// ===========================================================

Output::Output(RValue* rvalue)
{
    this->rvalue = rvalue;
}
Output::~Output()
{
    delete rvalue;
}
void Output::dump(std::ostream& os, int depth)
{
    indent(depth, "Output")
    rvalue->dump(os, depth+1);
}
void Output::translate(TranslationContext* ctx)
{
    llvm::Type* charptrTy = llvm::PointerType::getUnqual(ctx->builder->getInt8Ty());
    llvm::FunctionCallee printfFnc = ctx->llvmmodule->getOrInsertFunction("printf",
        llvm::FunctionType::get(ctx->builder->getInt32Ty(), charptrTy, true));

    llvm::Value* format = ctx->builder->CreateGlobalString("%i\n");
    std::vector<llvm::Value*> args {format, rvalue->translate(ctx)};
    ctx->builder->CreateCall(printfFnc, args);
}

// ===========================================================

IfThenElse::IfThenElse(BooleanExpression* condition, Statement* thenCase, Statement* elseCase)
{
    this->condition = condition;
    this->thenCase = thenCase;
    this->elseCase = elseCase;
}
IfThenElse::~IfThenElse()
{
    delete condition;
    delete thenCase;
    delete elseCase;
}
void IfThenElse::dump(std::ostream& os, int depth)
{
    indent(depth, "IfThenElse")
    condition->dump(os, depth+1);
    thenCase->dump(os, depth+1);
    if(elseCase)
        elseCase->dump(os, depth+1);
}
void IfThenElse::translate(TranslationContext* ctx)
{
    llvm::BasicBlock* thenBB = llvm::BasicBlock::Create(*(ctx->llvmcontext), "_then");
    llvm::BasicBlock* elseBB = elseCase != nullptr
                             ? llvm::BasicBlock::Create(*(ctx->llvmcontext), "_else")
                             : nullptr;
    llvm::BasicBlock* endifBB = llvm::BasicBlock::Create(*(ctx->llvmcontext), "_endif");

    llvm::Value* cond = condition->translate(ctx);
    ctx->builder->CreateCondBr(cond, thenBB,
                               elseBB != nullptr ? elseBB : endifBB);

    // then
    thenBB->insertInto(ctx->func);
    ctx->builder->SetInsertPoint(thenBB);
    thenCase->translate(ctx);
    ctx->builder->CreateBr(endifBB);

    // else
    if(elseBB != nullptr)
    {
        elseBB->insertInto(ctx->func);
        ctx->builder->SetInsertPoint(elseBB);
        elseCase->translate(ctx);
        ctx->builder->CreateBr(endifBB);
    }

    // end-if
    endifBB->insertInto(ctx->func);
    ctx->builder->SetInsertPoint(endifBB);
}

// ===========================================================

WhileLoop::WhileLoop(BooleanExpression* condition, Statement* body)
{
    this->condition = condition;
    this->body = body;
}
WhileLoop::~WhileLoop()
{
    delete condition;
    delete body;
}
void WhileLoop::dump(std::ostream& os, int depth)
{
    indent(depth, "WhileLoop")
    condition->dump(os, depth+1);
    body->dump(os, depth+1);
}
void WhileLoop::translate(TranslationContext* ctx)
{
    llvm::BasicBlock* whileBB = llvm::BasicBlock::Create(*(ctx->llvmcontext), "_while");
    llvm::BasicBlock* doBB = llvm::BasicBlock::Create(*(ctx->llvmcontext), "_do");
    llvm::BasicBlock* doneBB = llvm::BasicBlock::Create(*(ctx->llvmcontext), "_done");
    ctx->builder->CreateBr(whileBB);

    // while
    whileBB->insertInto(ctx->func);
    ctx->builder->SetInsertPoint(whileBB);
    llvm::Value* cond = condition->translate(ctx);
    ctx->builder->CreateCondBr(cond, doBB, doneBB);

    // do
    doBB->insertInto(ctx->func);
    ctx->builder->SetInsertPoint(doBB);
    body->translate(ctx);
    ctx->builder->CreateBr(whileBB);

    // done
    doneBB->insertInto(ctx->func);
    ctx->builder->SetInsertPoint(doneBB);
}


