#!/usr/bin/env bash
# trueref launcher (workspace root)
#
# Wraps the fat JAR with:
#   - --enable-native-access=ALL-UNNAMED  (silences FFM Linker warning from DJL tokenizers)
#   - --add-modules=jdk.incubator.vector  (enables Lucene 10 SIMD codepath)
#   - cuDNN 9 (cu12 build) on LD_LIBRARY_PATH so ONNX Runtime CUDA EP loads
#   - CUDA_VISIBLE_DEVICES isolation so ORT BFC arena doesn't trip over the second card
#   - per-session GPU memory cap so embedder + reranker fit on one card
#
# Defaults are tuned for this machine (LMDE 7, CUDA 12.4, RTX 2080 SUPER + RTX 3060).
# Override anything via env vars or by appending Spring properties to the command line.
#
# Usage:
#   ./trueref                                       # run with defaults (port 18080)
#   ./trueref --server.port=8080                    # forward Spring properties
#   TRUEREF_GPU=0 ./trueref                         # use the 2080 SUPER instead#   TRUEREF_GPU=cpu ./trueref                       # disable CUDA, run on CPU
#   TRUEREF_HOME=/data/trueref ./trueref            # custom data dir
#
# Env vars:
#   TRUEREF_GPU         GPU index (matches `nvidia-smi -L`) or "cpu". Default: 1
#   TRUEREF_HOME        Data directory. Default: ./data
#   TRUEREF_PORT        HTTP port. Default: 18080
#   TRUEREF_MEM_LIMIT   Per-session GPU mem cap in bytes. Default: 0 (unbounded).
#                       With session-count=1 there is no multi-session contention, so the BFC
#                       arena can grow freely — capping it risks running out of budget during
#                       model-weight loading (~1.5-2 GB) before inference even starts.
#                       Set to e.g. 8589934592 (8 GiB) only if you run multiple pools on one card.
#   TRUEREF_CUDNN_LIB   Directory containing libcudnn.so.9. Default: ./runtime/cudnn/nvidia/cudnn/lib
#   TRUEREF_JAR         Path to the fat JAR. Default: ./trueref-bootstrap/target/trueref.jar
#   JAVA                java binary. Default: $JAVA_HOME/bin/java or `java` on PATH
#   JAVA_OPTS           Extra JVM flags (e.g. -Xmx16g)

set -euo pipefail

ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"

JAR="${TRUEREF_JAR:-$ROOT/trueref-bootstrap/target/trueref.jar}"
GPU="${TRUEREF_GPU:-1}"
HOME_DIR="${TRUEREF_HOME:-$ROOT/data}"
PORT="${TRUEREF_PORT:-18080}"
MEM_LIMIT="${TRUEREF_MEM_LIMIT:-0}"
CUDNN_LIB="${TRUEREF_CUDNN_LIB:-$ROOT/runtime/cudnn/nvidia/cudnn/lib}"

if [[ ! -f "$JAR" ]]; then
  echo "trueref: jar not found at $JAR" >&2
  echo "trueref: build it first with: mvn -DskipTests -pl trueref-bootstrap -am package" >&2
  exit 1
fi

# Resolve java
if [[ -n "${JAVA:-}" ]]; then
  :
elif [[ -n "${JAVA_HOME:-}" && -x "${JAVA_HOME}/bin/java" ]]; then
  JAVA="${JAVA_HOME}/bin/java"
else
  JAVA="$(command -v java || true)"
fi
if [[ -z "${JAVA:-}" || ! -x "${JAVA}" ]]; then
  echo "trueref: java not found; install JDK 21+ or set JAVA_HOME" >&2
  exit 1
fi

mkdir -p "$HOME_DIR"

SPRING_ARGS=(
  "--server.port=$PORT"
  "--trueref.home=$HOME_DIR"
)

# CUDA setup. "cpu" disables CUDA entirely; otherwise pass the physical GPU index
# directly to ORT. ORT's CUDA EP uses the physical device index regardless of
# CUDA_VISIBLE_DEVICES remapping — so we pass the physical index and explicitly
# unset CUDA_VISIBLE_DEVICES to avoid the two-layer renumbering problem where
# CUDA runtime remaps N→0 but ORT still expects the physical N.
if [[ "$GPU" == "cpu" || "$GPU" == "CPU" ]]; then
  echo "trueref: GPU disabled (TRUEREF_GPU=cpu) — embedder/reranker will run on CPU"
  SPRING_ARGS+=(
    "--trueref.embedding.onnx-providers=cpu"
  )
else
  if [[ -d "$CUDNN_LIB" ]]; then
    export LD_LIBRARY_PATH="${CUDNN_LIB}${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
  else
    echo "trueref: TRUEREF_CUDNN_LIB=$CUDNN_LIB not found — CUDA EP will fall back to CPU" >&2
    echo "trueref: download cu12 cuDNN with:" >&2
    echo "  mkdir -p runtime/cudnn && cd runtime/cudnn && \\" >&2
    echo "    pip download --no-deps --only-binary=:all: --python-version 3.12 \\" >&2
    echo "    --platform manylinux2014_x86_64 'nvidia-cudnn-cu12<10' -d . && \\" >&2
    echo "    unzip -q -o nvidia_cudnn_cu12-*.whl 'nvidia/cudnn/lib/*' && rm *.whl" >&2
  fi
  # CUDA runtime respects CUDA_VISIBLE_DEVICES for all allocations (cudaMalloc, BFC arena,
  # etc.). By restricting CUDA's view to exactly the target GPU, we prevent the runtime from
  # creating a default context on device 0 before ORT's cudaSetDevice() takes effect.
  # We always pass gpu-device-id=0 to ORT because CUDA_VISIBLE_DEVICES makes the target
  # card the ONLY visible device (index 0 in the runtime's view).
  #
  # CUDA_DEVICE_ORDER=PCI_BUS_ID ensures CUDA runtime numbering matches nvidia-smi numbering.
  # Without it, the default FASTEST_FIRST ordering can rank GPUs differently from nvidia-smi,
  # so CUDA_VISIBLE_DEVICES=N would expose a different physical card than nvidia-smi GPU N.
  export CUDA_DEVICE_ORDER="PCI_BUS_ID"
  export CUDA_VISIBLE_DEVICES="$GPU"
  SPRING_ARGS+=(
    "--trueref.embedding.gpu-device-id=0"
    "--trueref.embedding.gpu-mem-limit-bytes=$MEM_LIMIT"
  )
fi

exec "$JAVA" \
  --enable-native-access=ALL-UNNAMED \
  --add-modules=jdk.incubator.vector \
  ${JAVA_OPTS:-} \
  -jar "$JAR" \
  "${SPRING_ARGS[@]}" \
  "$@"
