package ctrlflow

import (
	"fmt"
	"go/ast"
	"go/token"
	"go/types"
	"log"
	"math"
	mathrand "math/rand"
	"os"
	"strconv"
	"strings"

	ah "github.com/guno1928/alosgarble/internal/asthelper"
	"github.com/guno1928/alosgarble/internal/ssa2ast"
	"golang.org/x/tools/go/ast/astutil"
	"golang.org/x/tools/go/ssa"
)

const (
	mergedFileName = "GARBLE_controlflow.go"
	directiveName  = "//garble:controlflow"
	importPrefix   = "___garble_import"

	defaultBlockSplits   = 0
	defaultJunkJumps     = 0
	defaultFlattenPasses = 1
	defaultTrashBlocks   = 0

	maxBlockSplits   = math.MaxInt32
	maxJunkJumps     = 256
	maxFlattenPasses = 4
	maxTrashBlocks   = 1024

	minTrashBlockStmts = 1
	maxTrashBlockStmts = 32
)

type directiveParamMap map[string]string

func (m directiveParamMap) GetInt(name string, def, max int) (int, error) {
	rawVal, ok := m[name]
	if !ok {
		return def, nil
	}

	if rawVal == "max" {
		return max, nil
	}

	val, err := strconv.Atoi(rawVal)
	if err != nil {
		return 0, fmt.Errorf("invalid flag %q format: %v", name, err)
	}
	if val > max {
		return 0, fmt.Errorf("too big flag %q value: %d (max: %d)", name, val, max)
	}
	return val, nil
}

func (m directiveParamMap) StringSlice(name string) []string {
	rawVal, ok := m[name]
	if !ok {
		return nil
	}

	slice := strings.Split(rawVal, ",")
	if len(slice) == 0 {
		return nil
	}
	return slice
}

func parseDirective(directive string) (directiveParamMap, bool) {
	fieldsStr, ok := strings.CutPrefix(directive, directiveName)
	if !ok {
		return nil, false
	}

	fields := strings.Fields(fieldsStr)
	if len(fields) == 0 {
		return nil, true
	}
	m := make(map[string]string)
	for _, v := range fields {
		key, value, ok := strings.Cut(v, "=")
		if ok {
			m[key] = value
		} else {
			m[key] = ""
		}
	}
	return m, true
}

func Obfuscate(fset *token.FileSet, ssaPkg *ssa.Package, files []*ast.File, obfRand *mathrand.Rand) (newFileName string, newFile *ast.File, affectedFiles []*ast.File, err error) {
	var ssaFuncs []*ssa.Function
	var ssaParams []directiveParamMap

	for _, file := range files {
		affected := false
		for _, decl := range file.Decls {
			funcDecl, ok := decl.(*ast.FuncDecl)
			if !ok || funcDecl.Doc == nil {
				continue
			}

			for _, comment := range funcDecl.Doc.List {
				params, hasDirective := parseDirective(comment.Text)
				if !hasDirective {
					continue
				}

				path, _ := astutil.PathEnclosingInterval(file, funcDecl.Pos(), funcDecl.Pos())
				ssaFunc := ssa.EnclosingFunction(ssaPkg, path)
				if ssaFunc == nil {
					panic("function exists in ast but not found in ssa")
				}

				ssaFuncs = append(ssaFuncs, ssaFunc)
				ssaParams = append(ssaParams, params)

				log.Printf("detected function for controlflow %s (params: %v)", funcDecl.Name.Name, params)

				funcDecl.Name = ast.NewIdent("_")
				funcDecl.Body = ah.BlockStmt()
				funcDecl.Recv = nil
				funcDecl.Type = &ast.FuncType{Params: &ast.FieldList{}}
				affected = true

				break
			}
		}

		if affected {
			affectedFiles = append(affectedFiles, file)
		}
	}

	if len(ssaFuncs) == 0 {
		return
	}

	newFile = &ast.File{
		Package: token.Pos(fset.Base()),
		Name:    ast.NewIdent(files[0].Name.Name),
	}
	fset.AddFile(mergedFileName, int(newFile.Package), 1)

	funcConfig := ssa2ast.DefaultConfig()
	imports := make(map[string]string)
	funcConfig.ImportNameResolver = func(pkg *types.Package) *ast.Ident {
		if pkg == nil || pkg.Path() == ssaPkg.Pkg.Path() {
			return nil
		}

		name, ok := imports[pkg.Path()]
		if !ok {
			name = importPrefix + strconv.Itoa(len(imports))
			imports[pkg.Path()] = name
			astutil.AddNamedImport(fset, newFile, name, pkg.Path())
		}
		return ast.NewIdent(name)
	}

	var trashGen *trashGenerator

	for idx, ssaFunc := range ssaFuncs {
		params := ssaParams[idx]

		split, err := params.GetInt("block_splits", defaultBlockSplits, maxBlockSplits)
		if err != nil {
			return "", nil, nil, fmt.Errorf("controlflow directive on %s: %w", ssaFunc, err)
		}
		junkCount, err := params.GetInt("junk_jumps", defaultJunkJumps, maxJunkJumps)
		if err != nil {
			return "", nil, nil, fmt.Errorf("controlflow directive on %s: %w", ssaFunc, err)
		}
		passes, err := params.GetInt("flatten_passes", defaultFlattenPasses, maxFlattenPasses)
		if err != nil {
			return "", nil, nil, fmt.Errorf("controlflow directive on %s: %w", ssaFunc, err)
		}
		if passes == 0 {
			fmt.Fprintf(os.Stderr, "control flow obfuscation for %q function has no effect on the resulting binary, to fix this flatten_passes must be greater than zero", ssaFunc)
		}
		flattenHardening := params.StringSlice("flatten_hardening")

		trashBlockCount, err := params.GetInt("trash_blocks", defaultTrashBlocks, maxTrashBlocks)
		if err != nil {
			return "", nil, nil, fmt.Errorf("controlflow directive on %s: %w", ssaFunc, err)
		}
		if trashBlockCount > 0 && trashGen == nil {
			trashGen = newTrashGenerator(ssaPkg.Prog, funcConfig.ImportNameResolver, obfRand)
		}

		applyObfuscation := func(ssaFunc *ssa.Function) []dispatcherInfo {
			if trashBlockCount > 0 {
				addTrashBlockMarkers(ssaFunc, trashBlockCount, obfRand)
			}
			for range split {
				if !applySplitting(ssaFunc, obfRand) {
					break
				}
			}
			if junkCount > 0 {
				addJunkBlocks(ssaFunc, junkCount, obfRand)
			}
			var dispatchers []dispatcherInfo
			for range passes {
				if info := applyFlattening(ssaFunc, obfRand); info != nil {
					dispatchers = append(dispatchers, info)
				}
			}
			fixBlockIndexes(ssaFunc)
			return dispatchers
		}

		dispatchers := applyObfuscation(ssaFunc)
		for _, anonFunc := range ssaFunc.AnonFuncs {
			dispatchers = append(dispatchers, applyObfuscation(anonFunc)...)
		}

		var prologues []ast.Stmt
		if len(flattenHardening) > 0 && len(dispatchers) > 0 {
			hardening := newDispatcherHardening(flattenHardening)

			ssaRemap := make(map[ssa.Value]ast.Expr)
			for _, dispatcher := range dispatchers {
				decl, stmt := hardening.Apply(dispatcher, ssaRemap, obfRand)
				if decl != nil {
					newFile.Decls = append(newFile.Decls, decl)
				}
				if stmt != nil {
					prologues = append(prologues, stmt)
				}
			}
			funcConfig.SsaValueRemap = ssaRemap
		} else {
			funcConfig.SsaValueRemap = nil
		}

		funcConfig.MarkerInstrCallback = nil
		if trashBlockCount > 0 {
			funcConfig.MarkerInstrCallback = func(m map[string]types.Type) []ast.Stmt {
				return trashGen.Generate(minTrashBlockStmts+obfRand.Intn(maxTrashBlockStmts-minTrashBlockStmts), m)
			}
		}

		astFunc, err := ssa2ast.Convert(ssaFunc, funcConfig)
		if err != nil {
			return "", nil, nil, err
		}
		if len(prologues) > 0 {
			astFunc.Body.List = append(prologues, astFunc.Body.List...)
		}
		newFile.Decls = append(newFile.Decls, astFunc)
	}

	newFileName = mergedFileName
	return
}
