#!/usr/bin/env python3
"""
Axiom Streaming Client - Async Hebbian Learning

Replaces BOINC work unit model with continuous streaming for sample-by-sample
Hebbian learning. "Neurons that fire together, wire together."

Key features:
- Async data loading (never blocks training)
- Hebbian updates (no backward pass, ~130x less memory than backprop)
- WebSocket weight sync (gossip protocol, non-blocking)
- Sample-by-sample processing (1000+ samples/sec vs BOINC's 1-2)

Usage:
    python axiom_streaming_client.py [--expert N] [--server ws://host:port]
"""

import asyncio
import json
import os
import sys
import time
import struct
import base64
import hashlib
import argparse
import signal
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, List
import queue
import threading
import gc
import psutil

# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from simple_ml import MoEConfig, ExpertWorker, HAS_NUMPY
if HAS_NUMPY:
    import numpy as np

# Optional WebSocket support
HAS_WEBSOCKETS = False
try:
    import websockets
    HAS_WEBSOCKETS = True
except ImportError:
    print("[Client] websockets not installed - running in offline mode")
    print("[Client] Install with: pip install websockets")


# =============================================================================
# Configuration
# =============================================================================

@dataclass
class ClientConfig:
    """Configuration for the streaming client."""
    # Expert assignment (-1 = request from coordinator)
    expert_idx: int = -1
    auto_assign: bool = True  # Request expert assignment from coordinator

    # User authentication (for credit tracking)
    authenticator: str = ""

    # Model configuration (matches server's MoE config)
    input_size: int = 64
    output_size: int = 2
    num_experts: int = 420
    expert_hidden: int = 3072       # Full transformer FFN size
    expert_layers: int = 6          # Full depth
    expert_type: str = "transformer"  # 42.6M params per expert → 17.8B total
    d_model: int = 768
    n_heads: int = 12

    # Training batch size (for Hebbian updates)
    mini_batch_size: int = 1        # True sample-by-sample Hebbian learning

    # Hebbian learning parameters
    hebbian_lr: float = 0.001
    hebbian_decay: float = 0.999

    # Sync parameters
    sync_interval: int = 1000  # Sync weights every N samples
    gossip_alpha: float = 0.15  # 85/15 split: 85% local, 15% peer

    # Server connection
    server_url: str = "ws://65.21.196.61:8765"
    model_server_url: str = "https://axiom.heliex.net"

    # Data paths
    contribute_path: Optional[Path] = None
    cache_path: Optional[Path] = None

    # Performance
    data_buffer_size: int = 500   # Reduced from 10000 to save memory
    batch_report_interval: int = 100

    def __post_init__(self):
        if self.contribute_path is None:
            self.contribute_path = Path.home() / "Axiom" / "contribute"
        if self.cache_path is None:
            self.cache_path = Path.home() / "Axiom" / ".cache"


# =============================================================================
# Data Producer (Async file reading)
# =============================================================================

class DataProducer:
    """Async producer that continuously reads local files into a sample queue."""

    EXCLUDED_PATTERNS = (
        '*.env', '*.pem', '*.key', '*.p12', '*.pfx',
        '*password*', '*secret*', '*credential*'
    )

    COMPRESSED_EXTENSIONS = (
        '.zip', '.gz', '.gzip', '.bz2', '.xz', '.7z', '.rar', '.tar',
        '.tgz', '.tbz2', '.txz', '.lz', '.lzma', '.lz4', '.zst', '.zstd',
        '.cab', '.arj', '.lzh', '.z'
    )

    def __init__(self, config: ClientConfig):
        self.config = config
        self.queue = asyncio.Queue(maxsize=config.data_buffer_size)
        self.running = False
        self.files_processed = 0
        self.samples_produced = 0
        self.warned_compressed = set()  # Track files we've warned about

    def is_compressed(self, filepath: Path) -> bool:
        """Check if file is a compressed archive."""
        return filepath.suffix.lower() in self.COMPRESSED_EXTENSIONS

    def should_exclude(self, filepath: Path) -> bool:
        """Check if file should be excluded from training."""
        name_lower = str(filepath).lower()
        for pattern in self.EXCLUDED_PATTERNS:
            if pattern.startswith('*') and pattern.endswith('*'):
                if pattern[1:-1] in name_lower:
                    return True
            elif pattern.startswith('*'):
                if name_lower.endswith(pattern[1:]):
                    return True
        parts = filepath.parts
        if any(p in ('.git', '.ssh', '.gnupg', '__pycache__', 'node_modules', '.gradients') for p in parts):
            return True
        return False

    async def start(self):
        """Start producing samples."""
        self.running = True
        self.config.contribute_path.mkdir(parents=True, exist_ok=True)

        print(f"[DataProducer] Reading from: {self.config.contribute_path}")

        while self.running:
            files = []
            for filepath in sorted(self.config.contribute_path.rglob('*')):
                if filepath.is_file() and not self.should_exclude(filepath):
                    files.append(filepath)

            if not files:
                print("[DataProducer] No files found, waiting...")
                await asyncio.sleep(5)
                continue

            for filepath in files:
                if not self.running:
                    break

                # Skip compressed files with warning (once per file)
                if self.is_compressed(filepath):
                    if filepath not in self.warned_compressed:
                        self.warned_compressed.add(filepath)
                        print(f"[DataProducer] Skipping compressed file (extract first): {filepath.name}")
                    continue

                try:
                    await self._process_file(filepath)
                    self.files_processed += 1
                except Exception as e:
                    print(f"[DataProducer] Error reading {filepath}: {e}")

            # Small delay before re-scanning
            await asyncio.sleep(0.1)

    async def _process_file(self, filepath: Path):
        """Read a file and produce samples from it."""
        window_size = self.config.input_size

        # Read file in chunks to avoid blocking
        bits = []
        with open(filepath, 'rb') as f:
            while True:
                chunk = f.read(4096)
                if not chunk:
                    break
                for byte in chunk:
                    for i in range(8):
                        bits.append((byte >> (7 - i)) & 1)

                # Yield control periodically
                if len(bits) % 32768 == 0:
                    await asyncio.sleep(0)

        # Generate samples with sliding window
        for i in range(len(bits) - window_size):
            sample = bits[i:i + window_size]
            target = bits[i + window_size]

            await self.queue.put((sample, target))
            self.samples_produced += 1

            # Yield control every 100 samples
            if self.samples_produced % 100 == 0:
                await asyncio.sleep(0)

    async def get_sample(self):
        """Get next sample from queue."""
        return await self.queue.get()

    def stop(self):
        """Stop producing samples."""
        self.running = False


# =============================================================================
# Weight Sync (WebSocket gossip)
# =============================================================================

class WeightSync:
    """Async weight synchronization via WebSocket gossip protocol."""

    def __init__(self, config: ClientConfig, expert: ExpertWorker, worker_id: int = 0):
        self.config = config
        self.expert = expert
        self.worker_id = worker_id
        self.running = False
        self.ws = None
        self.reference_weights = None
        self.pending_deltas = asyncio.Queue()
        self.sync_count = 0
        self.connected = False

    async def start(self):
        """Start the sync loop."""
        self.running = True
        self.reference_weights = self.expert.get_weights()

        if not HAS_WEBSOCKETS:
            print("[WeightSync] WebSocket not available - offline mode")
            return

        if not self.config.server_url:
            print("[WeightSync] No server URL - offline mode")
            while self.running:
                await asyncio.sleep(1)
            return

        while self.running:
            try:
                await self._connect_and_sync()
            except Exception as e:
                print(f"[WeightSync] Connection error: {e}")
                self.connected = False
                await asyncio.sleep(5)  # Retry after 5 seconds

    async def _connect_and_sync(self):
        """Connect to server and run sync loop."""
        print(f"[WeightSync] Connecting to {self.config.server_url}...")

        async with websockets.connect(
            self.config.server_url,
            ping_interval=None,  # Disable client-side pings (server handles it)
            ping_timeout=None,
            open_timeout=30,     # Allow 30 seconds for initial handshake
            close_timeout=10
        ) as ws:
            self.ws = ws
            self.connected = True
            print("[WeightSync] Connected!")

            # Register our expert with authenticator for credit tracking
            await ws.send(json.dumps({
                "type": "register",
                "expert_idx": self.config.expert_idx,
                "param_count": self.expert.get_param_count(),
                "authenticator": self.config.authenticator
            }))

            # Run send/receive loops concurrently
            await asyncio.gather(
                self._send_loop(),
                self._receive_loop()
            )

    async def _send_loop(self):
        """Send weight deltas to server."""
        while self.running and self.connected:
            try:
                # Check for pending deltas (non-blocking)
                delta = await asyncio.wait_for(
                    self.pending_deltas.get(),
                    timeout=1.0
                )

                # Compress and send with samples for credit tracking
                compressed = self._compress_delta(delta)
                await self.ws.send(json.dumps({
                    "type": "weight_delta",
                    "expert_idx": self.config.expert_idx,
                    "delta": compressed,
                    "samples": self.config.sync_interval  # samples since last sync
                }))
                self.sync_count += 1
                print(f"[Sync] Worker-{self.worker_id} Expert-{self.config.expert_idx}: Sync #{self.sync_count} complete")

            except asyncio.TimeoutError:
                pass  # No delta to send
            except Exception as e:
                print(f"[WeightSync] Send error: {e}")
                break

    async def _receive_loop(self):
        """Receive weight deltas from peers."""
        while self.running and self.connected:
            try:
                msg = await asyncio.wait_for(
                    self.ws.recv(),
                    timeout=0.1
                )

                data = json.loads(msg)
                if data.get("type") == "peer_delta":
                    # Apply peer's weight delta
                    delta = self._decompress_delta(data["delta"])
                    self.expert.apply_weight_delta(delta, alpha=self.config.gossip_alpha)
                    print(f"[Sync] Worker-{self.worker_id}: Received peer delta (15% merge)")

            except asyncio.TimeoutError:
                pass  # No message
            except Exception as e:
                if self.running:
                    print(f"[WeightSync] Receive error: {e}")
                break

    def queue_delta(self, delta):
        """Queue a weight delta for sending (non-blocking)."""
        try:
            self.pending_deltas.put_nowait(delta)
        except asyncio.QueueFull:
            pass  # Drop if queue full

    def _compress_delta(self, delta) -> str:
        """Compress weight delta using sparse encoding."""
        if HAS_NUMPY:
            arr = np.array(delta, dtype=np.float32)
            scale = np.abs(arr).max()
            if scale < 1e-10:
                scale = 1.0

            # Top-K sparsification
            k = min(1000, len(arr))
            indices = np.argsort(np.abs(arr))[-k:]

            # Quantize to int8
            buff = bytearray()
            buff.extend(struct.pack('<IfI', len(arr), scale, len(indices)))
            for idx in indices:
                q = int(max(-127, min(127, round(arr[idx] / scale * 127))))
                buff.extend(struct.pack('<I', idx))
                buff.append(q & 0xFF)

            return base64.b64encode(bytes(buff)).decode()
        else:
            # Fallback: just JSON encode top values
            indexed = sorted(enumerate(delta), key=lambda x: abs(x[1]), reverse=True)[:100]
            return json.dumps(indexed)

    def _decompress_delta(self, compressed: str) -> list:
        """Decompress weight delta."""
        if compressed.startswith('['):
            # JSON format
            indexed = json.loads(compressed)
            delta = [0.0] * max(idx for idx, _ in indexed) + 1
            for idx, val in indexed:
                delta[idx] = val
            return delta
        else:
            # Binary sparse format
            buff = base64.b64decode(compressed)
            vec_len, scale, count = struct.unpack('<IfI', buff[:12])
            delta = [0.0] * vec_len
            offset = 12
            for _ in range(count):
                idx = struct.unpack('<I', buff[offset:offset+4])[0]
                q = buff[offset + 4]
                if q > 127:
                    q -= 256
                delta[idx] = q / 127.0 * scale
                offset += 5
            return delta

    def get_and_reset_delta(self):
        """Get current weight delta and reset reference."""
        delta = self.expert.get_weight_delta(self.reference_weights)
        self.reference_weights = self.expert.get_weights()
        return delta

    def stop(self):
        """Stop sync loop."""
        self.running = False
        self.connected = False


# =============================================================================
# Training Loop (Hebbian)
# =============================================================================

class EndocrineSystem:
    """
    Hormone-like GC modulation system with real system metrics.

    Uses ACTUAL system state (like a real endocrine system):
    - Total CPU usage (all cores averaged) → cortisol/adrenaline
    - System memory PERCENTAGE → triggers GC at 95% system memory
    - No arbitrary thresholds - adapts to any system

    Target: CPU at 95%+, memory under 95% of system RAM.
    """

    def __init__(self, target_cpu: float = 95.0, memory_threshold_pct: float = 85.0):
        """
        Memory threshold is PERCENTAGE of total system RAM.
        GC triggers when system memory usage crosses threshold.
        """
        self.cortisol = 0.0         # Stress level (0-1) - high CPU = defer GC
        self.adrenaline = 0.0       # Boost signal (0-1) - low CPU = work harder
        self.last_gc_time = time.time()
        self.gc_count = 0

        # System metrics
        self.cpu_usage = 0.0        # Total CPU % (all cores averaged)
        self.memory_pct = 0.0       # System memory usage %
        self.memory_gb = 0.0        # For display
        self.total_ram_gb = 0.0     # Total system RAM
        self.target_cpu = target_cpu
        self.memory_threshold_pct = memory_threshold_pct

    def update_metrics(self):
        """Sample current system metrics."""
        try:
            # Total CPU across ALL cores
            self.cpu_usage = psutil.cpu_percent(interval=None)

            # System-wide memory percentage
            mem = psutil.virtual_memory()
            self.memory_pct = mem.percent
            self.memory_gb = mem.used / (1024 ** 3)
            self.total_ram_gb = mem.total / (1024 ** 3)
        except:
            self.cpu_usage = 50.0
            self.memory_pct = 50.0
            self.memory_gb = 4.0
            self.total_ram_gb = 16.0

        # Cortisol: HIGH when total CPU is high (system stressed, defer GC!)
        if self.cpu_usage > 90:
            self.cortisol = min(1.0, self.cortisol + 0.15)
        elif self.cpu_usage > 80:
            self.cortisol = min(1.0, self.cortisol + 0.05)
        elif self.cpu_usage < 60:
            self.cortisol = max(0.0, self.cortisol - 0.1)
        else:
            self.cortisol = max(0.0, self.cortisol - 0.02)

        # Adrenaline: HIGH when CPU is LOW (spare capacity, push harder!)
        if self.cpu_usage < 50:
            self.adrenaline = min(1.0, self.adrenaline + 0.2)
        elif self.cpu_usage < 70:
            self.adrenaline = min(1.0, self.adrenaline + 0.1)
        elif self.cpu_usage > 90:
            self.adrenaline = max(0.0, self.adrenaline - 0.02)
        else:
            self.adrenaline = max(0.0, self.adrenaline - 0.05)

    def should_gc(self) -> str:
        """Check if GC needed based on SYSTEM MEMORY %. Returns: 'none', 'light', 'major', 'sleep'"""

        # Percentage thresholds
        light_thresh = self.memory_threshold_pct * 0.7   # e.g., 35% if threshold is 50%
        major_thresh = self.memory_threshold_pct * 0.85  # e.g., 42.5% if threshold is 50%

        # Critical - full flush when at or above threshold (always trigger)
        if self.memory_pct >= self.memory_threshold_pct:
            return 'sleep'

        # High cortisol (busy CPU) suppresses non-critical GC
        if self.cortisol > 0.7 and self.memory_pct < major_thresh:
            return 'none'

        # Major GC when memory getting high
        if self.memory_pct >= major_thresh:
            return 'major'

        # Light GC at lower threshold
        if self.memory_pct >= light_thresh:
            return 'light'

        # Time-based fallback - major GC every 60 seconds minimum
        if time.time() - self.last_gc_time > 60:
            return 'major'

        return 'none'

    def do_gc(self, level: str):
        """Perform garbage collection at specified level."""
        if level == 'none':
            return

        if level == 'light':
            gc.collect(0)  # Generation 0 only
        elif level == 'major':
            gc.collect(1)  # Gen 0 + 1
        elif level == 'sleep':
            # Full "glymphatic flush"
            gc.collect(2)  # All generations
            gc.collect()   # And again

        self.last_gc_time = time.time()
        self.gc_count += 1

    def status(self) -> str:
        return f"cpu={self.cpu_usage:.0f}% mem={self.memory_pct:.0f}%({self.memory_gb:.1f}GB) cort={self.cortisol:.1f} adren={self.adrenaline:.1f} gc={self.gc_count}"


class HebbianTrainer:
    """Main training loop using Hebbian learning."""

    def __init__(self, config: ClientConfig, worker_id: int = 0):
        self.config = config
        self.worker_id = worker_id
        self.expert = None
        self.data_producer = None
        self.weight_sync = None
        self.running = False

        # Endocrine system for GC modulation
        self.endocrine = EndocrineSystem()
        self.last_gc_check = time.time()

        # Stats
        self.samples_trained = 0
        self.start_time = None
        self.last_report_time = None
        self.last_report_samples = 0
        self.samples_since_sync = 0  # Track progress toward next sync

        # Shared counters for multiprocessing (set by run_single_worker)
        self.shared_samples = None
        self.shared_syncs = None
        self.shared_progress = None  # Per-worker progress array
        self.shared_experts = None   # Per-worker expert assignment array

    async def _request_expert_assignment(self) -> int:
        """Request expert assignment from coordinator."""
        import random
        if not HAS_WEBSOCKETS or not self.config.server_url:
            # Fallback to random if no server
            return random.randint(0, self.config.num_experts - 1)

        try:
            async with websockets.connect(
                self.config.server_url,
                open_timeout=10,
                close_timeout=5
            ) as ws:
                await ws.send(json.dumps({
                    "type": "request_assignment",
                    "authenticator": self.config.authenticator
                }))

                response = await asyncio.wait_for(ws.recv(), timeout=10)
                data = json.loads(response)

                if data.get("type") == "assignment":
                    expert_idx = data.get("expert_idx", 0)
                    print(f"[Worker-{self.worker_id}] Assigned expert {expert_idx}")
                    return expert_idx

        except Exception as e:
            print(f"[Trainer] Failed to get assignment: {e}, using random")

        # Fallback to random
        return random.randint(0, self.config.num_experts - 1)

    async def initialize(self):
        """Initialize expert and components."""

        # Request expert assignment from coordinator if auto_assign is enabled
        if self.config.auto_assign and self.config.expert_idx < 0:
            self.config.expert_idx = await self._request_expert_assignment()

        print(f"[Trainer] Initializing expert {self.config.expert_idx}...")

        # Create MoE config
        moe_config = MoEConfig(
            input_size=self.config.input_size,
            output_size=self.config.output_size,
            num_experts=self.config.num_experts,
            expert_hidden=self.config.expert_hidden,
            expert_layers=self.config.expert_layers,
            expert_type=self.config.expert_type,
            d_model=self.config.d_model,
            n_heads=self.config.n_heads,
        )

        # Clear old cache on startup (fresh weights each session)
        self._clear_cache()

        # Try to load cached weights (disabled - always start fresh)
        expert_weights = None  # self._load_cached_weights()

        # Create expert worker
        self.expert = ExpertWorker(
            config=moe_config,
            expert_idx=self.config.expert_idx,
            expert_weights=expert_weights or [],
            seed=self.config.expert_idx
        )

        param_count = self.expert.get_param_count()
        print(f"[Trainer] Expert {self.config.expert_idx} initialized: {param_count:,} parameters")
        print(f"[Trainer] Memory usage: ~{param_count * 4 / 1024 / 1024:.1f} MB")

        # Initialize data producer and weight sync
        self.data_producer = DataProducer(self.config)
        self.weight_sync = WeightSync(self.config, self.expert, self.worker_id)

    def _clear_cache(self):
        """Clear all cached weight files on startup."""
        cache_dir = self.config.cache_path / "models"
        if cache_dir.exists():
            try:
                import shutil
                shutil.rmtree(cache_dir)
                print(f"[Trainer] Cleared weight cache")
            except Exception as e:
                print(f"[Trainer] Failed to clear cache: {e}")

    def _load_cached_weights(self) -> Optional[list]:
        """Load cached expert weights if available."""
        cache_file = self.config.cache_path / "models" / f"expert_{self.config.expert_idx}.bin"
        if cache_file.exists():
            try:
                if HAS_NUMPY:
                    weights = np.fromfile(str(cache_file), dtype=np.float32).tolist()
                    print(f"[Trainer] Loaded cached weights: {len(weights):,} params")
                    return weights
            except Exception as e:
                print(f"[Trainer] Failed to load cache: {e}")
        return None

    def _save_cached_weights(self):
        """Save current weights to cache."""
        cache_dir = self.config.cache_path / "models"
        cache_dir.mkdir(parents=True, exist_ok=True)
        cache_file = cache_dir / f"expert_{self.config.expert_idx}.bin"

        try:
            weights = self.expert.get_weights()
            if HAS_NUMPY:
                np.array(weights, dtype=np.float32).tofile(str(cache_file))
        except Exception as e:
            print(f"[Trainer] Failed to save cache: {e}")

    async def train_loop(self):
        """Main training loop with mini-batching for efficiency."""
        self.running = True
        self.start_time = time.time()
        self.last_report_time = self.start_time

        batch_size = self.config.mini_batch_size

        print("[Trainer] Starting Hebbian training loop...")
        print(f"[Trainer] Learning rate: {self.config.hebbian_lr}")
        print(f"[Trainer] Weight decay: {self.config.hebbian_decay}")
        print(f"[Trainer] Mini-batch size: {batch_size}")
        print(f"[Trainer] Sync interval: {self.config.sync_interval} samples")

        # Collect samples into mini-batches for efficiency
        sample_buffer = []

        while self.running:
            try:
                # Get next sample (blocks if queue empty)
                sample, target = await asyncio.wait_for(
                    self.data_producer.get_sample(),
                    timeout=5.0
                )

                sample_buffer.append(sample)

                # Process batch when full
                if len(sample_buffer) >= batch_size:
                    # Convert to numpy batch
                    if HAS_NUMPY:
                        x = np.array(sample_buffer, dtype=np.float32)
                    else:
                        x = sample_buffer

                    # Forward pass (batched)
                    output = self.expert.forward(x)

                    # Hebbian update (batched - much more efficient!)
                    self.expert.hebbian_update(
                        lr=self.config.hebbian_lr,
                        decay=self.config.hebbian_decay
                    )

                    # Clear references immediately
                    del output
                    del x

                    batch_size = len(sample_buffer)
                    self.samples_trained += batch_size
                    if self.shared_samples is not None:
                        with self.shared_samples.get_lock():
                            self.shared_samples.value += batch_size
                    self.samples_since_sync += batch_size
                    sample_buffer = []

                    # Periodic sync - check if we've reached the interval
                    if self.samples_since_sync >= self.config.sync_interval:
                        delta = self.weight_sync.get_and_reset_delta()
                        self.weight_sync.queue_delta(delta)
                        if self.shared_syncs is not None:
                            with self.shared_syncs.get_lock():
                                self.shared_syncs.value += 1
                        del delta
                        self.samples_since_sync = 0  # Reset progress

                    # Periodic reporting
                    if self.samples_trained % self.config.batch_report_interval == 0:
                        self._report_progress()

                    # Periodic cache save (disabled - fresh weights each session)
                    # if self.samples_trained % 10000 == 0:
                    #     self._save_cached_weights()

                    # Yield control briefly
                    await asyncio.sleep(0)

            except asyncio.TimeoutError:
                # Process partial batch if we have samples
                if sample_buffer:
                    if HAS_NUMPY:
                        x = np.array(sample_buffer, dtype=np.float32)
                    else:
                        x = sample_buffer
                    output = self.expert.forward(x)
                    self.expert.hebbian_update(
                        lr=self.config.hebbian_lr,
                        decay=self.config.hebbian_decay
                    )
                    batch_size = len(sample_buffer)
                    self.samples_trained += batch_size
                    if self.shared_samples is not None:
                        with self.shared_samples.get_lock():
                            self.shared_samples.value += batch_size
                    self.samples_since_sync += batch_size
                    sample_buffer = []

                    # Check for sync
                    if self.samples_since_sync >= self.config.sync_interval:
                        delta = self.weight_sync.get_and_reset_delta()
                        self.weight_sync.queue_delta(delta)
                        if self.shared_syncs is not None:
                            with self.shared_syncs.get_lock():
                                self.shared_syncs.value += 1
                        del delta
                        self.samples_since_sync = 0

                    self._report_progress()
                else:
                    print("[Trainer] Waiting for data...")
            except Exception as e:
                print(f"[Trainer] Error: {e}")
                import traceback
                traceback.print_exc()
                await asyncio.sleep(0.1)

    def _report_progress(self):
        """Print training progress."""
        now = time.time()
        elapsed = now - self.start_time
        samples_since_report = self.samples_trained - self.last_report_samples
        time_since_report = now - self.last_report_time

        if time_since_report > 0:
            samples_per_sec = samples_since_report / time_since_report
        else:
            samples_per_sec = 0

        # Update system metrics for endocrine system
        self.endocrine.update_metrics()

        # Update shared progress arrays for main process to read
        if self.shared_progress is not None:
            self.shared_progress[self.worker_id] = self.samples_since_sync
        if self.shared_experts is not None:
            self.shared_experts[self.worker_id] = self.config.expert_idx

        self.last_report_time = now
        self.last_report_samples = self.samples_trained

    async def run(self):
        """Run all components concurrently."""
        await self.initialize()

        # Start all tasks
        tasks = [
            asyncio.create_task(self.data_producer.start()),
            asyncio.create_task(self.train_loop()),
        ]

        if HAS_WEBSOCKETS:
            tasks.append(asyncio.create_task(self.weight_sync.start()))

        try:
            await asyncio.gather(*tasks)
        except KeyboardInterrupt:
            print("\n[Trainer] Shutting down...")
        finally:
            self.running = False
            self.data_producer.stop()
            self.weight_sync.stop()
            # self._save_cached_weights()  # Disabled - fresh weights each session
            print(f"[Trainer] Final: {self.samples_trained:,} samples trained")

    def stop(self):
        """Stop training."""
        self.running = False


# =============================================================================
# Parallel Multi-Expert Training
# =============================================================================

def run_single_worker(worker_id: int, server_url: str, lr: float, decay: float,
                      sync_interval: int, expert_type: str, batch_size: int,
                      memory_threshold_pct: float, authenticator: str = "",
                      auto_assign: bool = True,
                      shared_samples=None, shared_syncs=None,
                      shared_progress=None, shared_experts=None):
    """Run a single expert worker (for multiprocessing)."""
    config = ClientConfig(
        expert_idx=-1 if auto_assign else worker_id,  # -1 = request from coordinator
        auto_assign=auto_assign,
        server_url=server_url,
        hebbian_lr=lr,
        hebbian_decay=decay,
        sync_interval=sync_interval,
        expert_type=expert_type,
        mini_batch_size=batch_size,
        authenticator=authenticator,
    )

    trainer = HebbianTrainer(config, worker_id=worker_id)
    trainer.endocrine.memory_threshold_pct = memory_threshold_pct
    trainer.shared_samples = shared_samples
    trainer.shared_syncs = shared_syncs
    trainer.shared_progress = shared_progress
    trainer.shared_experts = shared_experts

    print(f"[Worker-{worker_id}] Starting (will request expert from coordinator)")
    try:
        asyncio.run(trainer.run())
    except KeyboardInterrupt:
        print(f"[Worker-{worker_id}] Shutdown")


def run_parallel_workers(num_workers: int, server_url: str,
                         lr: float, decay: float, sync_interval: int,
                         expert_type: str, batch_size: int, memory_threshold_pct: float,
                         cpu_mode: str = "adaptive", cpu_fixed_pct: float = 50.0,
                         cpu_backoff_threshold: float = 70.0, authenticator: str = "",
                         auto_assign: bool = True):
    """Launch multiple expert workers in parallel using multiprocessing."""
    import multiprocessing as mp

    # Limit workers based on RAM threshold - snap to nearest valid count
    cpu_count = os.cpu_count() or 4
    max_workers_by_ram = max(1, round(cpu_count * (memory_threshold_pct / 100)))
    if num_workers > max_workers_by_ram:
        print(f"[RAM Limit] Reducing workers from {num_workers} to {max_workers_by_ram} (RAM threshold {memory_threshold_pct:.0f}% of {cpu_count} CPUs)")
        num_workers = max_workers_by_ram

    # Shared counters for stats
    shared_samples = mp.Value('i', 0)
    shared_syncs = mp.Value('i', 0)
    # Per-worker progress and expert arrays
    shared_progress = mp.Array('i', num_workers)  # samples_since_sync per worker
    shared_experts = mp.Array('i', [-1] * num_workers)  # expert_idx per worker

    # Get system RAM info
    total_ram = psutil.virtual_memory().total / (1024**3)

    # CPU mode description
    if cpu_mode == "fixed":
        cpu_desc = f"Fixed {cpu_fixed_pct:.0f}%"
    else:
        cpu_desc = f"Adaptive (back off > {cpu_backoff_threshold:.0f}%)"

    assign_mode = "Load-balanced (coordinator assigns)" if auto_assign else "Sequential"

    print("=" * 60)
    print("  Axiom PARALLEL Streaming Client - Hebbian Learning")
    print("  'Neurons that fire together, wire together'")
    print("=" * 60)
    print(f"  Workers: {num_workers}")
    print(f"  Expert Assignment: {assign_mode}")
    print(f"  Expert Type: {expert_type.upper()}")
    print(f"  System RAM: {total_ram:.1f}GB")
    print(f"  Max Workers: {memory_threshold_pct:.0f}% of CPU threads")
    print(f"  CPU Mode: {cpu_desc}")
    print(f"  Server: {server_url or 'OFFLINE'}")
    print("=" * 60)
    print()

    # Check for compressed files in contribute folder and warn
    contribute_path = Path.home() / "Axiom" / "contribute"
    if contribute_path.exists():
        compressed_exts = ('.zip', '.gz', '.gzip', '.bz2', '.xz', '.7z', '.rar', '.tar',
                          '.tgz', '.tbz2', '.txz', '.lz', '.lzma', '.lz4', '.zst', '.zstd')
        for filepath in contribute_path.rglob('*'):
            if filepath.is_file() and filepath.suffix.lower() in compressed_exts:
                print(f"[Warning] Skipping compressed file (extract first): {filepath.name}")

    # Spawn worker processes - all monitor system memory %
    processes = []
    for i in range(num_workers):
        p = mp.Process(
            target=run_single_worker,
            args=(i, server_url, lr, decay, sync_interval,
                  expert_type, batch_size, memory_threshold_pct, authenticator,
                  auto_assign, shared_samples, shared_syncs,
                  shared_progress, shared_experts)
        )
        p.start()
        processes.append(p)
        print(f"[Main] Launched worker {i} (PID: {p.pid})")

    # Track suspended workers for CPU management
    suspended = set()
    active_count = num_workers

    # Monitor and wait
    last_status_time = time.time()
    try:
        while any(p.is_alive() for p in processes):
            time.sleep(0.5)
            cpu = psutil.cpu_percent(interval=0.5)
            mem = psutil.virtual_memory()

            alive_workers = [(i, p) for i, p in enumerate(processes) if p.is_alive()]
            active_workers = [(i, p) for i, p in alive_workers if i not in suspended]

            # Log status every 10 seconds
            if time.time() - last_status_time >= 10:
                last_status_time = time.time()
                print(f"[Monitor] CPU: {cpu:.0f}%, MEM: {mem.percent:.0f}% (threshold: {memory_threshold_pct:.0f}%), Active: {len(active_workers)}, Suspended: {len(suspended)}")

            # CPU management based on mode
            if cpu_mode == "adaptive":
                # Adaptive: back off when system is busy
                if cpu > cpu_backoff_threshold and len(active_workers) > 1:
                    # Suspend some workers to reduce CPU
                    for i, p in active_workers[len(active_workers)//2:]:
                        if i not in suspended:
                            try:
                                psutil.Process(p.pid).suspend()
                                suspended.add(i)
                                print(f"[CPU] Suspended worker {i} (system CPU {cpu:.0f}% > {cpu_backoff_threshold:.0f}%)")
                            except:
                                pass
                elif cpu < cpu_backoff_threshold - 10 and suspended:
                    # Resume workers when CPU is lower
                    for i in list(suspended):
                        if processes[i].is_alive():
                            try:
                                psutil.Process(processes[i].pid).resume()
                                suspended.discard(i)
                                print(f"[CPU] Resumed worker {i} (system CPU {cpu:.0f}%)")
                            except:
                                suspended.discard(i)

            elif cpu_mode == "fixed":
                # Fixed: try to maintain target CPU
                target = cpu_fixed_pct
                if cpu > target + 10 and len(active_workers) > 1:
                    # Too high - suspend a worker
                    for i, p in active_workers[-1:]:
                        if i not in suspended:
                            try:
                                psutil.Process(p.pid).suspend()
                                suspended.add(i)
                                print(f"[CPU] Suspended worker {i} (CPU {cpu:.0f}% > target {target:.0f}%)")
                            except:
                                pass
                elif cpu < target - 10 and suspended:
                    # Too low - resume a worker
                    for i in list(suspended)[:1]:
                        if processes[i].is_alive():
                            try:
                                psutil.Process(processes[i].pid).resume()
                                suspended.discard(i)
                                print(f"[CPU] Resumed worker {i} (CPU {cpu:.0f}% < target {target:.0f}%)")
                            except:
                                suspended.discard(i)


            samples = shared_samples.value
            syncs = shared_syncs.value
            status = f"[Monitor] CPU: {cpu:.0f}% | RAM: {mem.percent:.0f}% ({mem.used/(1024**3):.1f}/{mem.total/(1024**3):.1f}GB) | Samples: {samples:,} | Syncs: {syncs}"
            if suspended:
                status += f" | Suspended: {len(suspended)}"
            print(status)

            # Output per-worker progress for UI to parse
            for i in range(num_workers):
                prog = shared_progress[i]
                expert = shared_experts[i]
                print(f"[Worker-{i}] Expert: {expert} | Progress: {prog}/{sync_interval}")

    except KeyboardInterrupt:
        print("\n[Main] Shutting down workers...")
        # Resume any suspended workers before terminating
        for i in suspended:
            if processes[i].is_alive():
                try:
                    psutil.Process(processes[i].pid).resume()
                except:
                    pass
        for p in processes:
            p.terminate()

    for p in processes:
        p.join()
    print("[Main] All workers finished")


# =============================================================================
# Main Entry Point
# =============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="Axiom Streaming Client - Async Hebbian Learning"
    )
    parser.add_argument(
        "--expert", "-e", type=int, default=0,
        help="Expert index to train (0-419)"
    )
    parser.add_argument(
        "--server", "-s", type=str, default="ws://65.21.196.61:8765",
        help="WebSocket server URL for weight sync"
    )
    parser.add_argument(
        "--lr", type=float, default=0.001,
        help="Hebbian learning rate"
    )
    parser.add_argument(
        "--decay", type=float, default=0.999,
        help="Weight decay factor"
    )
    parser.add_argument(
        "--sync-interval", type=int, default=1000,
        help="Samples between weight syncs"
    )
    parser.add_argument(
        "--offline", action="store_true",
        help="Run in offline mode (no server connection)"
    )
    parser.add_argument(
        "--expert-type", type=str, default="transformer", choices=["mlp", "transformer"],
        help="Expert architecture: mlp (fast) or transformer (42.6M params for 17.8B model)"
    )
    parser.add_argument(
        "--batch-size", type=int, default=1,
        help="Mini-batch size for Hebbian updates (1 = true sample-by-sample)"
    )
    parser.add_argument(
        "--parallel", "-p", type=int, default=1,
        help="Number of parallel expert workers (uses multiple CPU cores)"
    )
    parser.add_argument(
        "--memory-threshold", type=float, default=85.0,
        help="System memory %% that triggers GC (default: 85%%)"
    )
    parser.add_argument(
        "--cpu-mode", type=str, default="adaptive", choices=["fixed", "adaptive"],
        help="CPU usage mode: fixed (constant %%) or adaptive (yield when busy)"
    )
    parser.add_argument(
        "--cpu-fixed-pct", type=float, default=50.0,
        help="For fixed mode: target CPU percentage (default: 50%%)"
    )
    parser.add_argument(
        "--cpu-backoff-threshold", type=float, default=70.0,
        help="For adaptive mode: back off when system CPU exceeds this %% (default: 70%%)"
    )
    parser.add_argument(
        "--authenticator", type=str, default="",
        help="User authenticator for credit tracking"
    )
    parser.add_argument(
        "--no-auto-assign", action="store_true",
        help="Disable load-balanced expert assignment (use sequential instead)"
    )

    args = parser.parse_args()

    server_url = args.server if not args.offline else ""

    # Parallel mode: spawn multiple workers
    if args.parallel > 1:
        run_parallel_workers(
            num_workers=args.parallel,
            server_url=server_url,
            lr=args.lr,
            decay=args.decay,
            sync_interval=args.sync_interval,
            expert_type=args.expert_type,
            batch_size=args.batch_size,
            memory_threshold_pct=args.memory_threshold,
            cpu_mode=args.cpu_mode,
            cpu_fixed_pct=args.cpu_fixed_pct,
            cpu_backoff_threshold=args.cpu_backoff_threshold,
            authenticator=args.authenticator,
            auto_assign=not args.no_auto_assign
        )
        return

    # Single worker mode
    config = ClientConfig(
        expert_idx=args.expert if args.no_auto_assign else -1,
        auto_assign=not args.no_auto_assign,
        server_url=server_url,
        hebbian_lr=args.lr,
        hebbian_decay=args.decay,
        sync_interval=args.sync_interval,
        expert_type=args.expert_type,
        mini_batch_size=args.batch_size,
        authenticator=args.authenticator,
    )

    print("=" * 60)
    print("  Axiom Streaming Client - Hebbian Learning")
    print("  'Neurons that fire together, wire together'")
    print("=" * 60)
    print(f"  Expert: {config.expert_idx}")
    print(f"  Learning Rate: {config.hebbian_lr}")
    print(f"  Weight Decay: {config.hebbian_decay}")
    print(f"  Sync Interval: {config.sync_interval} samples")
    print(f"  Expert Type: {config.expert_type.upper()}")
    print(f"  Mini-batch: {config.mini_batch_size}")
    print(f"  Max Workers: {args.memory_threshold:.0f}% of CPU threads")
    print(f"  Server: {config.server_url or 'OFFLINE'}")
    print(f"  Data Path: {config.contribute_path}")
    print("=" * 60)
    print()

    trainer = HebbianTrainer(config)
    trainer.endocrine.memory_threshold_pct = args.memory_threshold

    try:
        asyncio.run(trainer.run())
    except KeyboardInterrupt:
        print("\nShutdown complete.")


if __name__ == "__main__":
    main()
