import http.server
import json
import os
import socket
import socketserver
import threading
from datetime import datetime
from urllib.parse import parse_qs, urlparse

PORT = 9000
RELAY_TOKEN = os.environ.get("ENGMODE_RELAY_TOKEN", "kaios-local-relay")
relay_next_id = 1
relay_commands = []
relay_results = []
relay_clients = {}
relay_active_client = None
ACTIVE_CLIENT_TTL_SECONDS = 300
relay_condition = threading.Condition()

agent_next_id = 1
agent_commands = []
agent_results = []
agent_clients = {}
agent_condition = threading.Condition()


def send_cors(handler, status=200, content_type="text/plain"):
    handler._cors_headers_sent = True
    handler.send_response(status)
    handler.send_header("Access-Control-Allow-Origin", "*")
    handler.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
    handler.send_header("Access-Control-Allow-Headers", "Content-Type")
    handler.send_header("Cache-Control", "no-store")
    handler.send_header("Connection", "close")
    handler.send_header("Content-Type", content_type)
    handler.end_headers()


def safe_filename(name):
    cleaned = os.path.basename(str(name)).replace("\\", "_").replace("/", "_")
    return "".join(ch if ch.isalnum() or ch in "._-" else "_" for ch in cleaned)[:120]


def save_probe_text(prefix, client, phase, content):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = f"{prefix}_{timestamp}_{safe_filename(client or 'client')}_{safe_filename(phase or 'event')}.txt"
    with open(output_file, "w", encoding="utf-8") as f:
        f.write(content)
    return output_file


def get_lan_addresses():
    addresses = set()

    try:
        hostname = socket.gethostname()
        for info in socket.getaddrinfo(hostname, None, socket.AF_INET):
            addr = info[4][0]
            if not addr.startswith("127."):
                addresses.add(addr)
    except OSError:
        pass

    try:
        probe = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        probe.connect(("8.8.8.8", 80))
        addr = probe.getsockname()[0]
        if not addr.startswith("127."):
            addresses.add(addr)
        probe.close()
    except OSError:
        pass

    return sorted(addresses)


def remember_client(client, address, kind="poll"):
    relay_clients[client] = {
        "last_poll": datetime.now().isoformat(timespec="seconds"),
        "address": address,
        "kind": kind,
    }


def remember_agent(client, address, kind="poll"):
    agent_clients[client] = {
        "last_poll": datetime.now().isoformat(timespec="seconds"),
        "address": address,
        "kind": kind,
    }


def active_client_is_fresh():
    if not relay_active_client:
        return False
    info = relay_clients.get(relay_active_client)
    if not info:
        return False
    try:
        last_poll = datetime.fromisoformat(info.get("last_poll", ""))
    except ValueError:
        return False
    return (datetime.now() - last_poll).total_seconds() <= ACTIVE_CLIENT_TTL_SECONDS


class CaptureHandler(http.server.SimpleHTTPRequestHandler):
    def end_headers(self):
        if not getattr(self, "_cors_headers_sent", False):
            self.send_header("Access-Control-Allow-Origin", "*")
            self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
            self.send_header("Access-Control-Allow-Headers", "Content-Type")
        self.send_header("Cache-Control", "no-store")
        return super().end_headers()

    def do_GET(self):
        global relay_commands, relay_clients, relay_active_client

        parsed = urlparse(self.path)
        if parsed.path == "/relay/wait":
            query = parse_qs(parsed.query)
            token = (query.get("token") or [""])[0]
            client = (query.get("client") or ["unknown"])[0]
            timeout = min(max(int((query.get("timeout") or ["55"])[0]), 1), 55)
            if token != RELAY_TOKEN:
                send_cors(self, 403, "application/json")
                self.wfile.write(b'{"error":"bad token"}')
                return

            remember_client(client, self.client_address[0], "wait")

            command = None
            with relay_condition:
                if relay_active_client and not active_client_is_fresh():
                    print(f"\n[relay] active client expired: {relay_active_client}")
                    relay_active_client = None

                if relay_commands and (relay_active_client is None or client == relay_active_client):
                    relay_active_client = client
                    command = relay_commands.pop(0)
                elif not relay_commands:
                    relay_condition.wait(timeout=timeout)
                    remember_client(client, self.client_address[0], "wait")
                    if relay_active_client and not active_client_is_fresh():
                        relay_active_client = None
                    if relay_commands and (relay_active_client is None or client == relay_active_client):
                        relay_active_client = client
                        command = relay_commands.pop(0)

            if command:
                command["client"] = client
                print(f"\n[relay] assigned #{command['id']} to {client}")

            send_cors(self, 200, "application/json")
            self.wfile.write(json.dumps(command or {"id": None, "command": None}).encode("utf-8"))
            return

        if parsed.path == "/relay/next":
            query = parse_qs(parsed.query)
            token = (query.get("token") or [""])[0]
            client = (query.get("client") or ["unknown"])[0]
            if token != RELAY_TOKEN:
                send_cors(self, 403, "application/json")
                self.wfile.write(b'{"error":"bad token"}')
                return

            remember_client(client, self.client_address[0])

            send_cors(self, 200, "application/json")
            self.wfile.write(json.dumps({"id": None, "command": None}).encode("utf-8"))
            return

        if parsed.path == "/relay/status":
            if relay_active_client and not active_client_is_fresh():
                relay_active_client = None
            send_cors(self, 200, "application/json")
            self.wfile.write(
                json.dumps(
                    {
                        "queued": len(relay_commands),
                        "results": len(relay_results),
                        "clients": relay_clients,
                        "active_client": relay_active_client,
                        "token_hint": "set ENGMODE_RELAY_TOKEN to change the default",
                    },
                    indent=2,
                ).encode("utf-8")
            )
            return

        if parsed.path == "/agent/status":
            send_cors(self, 200, "application/json")
            self.wfile.write(
                json.dumps(
                    {
                        "queued": len(agent_commands),
                        "results": len(agent_results),
                        "clients": agent_clients,
                        "token_hint": "set ENGMODE_RELAY_TOKEN to change the default",
                    },
                    indent=2,
                ).encode("utf-8")
            )
            return

        if parsed.path == "/agent/wait":
            query = parse_qs(parsed.query)
            token = (query.get("token") or [""])[0]
            client = (query.get("client") or ["unknown"])[0]
            timeout = min(max(int((query.get("timeout") or ["55"])[0]), 1), 55)
            if token != RELAY_TOKEN:
                send_cors(self, 403, "text/plain")
                self.wfile.write(b"ERR bad token\n")
                return

            remember_agent(client, self.client_address[0], "wait")
            command = None
            with agent_condition:
                if not agent_commands:
                    agent_condition.wait(timeout=timeout)
                    remember_agent(client, self.client_address[0], "wait")
                if agent_commands:
                    command = agent_commands.pop(0)

            send_cors(self, 200, "text/plain")
            if command:
                command["client"] = client
                print(f"\n[agent] assigned #{command['id']} {command['action']} to {client}")
                self.wfile.write(
                    (
                        f"ID={command['id']}\n"
                        f"ACTION={command['action']}\n"
                        f"ARG={command.get('arg', '')}\n"
                    ).encode("utf-8")
                )
            else:
                self.wfile.write(b"NONE\n")
            return

        if parsed.path == "/agent/ping":
            query = parse_qs(parsed.query)
            token = (query.get("token") or [""])[0]
            client = (query.get("client") or ["unknown"])[0]
            if token != RELAY_TOKEN:
                send_cors(self, 403, "text/plain")
                self.wfile.write(b"ERR bad token\n")
                return
            remember_agent(client, self.client_address[0], "ping")
            send_cors(self, 200, "text/plain")
            self.wfile.write(b"ok\n")
            return

        if parsed.path == "/relay/ping":
            query = parse_qs(parsed.query)
            client = (query.get("client") or ["unknown"])[0]
            phase = (query.get("phase") or ["ping"])[0]
            remember_client(client, self.client_address[0], f"ping:{phase}")
            if phase == "script-start" or not active_client_is_fresh():
                relay_active_client = client
            print(f"[relay] ping from {client} at {self.client_address[0]}")
            if (query.get("save") or [""])[0] == "1":
                content = {
                    "path": parsed.path,
                    "client": client,
                    "phase": phase,
                    "address": self.client_address[0],
                    "query": {key: values[-1] if values else "" for key, values in query.items()},
                    "saved_at": datetime.now().isoformat(timespec="seconds"),
                }
                output_file = save_probe_text(
                    "relay_ping",
                    client,
                    phase,
                    json.dumps(content, indent=2, sort_keys=True),
                )
                print(f"[relay] ping saved to {output_file}")
            send_cors(self, 200, "text/plain")
            self.wfile.write(b"ok")
            return

        if parsed.path == "/probe/beacon":
            query = parse_qs(parsed.query)
            client = (query.get("client") or ["unknown"])[0]
            phase = (query.get("phase") or ["beacon"])[0]
            content = {
                "path": parsed.path,
                "client": client,
                "phase": phase,
                "address": self.client_address[0],
                "query": {key: values[-1] if values else "" for key, values in query.items()},
                "saved_at": datetime.now().isoformat(timespec="seconds"),
            }
            output_file = save_probe_text(
                "probe_beacon",
                client,
                phase,
                json.dumps(content, indent=2, sort_keys=True),
            )
            print(f"[probe] beacon saved to {output_file}")
            send_cors(self, 200, "image/gif")
            self.wfile.write(
                b"GIF89a\x01\x00\x01\x00\x80\x00\x00\x00\x00\x00\xff\xff\xff!"
                b"\xf9\x04\x01\x00\x00\x00\x00,\x00\x00\x00\x00\x01\x00\x01\x00"
                b"\x00\x02\x02D\x01\x00;"
            )
            return

        if self.path == "/":
            self.path = "/shell_capture_debug.js"

        print(f"[GET] {self.path}")
        return super().do_GET()

    def do_OPTIONS(self):
        send_cors(self, 204)

    def do_POST(self):
        global relay_next_id, relay_commands, relay_results, relay_active_client
        global agent_next_id, agent_commands, agent_results

        try:
            parsed = urlparse(self.path)
            content_length = int(self.headers.get("Content-Length", 0))
            post_data = self.rfile.read(content_length)
            content_type = self.headers.get("Content-Type", "")

            if parsed.path == "/relay/command":
                data = json.loads(post_data.decode("utf-8", errors="replace") or "{}")
                if data.get("token") != RELAY_TOKEN:
                    send_cors(self, 403, "application/json")
                    self.wfile.write(b'{"error":"bad token"}')
                    return
                command = str(data.get("command") or "").strip()
                script = str(data.get("script") or "")
                if not command:
                    send_cors(self, 400, "application/json")
                    self.wfile.write(b'{"error":"empty command"}')
                    return

                item = {
                    "id": relay_next_id,
                    "command": command,
                    "mode": str(data.get("mode") or "shell"),
                    "wait_ms": int(data.get("wait_ms") or 8000),
                    "queued_at": datetime.now().isoformat(timespec="seconds"),
                }
                if script:
                    item["script"] = script
                relay_next_id += 1
                relay_commands.append(item)
                with relay_condition:
                    relay_condition.notify_all()
                print(f"\n[relay] queued #{item['id']}: {command}")
                send_cors(self, 200, "application/json")
                self.wfile.write(json.dumps(item).encode("utf-8"))
                return

            if parsed.path == "/relay/clear":
                data = json.loads(post_data.decode("utf-8", errors="replace") or "{}")
                if data.get("token") != RELAY_TOKEN:
                    send_cors(self, 403, "application/json")
                    self.wfile.write(b'{"error":"bad token"}')
                    return
                cleared = len(relay_commands)
                relay_commands.clear()
                print(f"\n[relay] cleared {cleared} queued command(s)")
                send_cors(self, 200, "application/json")
                self.wfile.write(json.dumps({"cleared": cleared}).encode("utf-8"))
                return

            if parsed.path == "/relay/active":
                data = json.loads(post_data.decode("utf-8", errors="replace") or "{}")
                if data.get("token") != RELAY_TOKEN:
                    send_cors(self, 403, "application/json")
                    self.wfile.write(b'{"error":"bad token"}')
                    return
                relay_active_client = data.get("client") or None
                print(f"\n[relay] active client set to {relay_active_client}")
                send_cors(self, 200, "application/json")
                self.wfile.write(json.dumps({"active_client": relay_active_client}).encode("utf-8"))
                return

            if parsed.path == "/relay/result":
                data = json.loads(post_data.decode("utf-8", errors="replace") or "{}")
                if data.get("token") != RELAY_TOKEN:
                    send_cors(self, 403, "application/json")
                    self.wfile.write(b'{"error":"bad token"}')
                    return

                relay_results.append(data)
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                rid = safe_filename(data.get("id") or "unknown")
                client = safe_filename(data.get("client") or "client")
                output_file = f"relay_result_{timestamp}_{client}_{rid}.txt"
                content = (
                    f"id={data.get('id')}\n"
                    f"client={data.get('client')}\n"
                    f"command={data.get('command')}\n"
                    f"started_at={data.get('started_at')}\n"
                    f"finished_at={data.get('finished_at')}\n\n"
                    f"{data.get('output', '')}"
                )
                with open(output_file, "w", encoding="utf-8") as f:
                    f.write(content)

                print(f"\n[relay] result #{data.get('id')} saved to {output_file}")
                print(f"[relay] Preview: {str(data.get('output', ''))[:300]}...")
                send_cors(self, 200, "application/json")
                self.wfile.write(b'{"ok":true}')
                return

            if parsed.path == "/agent/command":
                data = json.loads(post_data.decode("utf-8", errors="replace") or "{}")
                if data.get("token") != RELAY_TOKEN:
                    send_cors(self, 403, "application/json")
                    self.wfile.write(b'{"error":"bad token"}')
                    return
                command = str(data.get("command") or "").strip()
                action = str(data.get("action") or "").strip()
                allowed = {"identity", "service_tree", "remote_tree", "restore_updater", "stop", "shell"}
                if command:
                    action = "shell"
                if action not in allowed:
                    send_cors(self, 400, "application/json")
                    self.wfile.write(json.dumps({"error": "bad action", "allowed": sorted(allowed)}).encode("utf-8"))
                    return
                if action == "shell" and not command:
                    command = str(data.get("arg") or "").strip()
                if action == "shell" and not command:
                    send_cors(self, 400, "application/json")
                    self.wfile.write(b'{"error":"empty command"}')
                    return

                item = {
                    "id": agent_next_id,
                    "action": action,
                    "arg": command if action == "shell" else str(data.get("arg") or ""),
                    "queued_at": datetime.now().isoformat(timespec="seconds"),
                }
                agent_next_id += 1
                agent_commands.append(item)
                with agent_condition:
                    agent_condition.notify_all()
                if action == "shell":
                    print(f"\n[agent] queued #{item['id']}: {command}")
                else:
                    print(f"\n[agent] queued #{item['id']}: {action}")
                send_cors(self, 200, "application/json")
                self.wfile.write(json.dumps(item).encode("utf-8"))
                return

            if parsed.path == "/agent/clear":
                data = json.loads(post_data.decode("utf-8", errors="replace") or "{}")
                if data.get("token") != RELAY_TOKEN:
                    send_cors(self, 403, "application/json")
                    self.wfile.write(b'{"error":"bad token"}')
                    return
                cleared = len(agent_commands)
                agent_commands.clear()
                print(f"\n[agent] cleared {cleared} queued command(s)")
                send_cors(self, 200, "application/json")
                self.wfile.write(json.dumps({"cleared": cleared}).encode("utf-8"))
                return

            if parsed.path == "/agent/result":
                query = parse_qs(parsed.query)
                token = (query.get("token") or [""])[0]
                if token != RELAY_TOKEN:
                    send_cors(self, 403, "text/plain")
                    self.wfile.write(b"ERR bad token\n")
                    return
                rid = (query.get("id") or ["unknown"])[0]
                client = (query.get("client") or ["agent"])[0]
                action = (query.get("action") or ["unknown"])[0]
                output = post_data.decode("utf-8", errors="replace")
                agent_results.append({"id": rid, "client": client, "action": action, "output": output})
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                output_file = f"agent_result_{timestamp}_{safe_filename(client)}_{safe_filename(rid)}_{safe_filename(action)}.txt"
                with open(output_file, "w", encoding="utf-8") as f:
                    f.write(f"id={rid}\nclient={client}\naction={action}\n\n{output}")
                print(f"\n[agent] result #{rid} {action} saved to {output_file}")
                print(f"[agent] Preview: {output[:300]}...")
                send_cors(self, 200, "text/plain")
                self.wfile.write(b"ok\n")
                return

            if "application/json" in content_type:
                data = json.loads(post_data.decode("utf-8", errors="replace"))
                filename = data.get("filename") or "shell_output.txt"
                content = data.get("data") or ""
                safe_name = safe_filename(filename)
                output_file = f"captured_{safe_name}"
            else:
                content = post_data.decode("utf-8", errors="replace")
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                output_file = f"captured_shell_output_{timestamp}.txt"

            with open(output_file, "w", encoding="utf-8") as f:
                f.write(content)

            print("\n[!] SUCCESS: Received data!")
            print(f"[!] Saved to: {output_file}")
            print(f"[!] Preview: {content[:300]}...")

            send_cors(self, 200)
            self.wfile.write(b"Success")
        except Exception as e:
            print(f"[!] Error: {e}")
            send_cors(self, 500)


os.chdir(os.path.dirname(os.path.abspath(__file__)))
socketserver.ThreadingTCPServer.allow_reuse_address = True

print(f"Server running at http://0.0.0.0:{PORT}")
for address in get_lan_addresses():
    print(f"LAN URL: http://{address}:{PORT}")
print(f"Engmode relay token: {RELAY_TOKEN}")

with socketserver.ThreadingTCPServer(("", PORT), CaptureHandler) as httpd:
    httpd.serve_forever()
