diff --git a/.github/workflows/integration-test-k8s.yml b/.github/workflows/integration-test-k8s.yml
new file mode 100644
index 00000000..5555f7c2
--- /dev/null
+++ b/.github/workflows/integration-test-k8s.yml
@@ -0,0 +1,450 @@
+# SPDX-FileCopyrightText: 2025 Nextcloud GmbH and Nextcloud contributors
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+name: Integration test k8s
+
+on:
+ pull_request:
+ push:
+ branches:
+ - master
+ - stable*
+
+permissions:
+ contents: read
+
+concurrency:
+ group: integration-test-k8s-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+
+jobs:
+ changes:
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ pull-requests: read
+
+ outputs:
+ src: ${{ steps.changes.outputs.src}}
+
+ steps:
+ - uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
+ id: changes
+ continue-on-error: true
+ with:
+ filters: |
+ src:
+ - 'main.py'
+ - 'main_em.py'
+ - 'config.cpu.yaml'
+ - 'config.gpu.yaml'
+ - 'context_chat_backend/**'
+ - 'appinfo/**'
+ - 'example.env'
+ - 'hwdetect.sh'
+ - 'persistent_storage/**'
+ - 'project.toml'
+ - 'requirements.txt'
+ - 'logger_config.k8s.yaml'
+ - 'supervisord.conf'
+ - '.github/workflows/integration-test-k8s.yml'
+
+ integration:
+ runs-on: ubuntu-24.04
+
+ needs: changes
+ if: needs.changes.outputs.src != 'false'
+
+ strategy:
+ # do not stop on another job's failure
+ fail-fast: false
+ matrix:
+ php-versions: [ '8.2' ]
+ databases: [ 'pgsql' ]
+ server-versions: [ 'master' ]
+
+ name: Integration test k8s on ${{ matrix.server-versions }} php@${{ matrix.php-versions }}
+
+ env:
+ MYSQL_PORT: 4444
+ PGSQL_PORT: 4445
+ HP_SHARED_KEY: test_shared_key_12345
+
+ services:
+ mysql:
+ image: mariadb:10.5
+ ports:
+ - 4444:3306/tcp
+ env:
+ MYSQL_ROOT_PASSWORD: rootpassword
+ options: --health-cmd="mysqladmin ping" --health-interval 5s --health-timeout 2s --health-retries 5
+ # use the same db for ccb and nextcloud
+ postgres:
+ image: pgvector/pgvector:pg17
+ ports:
+ - 4445:5432/tcp
+ env:
+ POSTGRES_USER: root
+ POSTGRES_PASSWORD: rootpassword
+ POSTGRES_DB: nextcloud
+ options: --health-cmd pg_isready --health-interval 5s --health-timeout 2s --health-retries 5 --name postgres --hostname postgres
+
+ steps:
+ - name: Checkout server
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
+ with:
+ repository: nextcloud/server
+ ref: ${{ matrix.server-versions }}
+ submodules: 'recursive'
+ persist-credentials: false
+
+ - name: Set up php ${{ matrix.php-versions }}
+ uses: shivammathur/setup-php@9e72090525849c5e82e596468b86eb55e9cc5401 # v2
+ with:
+ php-version: ${{ matrix.php-versions }}
+ tools: phpunit
+ extensions: mbstring, iconv, fileinfo, intl, sqlite, pdo_mysql, pdo_sqlite, pgsql, pdo_pgsql, gd, zip
+
+ - name: Checkout context_chat php app
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
+ with:
+ repository: nextcloud/context_chat
+ path: apps/context_chat
+ persist-credentials: false
+
+ - name: Checkout backend
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
+ with:
+ path: context_chat_backend/
+ persist-credentials: false
+
+ - name: Checkout app_api
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
+ with:
+ repository: nextcloud/app_api
+ ref: ${{ matrix.server-versions == 'master' && 'main' || matrix.server-versions }}
+ path: apps/app_api
+ persist-credentials: false
+
+ - name: Get app version
+ id: appinfo
+ uses: skjnldsv/xpath-action@7e6a7c379d0e9abc8acaef43df403ab4fc4f770c # master
+ with:
+ filename: context_chat_backend/appinfo/info.xml
+ expression: "/info/version/text()"
+
+ - name: Set up Nextcloud MYSQL
+ if: ${{ matrix.databases != 'pgsql'}}
+ run: |
+ sleep 25
+ mkdir data
+ ./occ maintenance:install --verbose --database=${{ matrix.databases }} --database-name=nextcloud --database-host=127.0.0.1 --database-port=$MYSQL_PORT --database-user=root --database-pass=rootpassword --admin-user admin --admin-pass password
+
+ - name: Set up Nextcloud PGSQL
+ if: ${{ matrix.databases == 'pgsql'}}
+ run: |
+ sleep 25
+ mkdir data
+ ./occ maintenance:install --verbose --database=${{ matrix.databases }} --database-name=nextcloud --database-host=127.0.0.1 --database-port=$PGSQL_PORT --database-user=root --database-pass=rootpassword --admin-user admin --admin-pass password
+
+ - name: Enable context_chat, app_api and testing
+ run: ./occ app:enable -vvv -f context_chat app_api testing
+
+ - name: Checkout documentation
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
+ with:
+ repository: nextcloud/documentation
+ path: data/admin/files/documentation
+ persist-credentials: false
+
+ - name: Prepare docs
+ run: |
+ cd data/admin/files
+ mv documentation/admin_manual .
+ cp -R documentation/developer_manual .
+ cd developer_manual
+ find . -type f -name "*.rst" -exec bash -c 'mv "$0" "${0%.rst}.md"' {} \;
+ cd ..
+ cp -R documentation/developer_manual ./developer_manual2
+ cd developer_manual2
+ find . -type f -name "*.rst" -exec bash -c 'mv "$0" "${0%.rst}.txt"' {} \;
+ cd ..
+ rm -rf documentation
+
+ - name: Run files scan
+ run: |
+ ./occ files:scan --all
+
+ - name: Install k3s
+ run: |
+ curl -sfL https://get.k3s.io | INSTALL_K3S_EXEC="--disable traefik --disable servicelb --kubelet-arg=container-log-max-size=0" sh -
+ sudo chmod 644 /etc/rancher/k3s/k3s.yaml
+ echo "KUBECONFIG=/etc/rancher/k3s/k3s.yaml" >> $GITHUB_ENV
+
+ - name: Wait for k3s and create namespace
+ run: |
+ kubectl wait --for=condition=Ready node --all --timeout=120s
+ kubectl create namespace nextcloud-exapps
+ NODE_IP=$(kubectl get node -o jsonpath='{.items[0].status.addresses[?(@.type=="InternalIP")].address}')
+ echo "NODE_IP=${NODE_IP}" >> $GITHUB_ENV
+ echo "k3s node IP: $NODE_IP"
+
+ - name: Configure Nextcloud for k3s networking
+ run: |
+ ./occ config:system:set overwrite.cli.url --value "http://${{ env.NODE_IP }}" --type=string
+ ./occ config:system:set trusted_domains 1 --value "${{ env.NODE_IP }}"
+
+ - name: Create K8s service account for HaRP
+ run: |
+ kubectl -n nextcloud-exapps create serviceaccount harp-sa
+ kubectl create clusterrolebinding harp-admin \
+ --clusterrole=cluster-admin \
+ --serviceaccount=nextcloud-exapps:harp-sa
+ K3S_TOKEN=$(kubectl -n nextcloud-exapps create token harp-sa --duration=2h)
+ echo "K3S_TOKEN=${K3S_TOKEN}" >> $GITHUB_ENV
+
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3
+
+ - name: Login to GitHub Container Registry
+ uses: docker/login-action@184bdaa0721073962dff0199f1fb9940f07167d1 # v3
+ with:
+ registry: ghcr.io
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Build the context_chat_backend cpu image
+ uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6
+ with:
+ context: context_chat_backend
+ push: false
+ platforms: linux/amd64
+ # use local tag so image is not pulled from remote
+ tags: ghcr.io/ccb-cpu:local
+ target: runtime-cpu
+ load: true
+ cache-from: type=gha
+ cache-to: type=gha,mode=max
+
+ - name: Pre-load CCB ExApp image into k3s
+ run: docker save ghcr.io/ccb-cpu:local | sudo k3s ctr images import -
+
+ - name: Start HaRP with K8s backend
+ run: |
+ docker run --net host --name appapi-harp \
+ -e HP_SHARED_KEY="${{ env.HP_SHARED_KEY }}" \
+ -e NC_INSTANCE_URL="http://${{ env.NODE_IP }}" \
+ -e HP_LOG_LEVEL="debug" \
+ -e HP_K8S_ENABLED="true" \
+ -e HP_K8S_API_SERVER="https://127.0.0.1:6443" \
+ -e HP_K8S_BEARER_TOKEN="${{ env.K3S_TOKEN }}" \
+ -e HP_K8S_NAMESPACE="nextcloud-exapps" \
+ -e HP_K8S_VERIFY_SSL="false" \
+ --restart unless-stopped \
+ -d ghcr.io/nextcloud/nextcloud-appapi-harp:latest
+
+ - name: Start nginx proxy
+ run: |
+ docker run --net host --name nextcloud --rm \
+ -v $(pwd)/apps/app_api/tests/simple-nginx-NOT-FOR-PRODUCTION.conf:/etc/nginx/conf.d/default.conf:ro \
+ -d nginx
+
+ - name: Start nextcloud
+ run: PHP_CLI_SERVER_WORKERS=2 php -S 0.0.0.0:8080 &
+
+ - name: Wait for HaRP K8s readiness
+ run: |
+ for i in $(seq 1 30); do
+ if curl -sf http://${{ env.NODE_IP }}:8780/exapps/app_api/info \
+ -H "harp-shared-key: ${{ env.HP_SHARED_KEY }}" 2>/dev/null | grep -q '"kubernetes"'; then
+ echo "HaRP is ready with K8s backend"
+ exit 0
+ fi
+ echo "Waiting for HaRP... ($i/30)"
+ sleep 2
+ done
+ echo "HaRP K8s readiness check failed"
+ docker logs appapi-harp
+ exit 1
+
+ - name: Register K8s daemon
+ run: |
+ ./occ app_api:daemon:register \
+ k8s_test "K8s Test" "kubernetes-install" "http" "${{ env.NODE_IP }}:8780" "http://${{ env.NODE_IP }}" \
+ --harp --harp_shared_key "${{ env.HP_SHARED_KEY }}" \
+ --k8s --k8s_expose_type=nodeport --set-default
+ ./occ app_api:daemon:list
+
+ - name: Register backend
+ run: |
+ sed -i 's;.*;ccb-cpu;' context_chat_backend/appinfo/info.xml
+ sed -i 's;.*;local;' context_chat_backend/appinfo/info.xml
+ timeout 120 ./occ app_api:app:register context_chat_backend k8s_test \
+ --info-xml context_chat_backend/appinfo/info.xml \
+ --env EXTERNAL_DB="postgresql+psycopg://root:rootpassword@${{ env.NODE_IP }}:4445/nextcloud" \
+ --wait-finish
+
+ - name: Run cron jobs
+ run: |
+ # every 10 seconds indefinitely
+ while true; do
+ php cron.php
+ sleep 10
+ done &
+ sleep 30
+ # list all the bg jobs
+ ./occ background-job:list
+
+ - name: Initial dump of DB with context_chat_queue populated
+ if: always()
+ run: |
+ docker exec postgres pg_dump nextcloud > /tmp/0_pgdump_nextcloud
+
+ - name: Periodically check context_chat stats for 15 minutes to allow the backend to index the files
+ run: |
+ success=0
+ echo "::group::Checking stats periodically for 15 minutes to allow the backend to index the files"
+ for i in {1..90}; do
+ echo "Checking stats, attempt $i..."
+
+ stats_err=$(mktemp)
+ stats_exit=0
+ stats=$(timeout 30 ./occ context_chat:stats --json 2>"$stats_err") || stats_exit=$?
+ echo "Stats output:"
+ echo "$stats"
+ if [ -s "$stats_err" ]; then
+ echo "Stderr:"
+ cat "$stats_err"
+ fi
+ echo "---"
+ rm -f "$stats_err"
+
+ # Check for critical errors in output
+ if [ $stats_exit -ne 0 ] || echo "$stats" | grep -q "Error during request"; then
+ echo "Backend connection error detected (exit=$stats_exit), retrying..."
+ sleep 10
+ continue
+ fi
+
+ # Extract total eligible files
+ total_eligible_files=$(echo "$stats" | jq '.eligible_files_count' || echo "")
+
+ # Extract indexed documents count (files__default)
+ indexed_count=$(echo "$stats" | jq '.vectordb_document_counts.files__default' || echo "")
+
+ echo "Total eligible files: $total_eligible_files"
+ echo "Indexed documents (files__default): $indexed_count"
+
+ diff=$((total_eligible_files - indexed_count))
+ threshold=$((total_eligible_files * 3 / 100))
+
+ # Check if difference is within tolerance
+ if [ $diff -le $threshold ]; then
+ echo "Indexing within 3% tolerance (diff=$diff, threshold=$threshold)"
+ success=1
+ break
+ else
+ progress=$((diff * 100 / total_eligible_files))
+ echo "Outside 3% tolerance: diff=$diff (${progress}%), threshold=$threshold"
+ fi
+
+ sleep 10
+ done
+
+ echo "::endgroup::"
+
+ if [ $success -ne 1 ]; then
+ echo "Max attempts reached"
+ exit 1
+ fi
+
+ - name: Run the prompts
+ run: |
+ ./occ background-job:worker 'OC\TaskProcessing\SynchronousBackgroundJob' > worker1_logs 2>&1 &
+ ./occ background-job:worker 'OC\TaskProcessing\SynchronousBackgroundJob' > worker2_logs 2>&1 &
+
+ echo ::group::English prompt
+ OUT1=$(./occ context_chat:prompt admin "Which factors are taken into account for the Ethical AI Rating?")
+ echo "$OUT1"
+ echo "$OUT1" | grep -q "If all of these points are met, we give a Green label." || exit 1
+ echo ::endgroup::
+
+ echo ::group::German prompt
+ OUT2=$(./occ context_chat:prompt admin "Welche Faktoren beeinflussen das Ethical AI Rating?")
+ echo "$OUT2"
+ echo "$OUT2" | grep -q "If all of these points are met, we give a Green label." || exit 1
+ echo ::endgroup::
+
+ - name: Final dump of DB with vectordb populated
+ if: always()
+ run: |
+ docker exec postgres pg_dump nextcloud > /tmp/1_pgdump_nextcloud
+
+ - name: Show server logs
+ if: always()
+ run: |
+ cat data/nextcloud.log
+
+ - name: Show context_chat specific logs
+ if: always()
+ run: |
+ cat data/context_chat.log
+
+ - name: Show task processing worker logs
+ if: always()
+ run: |
+ tail -v -n +1 worker?_logs || echo "No worker logs"
+
+ - name: Show HaRP logs
+ if: always()
+ run: |
+ docker logs appapi-harp
+
+ - name: Show main app indexing logs
+ if: always()
+ run: |
+ kubectl logs -n nextcloud-exapps -l app=nc-app-context-chat-backend-indexing --prefix --tail=-1 --ignore-errors
+
+ - name: Show main app updates processing logs
+ if: always()
+ run: |
+ kubectl logs -n nextcloud-exapps -l app=nc-app-context-chat-backend-updatesproc --prefix --tail=-1 --ignore-errors
+
+ - name: Show main app request processing logs
+ if: always()
+ run: |
+ kubectl logs -n nextcloud-exapps -l app=nc-app-context-chat-backend-requestproc --prefix --tail=-1 --ignore-errors
+
+ - name: Upload database dumps
+ uses: actions/upload-artifact@v4
+ if: always()
+ with:
+ name: database-dumps-${{ matrix.server-versions }}-php@${{ matrix.php-versions }}
+ path: |
+ /tmp/0_pgdump_nextcloud
+ /tmp/1_pgdump_nextcloud
+
+ - name: Final stats log
+ if: always()
+ run: |
+ ./occ context_chat:stats
+ ./occ context_chat:stats --json
+
+ summary:
+ permissions:
+ contents: none
+ runs-on: ubuntu-latest-low
+ needs: [changes, integration]
+
+ if: always()
+
+ # This is the summary, we just avoid to rename it so that branch protection rules still match
+ name: integration-test-k8s
+
+ steps:
+ - name: Summary status
+ run: if ${{ needs.changes.outputs.src != 'false' && needs.integration.result != 'success' }}; then exit 1; fi
diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml
index 10e2d61b..4a6123c4 100644
--- a/.github/workflows/integration-test.yml
+++ b/.github/workflows/integration-test.yml
@@ -89,7 +89,7 @@ jobs:
POSTGRES_USER: root
POSTGRES_PASSWORD: rootpassword
POSTGRES_DB: nextcloud
- options: --health-cmd pg_isready --health-interval 5s --health-timeout 2s --health-retries 5
+ options: --health-cmd pg_isready --health-interval 5s --health-timeout 2s --health-retries 5 --name postgres --hostname postgres
steps:
- name: Checkout server
@@ -120,6 +120,14 @@ jobs:
path: context_chat_backend/
persist-credentials: false
+ - name: Checkout app_api
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
+ with:
+ repository: nextcloud/app_api
+ ref: ${{ matrix.server-versions == 'master' && 'main' || matrix.server-versions }}
+ path: apps/app_api
+ persist-credentials: false
+
- name: Get app version
id: appinfo
uses: skjnldsv/xpath-action@7e6a7c379d0e9abc8acaef43df403ab4fc4f770c # master
@@ -167,6 +175,10 @@ jobs:
cd ..
rm -rf documentation
+ - name: Run files scan
+ run: |
+ ./occ files:scan --all
+
- name: Setup python 3.11
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5
with:
@@ -195,42 +207,109 @@ jobs:
timeout 10 ./occ app_api:daemon:register --net host manual_install "Manual Install" manual-install http localhost http://localhost:8080
timeout 120 ./occ app_api:app:register context_chat_backend manual_install --json-info "{\"appid\":\"context_chat_backend\",\"name\":\"Context Chat Backend\",\"daemon_config_name\":\"manual_install\",\"version\":\"${{ fromJson(steps.appinfo.outputs.result).version }}\",\"secret\":\"12345\",\"port\":10034,\"scopes\":[],\"system_app\":0}" --force-scopes --wait-finish
ls -la context_chat_backend/persistent_storage/*
- sleep 30 # Wait for the em server to get ready
-
- - name: Scan files, baseline
- run: |
- ./occ files:scan admin
- ./occ context_chat:scan admin -m text/plain
- - name: Check python memory usage
+ - name: Initial memory usage check
run: |
ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem
ps -p $(cat pid.txt) -o %mem --no-headers > initial_mem.txt
- - name: Scan files
+ - name: Run cron jobs
+ run: |
+ # every 10 seconds indefinitely
+ while true; do
+ php cron.php
+ sleep 10
+ done &
+ sleep 30
+ # list all the bg jobs
+ ./occ background-job:list
+
+ - name: Initial dump of DB with context_chat_queue populated
+ if: always()
run: |
- ./occ files:scan admin
- ./occ context_chat:scan admin -m text/markdown &
- ./occ context_chat:scan admin -m text/x-rst
+ docker exec postgres pg_dump nextcloud > /tmp/0_pgdump_nextcloud
- - name: Check python memory usage
+ - name: Periodically check context_chat stats for 15 minutes to allow the backend to index the files
run: |
- ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem
- ps -p $(cat pid.txt) -o %mem --no-headers > after_scan_mem.txt
+ success=0
+ echo "::group::Checking stats periodically for 15 minutes to allow the backend to index the files"
+ for i in {1..90}; do
+ echo "Checking stats, attempt $i..."
+
+ stats_err=$(mktemp)
+ stats_exit=0
+ stats=$(timeout 30 ./occ context_chat:stats --json 2>"$stats_err") || stats_exit=$?
+ echo "Stats output:"
+ echo "$stats"
+ if [ -s "$stats_err" ]; then
+ echo "Stderr:"
+ cat "$stats_err"
+ fi
+ echo "---"
+ rm -f "$stats_err"
+
+ # Check for critical errors in output
+ if [ $stats_exit -ne 0 ] || echo "$stats" | grep -q "Error during request"; then
+ echo "Backend connection error detected (exit=$stats_exit), retrying..."
+ sleep 10
+ continue
+ fi
+
+ # Extract total eligible files
+ total_eligible_files=$(echo "$stats" | jq '.eligible_files_count' || echo "")
+
+ # Extract indexed documents count (files__default)
+ indexed_count=$(echo "$stats" | jq '.vectordb_document_counts.files__default' || echo "")
+
+ echo "Total eligible files: $total_eligible_files"
+ echo "Indexed documents (files__default): $indexed_count"
+
+ diff=$((total_eligible_files - indexed_count))
+ threshold=$((total_eligible_files * 3 / 100))
+
+ # Check if difference is within tolerance
+ if [ $diff -le $threshold ]; then
+ echo "Indexing within 3% tolerance (diff=$diff, threshold=$threshold)"
+ success=1
+ break
+ else
+ progress=$((diff * 100 / total_eligible_files))
+ echo "Outside 3% tolerance: diff=$diff (${progress}%), threshold=$threshold"
+ fi
+
+ # Check if backend is still alive
+ ccb_alive=$(ps -p $(cat pid.txt) -o cmd= | grep -c "main.py" || echo "0")
+ if [ "$ccb_alive" -eq 0 ]; then
+ echo "Error: Context Chat Backend process is not running. Exiting."
+ exit 1
+ fi
+
+ sleep 10
+ done
+
+ echo "::endgroup::"
+
+ if [ $success -ne 1 ]; then
+ echo "Max attempts reached"
+ exit 1
+ fi
- name: Run the prompts
run: |
./occ background-job:worker 'OC\TaskProcessing\SynchronousBackgroundJob' > worker1_logs 2>&1 &
./occ background-job:worker 'OC\TaskProcessing\SynchronousBackgroundJob' > worker2_logs 2>&1 &
+ echo ::group::English prompt
OUT1=$(./occ context_chat:prompt admin "Which factors are taken into account for the Ethical AI Rating?")
echo "$OUT1"
- echo '--------------------------------------------------'
+ echo "$OUT1" | grep -q "If all of these points are met, we give a Green label." || exit 1
+ echo ::endgroup::
+
+ echo ::group::German prompt
OUT2=$(./occ context_chat:prompt admin "Welche Faktoren beeinflussen das Ethical AI Rating?")
echo "$OUT2"
-
- echo "$OUT1" | grep -q "If all of these points are met, we give a Green label." || exit 1
echo "$OUT2" | grep -q "If all of these points are met, we give a Green label." || exit 1
+ echo ::endgroup::
- name: Check python memory usage
run: |
@@ -250,18 +329,10 @@ jobs:
echo "Memory usage during scan is stable. No memory leak detected."
fi
- - name: Compare memory usage and detect leak
+ - name: Final dump of DB with vectordb populated
+ if: always()
run: |
- initial_mem=$(cat after_scan_mem.txt | tr -d ' ')
- final_mem=$(cat after_prompt_mem.txt | tr -d ' ')
- echo "Initial Memory Usage: $initial_mem%"
- echo "Memory Usage after prompt: $final_mem%"
-
- if (( $(echo "$final_mem > $initial_mem" | bc -l) )); then
- echo "Memory usage has increased during prompt. Possible memory leak detected!"
- else
- echo "Memory usage during prompt is stable. No memory leak detected."
- fi
+ docker exec postgres pg_dump nextcloud > /tmp/1_pgdump_nextcloud
- name: Show server logs
if: always()
@@ -298,6 +369,21 @@ jobs:
run: |
tail -v -n +1 context_chat_backend/persistent_storage/logs/em_server.log* || echo "No logs in logs directory"
+ - name: Upload database dumps
+ uses: actions/upload-artifact@v4
+ if: always()
+ with:
+ name: database-dumps-${{ matrix.server-versions }}-php@${{ matrix.php-versions }}
+ path: |
+ /tmp/0_pgdump_nextcloud
+ /tmp/1_pgdump_nextcloud
+
+ - name: Final stats log
+ if: always()
+ run: |
+ ./occ context_chat:stats
+ ./occ context_chat:stats --json
+
summary:
permissions:
contents: none
diff --git a/Dockerfile b/Dockerfile
index 3430a5ee..63eec9ac 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,24 +1,191 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
-FROM docker.io/nvidia/cuda:12.2.2-runtime-ubuntu22.04
+ARG CPU_IMAGE=ubuntu:22.04
+ARG CUDA_DEVEL_IMAGE=nvidia/cuda:12.4.1-devel-ubuntu22.04
+ARG CUDA_RUNTIME_IMAGE=nvidia/cuda:12.4.1-runtime-ubuntu22.04
+ARG LLAMA_CPP_PYTHON_VERSION=0.3.20
+
+# ============================================================
+# CPU / ARM builder
+# Builds llama_cpp_python for any x86_64 (AVX+, Sandy Bridge 2011+)
+# and for arm64 (NEON always available).
+# ubuntu:22.04 is a multi-arch image so this stage covers both.
+#
+# GGML_NATIVE=OFF: no -march=native; the host build machine's SIMD
+# capabilities are not baked in. AVX/AVX2/FMA/F16C default to ON in
+# llama.cpp cmake and are used when the CPU supports them at runtime
+# (the ggml_cpu_has_*() guards). On arm64 those x86 flags are never
+# emitted by cmake, so NEON/SVE detection remains intact.
+# ============================================================
+FROM ubuntu:22.04 AS llama-builder-cpu
+ARG LLAMA_CPP_PYTHON_VERSION
+
+ENV DEBIAN_FRONTEND=noninteractive
+WORKDIR /build
+ADD dockerfile_scripts/install_py11.sh dockerfile_scripts/install_py11.sh
+RUN ./dockerfile_scripts/install_py11.sh
+# install_py11.sh leaves apt lists in place – install build tools in one layer
+RUN apt-get install -y --no-install-recommends \
+ python3.11-dev \
+ cmake build-essential ninja-build git \
+ libgomp1 \
+ && rm -rf /var/lib/apt/lists/*
+
+RUN python3.11 -m pip install --no-cache-dir --upgrade pip setuptools wheel
+
+ENV CMAKE_ARGS="-DGGML_NATIVE=OFF"
+
+RUN python3.11 -m pip wheel \
+ --no-cache-dir \
+ --no-binary llama-cpp-python \
+ --wheel-dir=/wheels \
+ "llama-cpp-python==${LLAMA_CPP_PYTHON_VERSION}"
+
+# ============================================================
+# CUDA (NVIDIA) builder
+# Builds llama_cpp_python with CUDA support.
+# sm_90 is the maximum compute capability supported by CUDA 12.4
+# (Hopper / H100). Blackwell sm_100 requires CUDA 12.8+.
+# ============================================================
+FROM ${CUDA_DEVEL_IMAGE} AS llama-builder-cuda
+ARG LLAMA_CPP_PYTHON_VERSION
+
+ENV DEBIAN_FRONTEND=noninteractive
+WORKDIR /build
+ADD dockerfile_scripts/install_py11.sh dockerfile_scripts/install_py11.sh
+RUN ./dockerfile_scripts/install_py11.sh
+# gcc-12 is required: Ubuntu 22.04 ships gcc-11 by default which CUDA 12.4
+# treats as "unsupported"; we pin gcc-12 to match the official CI workflow.
+RUN apt-get install -y --no-install-recommends \
+ python3.11-dev \
+ cmake build-essential ninja-build git \
+ gcc-12 g++-12 \
+ libgomp1 \
+ && rm -rf /var/lib/apt/lists/*
+
+ENV CC=/usr/bin/gcc-12
+ENV CXX=/usr/bin/g++-12
+ENV CUDAHOSTCXX=/usr/bin/g++-12
+
+RUN python3.11 -m pip install --no-cache-dir --upgrade pip setuptools wheel
+
+# Architecture list aligned with the official llama-cpp-python CUDA CI workflow:
+# https://github.com/abetlen/llama-cpp-python/blob/main/.github/workflows/build-wheels-cuda.yaml
+ENV CMAKE_ARGS="-DGGML_CUDA=ON -DGGML_CUDA_FORCE_MMQ=ON -DGGML_NATIVE=OFF \
+ -DCMAKE_CUDA_ARCHITECTURES=70-real;75-real;80-real;86-real;89-real;90-real;90-virtual \
+ -DCMAKE_CUDA_FLAGS=--allow-unsupported-compiler \
+ -DCMAKE_CUDA_HOST_COMPILER=/usr/bin/g++-12"
+
+RUN python3.11 -m pip wheel \
+ --no-cache-dir \
+ --no-binary llama-cpp-python \
+ --wheel-dir=/wheels \
+ "llama-cpp-python==${LLAMA_CPP_PYTHON_VERSION}"
+
+# ============================================================
+# Vulkan (AMD / Intel / any Vulkan-capable GPU) builder
+# Builds llama_cpp_python with Vulkan compute backend.
+# Works on RDNA1/2/3, GCN, Intel Arc, and more.
+# ============================================================
+FROM ubuntu:22.04 AS llama-builder-vulkan
+ARG LLAMA_CPP_PYTHON_VERSION
+
+ENV DEBIAN_FRONTEND=noninteractive
+WORKDIR /build
+ADD dockerfile_scripts/install_py11.sh dockerfile_scripts/install_py11.sh
+RUN ./dockerfile_scripts/install_py11.sh
+# Vulkan headers + glslang (shader compiler) are build-time only
+RUN apt-get install -y --no-install-recommends \
+ python3.11-dev \
+ cmake build-essential ninja-build git \
+ libgomp1 \
+ libvulkan-dev glslang-tools \
+ && rm -rf /var/lib/apt/lists/*
+
+RUN python3.11 -m pip install --no-cache-dir --upgrade pip setuptools wheel
+
+ENV CMAKE_ARGS="-DGGML_VULKAN=ON -DGGML_NATIVE=OFF"
+
+RUN python3.11 -m pip wheel \
+ --no-cache-dir \
+ --no-binary llama-cpp-python \
+ --wheel-dir=/wheels \
+ "llama-cpp-python==${LLAMA_CPP_PYTHON_VERSION}"
+
+# ============================================================
+# CPU / ARM runtime
+# ============================================================
+FROM ubuntu:22.04 AS runtime-cpu
+
+ARG CCB_DB_NAME=ccb
+ARG CCB_DB_USER=ccbuser
+ARG CCB_DB_PASS=ccbpass
+
+ENV CCB_DB_NAME=${CCB_DB_NAME}
+ENV CCB_DB_USER=${CCB_DB_USER}
+ENV CCB_DB_PASS=${CCB_DB_PASS}
+ENV DEBIAN_FRONTEND=noninteractive
+ENV AA_DOCKER_ENV=1
+
+WORKDIR /app
+
+ADD dockerfile_scripts/install_deps.sh dockerfile_scripts/install_deps.sh
+RUN ./dockerfile_scripts/install_deps.sh
+ADD dockerfile_scripts/install_py11.sh dockerfile_scripts/install_py11.sh
+RUN ./dockerfile_scripts/install_py11.sh
+ADD dockerfile_scripts/pgsql dockerfile_scripts/pgsql
+RUN ./dockerfile_scripts/pgsql/install.sh
+ADD dockerfile_scripts/install_frpc.sh dockerfile_scripts/install_frpc.sh
+RUN ./dockerfile_scripts/install_frpc.sh
+RUN apt-get autoclean
+ADD dockerfile_scripts/entrypoint.sh dockerfile_scripts/entrypoint.sh
+
+ENV DEBIAN_FRONTEND=dialog
+
+# Install llama_cpp_python from the CPU builder wheel
+COPY --from=llama-builder-cpu /wheels /wheels
+RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel \
+ && python3 -m pip install --no-cache-dir --no-index --find-links=/wheels llama-cpp-python \
+ && python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu \
+ && rm -rf /wheels \
+ && pip cache purge
+
+COPY requirements.txt .
+RUN sed -i '/^llama_cpp_python/d' requirements.txt \
+ && python3 -m pip install --no-cache-dir -r requirements.txt \
+ && python3 -m pip cache purge
+
+COPY context_chat_backend context_chat_backend
+COPY main.py .
+COPY main_em.py .
+COPY config.?pu.yaml .
+COPY logger_config*.yaml .
+COPY hwdetect.sh .
+COPY harp_connect.sh .
+COPY supervisord.conf /etc/supervisor/supervisord.conf
+
+ENTRYPOINT ["supervisord", "-c", "/etc/supervisor/supervisord.conf"]
+
+# ============================================================
+# CUDA (NVIDIA GPU) runtime
+# ============================================================
+FROM ${CUDA_RUNTIME_IMAGE} AS runtime-cuda
ARG CCB_DB_NAME=ccb
ARG CCB_DB_USER=ccbuser
ARG CCB_DB_PASS=ccbpass
-ENV CCB_DB_NAME ${CCB_DB_NAME}
-ENV CCB_DB_USER ${CCB_DB_USER}
-ENV CCB_DB_PASS ${CCB_DB_PASS}
-ENV DEBIAN_FRONTEND noninteractive
-ENV NVIDIA_VISIBLE_DEVICES all
-ENV NVIDIA_DRIVER_CAPABILITIES compute
-ENV AA_DOCKER_ENV 1
+ENV CCB_DB_NAME=${CCB_DB_NAME}
+ENV CCB_DB_USER=${CCB_DB_USER}
+ENV CCB_DB_PASS=${CCB_DB_PASS}
+ENV DEBIAN_FRONTEND=noninteractive
+ENV NVIDIA_VISIBLE_DEVICES=all
+ENV NVIDIA_DRIVER_CAPABILITIES=compute
+ENV AA_DOCKER_ENV=1
-# Set working directory
WORKDIR /app
-# Install dependencies
ADD dockerfile_scripts/install_deps.sh dockerfile_scripts/install_deps.sh
RUN ./dockerfile_scripts/install_deps.sh
ADD dockerfile_scripts/install_py11.sh dockerfile_scripts/install_py11.sh
@@ -30,27 +197,91 @@ RUN ./dockerfile_scripts/install_frpc.sh
RUN apt-get autoclean
ADD dockerfile_scripts/entrypoint.sh dockerfile_scripts/entrypoint.sh
-# Restore interactivity
-ENV DEBIAN_FRONTEND dialog
+ENV DEBIAN_FRONTEND=dialog
+
+# Install llama_cpp_python from the CUDA builder wheel
+COPY --from=llama-builder-cuda /wheels /wheels
+RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel \
+ && python3 -m pip install --no-cache-dir --no-index --find-links=/wheels llama-cpp-python \
+ && rm -rf /wheels \
+ && pip cache purge
-# Copy requirements files
COPY requirements.txt .
+RUN sed -i '/^llama_cpp_python/d' requirements.txt \
+ && python3 -m pip install --no-cache-dir -r requirements.txt \
+ && python3 -m pip cache purge
+
+COPY context_chat_backend context_chat_backend
+COPY main.py .
+COPY main_em.py .
+COPY config.?pu.yaml .
+COPY logger_config*.yaml .
+COPY hwdetect.sh .
+COPY harp_connect.sh .
+COPY supervisord.conf /etc/supervisor/supervisord.conf
+
+ENTRYPOINT ["supervisord", "-c", "/etc/supervisor/supervisord.conf"]
+
+# ============================================================
+# Vulkan (AMD / Intel / any Vulkan-capable GPU) runtime
+# Run with: --device /dev/dri (and optionally --device /dev/kfd for AMD)
+# The RADV Mesa driver (mesa-vulkan-drivers) is included and covers
+# GCN, RDNA1/2/3 and newer AMD GPUs out of the box.
+# ============================================================
+FROM ubuntu:22.04 AS runtime-vulkan
-# Install requirements
-RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel
-RUN python3 -m pip install --no-cache-dir https://github.com/abetlen/llama-cpp-python/releases/download/v0.3.13-cu122/llama_cpp_python-0.3.13-cp311-cp311-linux_x86_64.whl
-RUN sed -i '/llama_cpp_python/d' requirements.txt
-RUN python3 -m pip install --no-cache-dir -r requirements.txt && python3 -m pip cache purge
+ARG CCB_DB_NAME=ccb
+ARG CCB_DB_USER=ccbuser
+ARG CCB_DB_PASS=ccbpass
+
+ENV CCB_DB_NAME=${CCB_DB_NAME}
+ENV CCB_DB_USER=${CCB_DB_USER}
+ENV CCB_DB_PASS=${CCB_DB_PASS}
+ENV DEBIAN_FRONTEND=noninteractive
+ENV AA_DOCKER_ENV=1
+
+WORKDIR /app
+
+ADD dockerfile_scripts/install_deps.sh dockerfile_scripts/install_deps.sh
+RUN ./dockerfile_scripts/install_deps.sh
+ADD dockerfile_scripts/install_py11.sh dockerfile_scripts/install_py11.sh
+RUN ./dockerfile_scripts/install_py11.sh
+ADD dockerfile_scripts/pgsql dockerfile_scripts/pgsql
+RUN ./dockerfile_scripts/pgsql/install.sh
+ADD dockerfile_scripts/install_frpc.sh dockerfile_scripts/install_frpc.sh
+RUN ./dockerfile_scripts/install_frpc.sh
+RUN apt-get autoclean
+ADD dockerfile_scripts/entrypoint.sh dockerfile_scripts/entrypoint.sh
+
+# Install Vulkan runtime + AMD RADV open-source driver
+RUN apt-get update \
+ && apt-get install -y --no-install-recommends \
+ libvulkan1 mesa-vulkan-drivers \
+ && rm -rf /var/lib/apt/lists/*
+
+ENV DEBIAN_FRONTEND=dialog
+
+# Install llama_cpp_python from the Vulkan builder wheel
+COPY --from=llama-builder-vulkan /wheels /wheels
+RUN python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel \
+ && python3 -m pip install --no-cache-dir --no-index --find-links=/wheels llama-cpp-python \
+ && rm -rf /wheels \
+ && pip cache purge
+
+COPY requirements.txt .
+RUN sed -i '/^llama_cpp_python/d' requirements.txt \
+ && python3 -m pip install --no-cache-dir -r requirements.txt \
+ && python3 -m pip cache purge
-# Copy application files
COPY context_chat_backend context_chat_backend
COPY main.py .
COPY main_em.py .
COPY config.?pu.yaml .
-COPY logger_config.yaml .
-COPY logger_config_em.yaml .
+COPY logger_config*.yaml .
COPY hwdetect.sh .
COPY harp_connect.sh .
COPY supervisord.conf /etc/supervisor/supervisord.conf
ENTRYPOINT ["supervisord", "-c", "/etc/supervisor/supervisord.conf"]
+
+FROM runtime-cpu AS final
diff --git a/appinfo/info.xml b/appinfo/info.xml
index 9760cd29..e5896385 100644
--- a/appinfo/info.xml
+++ b/appinfo/info.xml
@@ -82,5 +82,25 @@ Setup background job workers as described here: https://docs.nextcloud.com/serve
Password to be used for authenticating requests to the OpenAI-compatible endpoint set in CC_EM_BASE_URL.
+
+
+ requestproc
+ Request Processing Mode
+ APP_ROLE=requestproc
+ true
+
+
+ updatesproc
+ Metadata Updates Processing Mode
+ APP_ROLE=updatesproc
+ false
+
+
+ indexing
+ Indexing Mode
+ APP_ROLE=indexing
+ false
+
+
diff --git a/config.cpu.yaml b/config.cpu.yaml
index 1512ea07..6ceac915 100644
--- a/config.cpu.yaml
+++ b/config.cpu.yaml
@@ -7,7 +7,10 @@ verify_ssl: true
use_colors: true
uvicorn_workers: 1
embedding_chunk_size: 2000
-doc_parser_worker_limit: 10
+doc_indexing_batch_size: 32 # theoretical max RAM usage: 32 * 100 MiB
+actions_batch_size: 512
+file_parsing_cpu_count: -1 # divides the batch into these many chunks, -1 = auto
+concurrent_file_fetches: 10 # maximum number of files to fetch concurrently to not overload the NC server
vectordb:
@@ -43,6 +46,9 @@ embedding:
llm:
nc_texttotext:
+ # template:
+ # n_ctx:
+ # max_tokens:
llama:
# all options: https://python.langchain.com/api_reference/community/llms/langchain_community.llms.llamacpp.LlamaCpp.html
@@ -52,14 +58,12 @@ llm:
max_tokens: 4096
template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant, good at finding relevant context from documents to answer questions provided by the user. <|im_end|>\n<|im_start|> user\nUse the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.\n\nSTART OF CONTEXT: \n{context} \n\nEND OF CONTEXT!\n\nIf you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. Don't mention the context in your answer but rather just answer the question directly. Detect the language of the question and make sure to use the same language that was used in the question to answer the question. Don't mention which language was used, but just answer the question directly in the same langauge. \nQuestion: {question} Let's think this step-by-step. \n<|im_end|>\n<|im_start|> assistant\n"
no_ctx_template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant.<|im_end|>\n<|im_start|> user\n{question}<|im_end|>\n<|im_start|> assistant\n"
- end_separator: "<|im_end|>"
ctransformer:
# all options: https://python.langchain.com/api_reference/community/llms/langchain_community.llms.ctransformers.CTransformers.html
model: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant, good at finding relevant context from documents to answer questions provided by the user. <|im_end|>\n<|im_start|> user\nUse the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.\n\nSTART OF CONTEXT: \n{context} \n\nEND OF CONTEXT!\n\nIf you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. Don't mention the context in your answer but rather just answer the question directly. Detect the language of the question and make sure to use the same language that was used in the question to answer the question. Don't mention which language was used, but just answer the question directly in the same langauge. \nQuestion: {question} Let's think this step-by-step. \n<|im_end|>\n<|im_start|> assistant\n"
no_ctx_template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant.<|im_end|>\n<|im_start|> user\n{question}<|im_end|>\n<|im_start|> assistant\n"
- end_separator: "<|im_end|>"
config:
context_length: 8192
max_new_tokens: 4096
diff --git a/config.gpu.yaml b/config.gpu.yaml
index fc3acaf2..a12fd1be 100644
--- a/config.gpu.yaml
+++ b/config.gpu.yaml
@@ -7,7 +7,10 @@ verify_ssl: true
use_colors: true
uvicorn_workers: 1
embedding_chunk_size: 2000
-doc_parser_worker_limit: 10
+doc_indexing_batch_size: 32 # theoretical max RAM usage: 32 * 100 MiB
+actions_batch_size: 512
+file_parsing_cpu_count: -1 # divides the batch into these many chunks, -1 = auto
+concurrent_file_fetches: 10 # maximum number of files to fetch concurrently to not overload the NC server
vectordb:
@@ -44,6 +47,9 @@ embedding:
llm:
nc_texttotext:
+ # template:
+ # n_ctx:
+ # max_tokens:
llama:
# all options: https://python.langchain.com/api_reference/community/llms/langchain_community.llms.llamacpp.LlamaCpp.html
@@ -53,7 +59,6 @@ llm:
max_tokens: 4096
template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant, good at finding relevant context from documents to answer questions provided by the user. <|im_end|>\n<|im_start|> user\nUse the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.\n\nSTART OF CONTEXT: \n{context} \n\nEND OF CONTEXT!\n\nIf you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. Don't mention the context in your answer but rather just answer the question directly. Detect the language of the question and make sure to use the same language that was used in the question to answer the question. Don't mention which language was used, but just answer the question directly in the same langauge. \nQuestion: {question} Let's think this step-by-step. \n<|im_end|>\n<|im_start|> assistant\n"
no_ctx_template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant.<|im_end|>\n<|im_start|> user\n{question}<|im_end|>\n<|im_start|> assistant\n"
- end_separator: "<|im_end|>"
n_gpu_layers: -1
model_kwargs:
device: cuda
@@ -63,7 +68,6 @@ llm:
model: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant, good at finding relevant context from documents to answer questions provided by the user. <|im_end|>\n<|im_start|> user\nUse the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.\n\nSTART OF CONTEXT: \n{context} \n\nEND OF CONTEXT!\n\nIf you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. Don't mention the context in your answer but rather just answer the question directly. Detect the language of the question and make sure to use the same language that was used in the question to answer the question. Don't mention which language was used, but just answer the question directly in the same langauge. \nQuestion: {question} Let's think this step-by-step. \n<|im_end|>\n<|im_start|> assistant\n"
no_ctx_template: "<|im_start|> system \nYou're an AI assistant named Nextcloud Assistant.<|im_end|>\n<|im_start|> user\n{question}<|im_end|>\n<|im_start|> assistant\n"
- end_separator: "<|im_end|>"
config:
context_length: 8192
max_new_tokens: 4096
diff --git a/context_chat_backend/chain/context.py b/context_chat_backend/chain/context.py
index adbac2d6..81a58f97 100644
--- a/context_chat_backend/chain/context.py
+++ b/context_chat_backend/chain/context.py
@@ -32,21 +32,11 @@ def get_context_docs(
return vectordb.doc_search(user_id, query, ctx_limit, scope_type, scope_list)
-def get_context_chunks(context_docs: list[Document]) -> list[str]:
- context_chunks = []
- for doc in context_docs:
- if title := doc.metadata.get('title'):
- context_chunks.append(title)
- context_chunks.append(doc.page_content)
-
- return context_chunks
-
-
def do_doc_search(
user_id: str,
query: str,
vectordb_loader: VectorDBLoader,
- ctx_limit: int = 20,
+ ctx_limit: int = 30,
scope_type: ScopeType | None = None,
scope_list: list[str] | None = None,
) -> list[SearchResult]:
diff --git a/context_chat_backend/chain/ingest/doc_loader.py b/context_chat_backend/chain/ingest/doc_loader.py
index efb81b6d..832c8331 100644
--- a/context_chat_backend/chain/ingest/doc_loader.py
+++ b/context_chat_backend/chain/ingest/doc_loader.py
@@ -3,15 +3,13 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
#
-import logging
import re
import tempfile
from collections.abc import Callable
-from typing import BinaryIO
+from io import BytesIO
import docx2txt
from epub2txt import epub2txt
-from fastapi import UploadFile
from langchain_unstructured import UnstructuredLoader
from odfdo import Document
from pandas import read_csv, read_excel
@@ -19,9 +17,10 @@
from pypdf.errors import FileNotDecryptedError as PdfFileNotDecryptedError
from striprtf import striprtf
-logger = logging.getLogger('ccb.doc_loader')
+from ...types import IndexingException, SourceItem
-def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str:
+
+def _temp_file_wrapper(file: BytesIO, loader: Callable, sep: str = '\n') -> str:
raw_bytes = file.read()
with tempfile.NamedTemporaryFile(mode='wb') as tmp:
tmp.write(raw_bytes)
@@ -35,49 +34,49 @@ def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str
# -- LOADERS -- #
-def _load_pdf(file: BinaryIO) -> str:
+def _load_pdf(file: BytesIO) -> str:
pdf_reader = PdfReader(file)
return '\n\n'.join([page.extract_text().strip() for page in pdf_reader.pages])
-def _load_csv(file: BinaryIO) -> str:
+def _load_csv(file: BytesIO) -> str:
return read_csv(file).to_string(header=False, na_rep='')
-def _load_epub(file: BinaryIO) -> str:
+def _load_epub(file: BytesIO) -> str:
return _temp_file_wrapper(file, epub2txt).strip()
-def _load_docx(file: BinaryIO) -> str:
+def _load_docx(file: BytesIO) -> str:
return docx2txt.process(file).strip()
-def _load_odt(file: BinaryIO) -> str:
+def _load_odt(file: BytesIO) -> str:
return _temp_file_wrapper(file, lambda fp: Document(fp).get_formatted_text()).strip()
-def _load_ppt_x(file: BinaryIO) -> str:
+def _load_ppt_x(file: BytesIO) -> str:
return _temp_file_wrapper(file, lambda fp: UnstructuredLoader(fp).load()).strip()
-def _load_rtf(file: BinaryIO) -> str:
+def _load_rtf(file: BytesIO) -> str:
return striprtf.rtf_to_text(file.read().decode('utf-8', 'ignore')).strip()
-def _load_xml(file: BinaryIO) -> str:
+def _load_xml(file: BytesIO) -> str:
data = file.read().decode('utf-8', 'ignore')
data = re.sub(r'', '', data)
return data.strip()
-def _load_xlsx(file: BinaryIO) -> str:
+def _load_xlsx(file: BytesIO) -> str:
return read_excel(file, na_filter=False).to_string(header=False, na_rep='')
-def _load_email(file: BinaryIO, ext: str = 'eml') -> str | None:
+def _load_email(file: BytesIO, ext: str = 'eml') -> str:
# NOTE: msg format is not tested
if ext not in ['eml', 'msg']:
- return None
+ raise IndexingException(f'Unsupported email format: {ext}')
# TODO: implement attachment partitioner using unstructured.partition.partition_{email,msg}
# since langchain does not pass through the attachment_partitioner kwarg
@@ -115,30 +114,36 @@ def attachment_partitioner(
}
-def decode_source(source: UploadFile) -> str | None:
+def decode_source(source: SourceItem) -> str:
+ '''
+ Raises
+ ------
+ IndexingException
+ '''
+
+ io_obj: BytesIO | None = None
try:
# .pot files are powerpoint templates but also plain text files,
# so we skip them to prevent decoding errors
- if source.headers['title'].endswith('.pot'):
- return None
-
- mimetype = source.headers['type']
- if mimetype is None:
- return None
-
- if _loader_map.get(mimetype):
- result = _loader_map[mimetype](source.file)
- source.file.close()
- return result.encode('utf-8', 'ignore').decode('utf-8', 'ignore')
-
- result = source.file.read().decode('utf-8', 'ignore')
- source.file.close()
- return result
- except PdfFileNotDecryptedError:
- logger.warning(f'PDF file ({source.filename}) is encrypted and cannot be read')
- return None
- except Exception:
- logger.exception(f'Error decoding source file ({source.filename})', stack_info=True)
- return None
+ if source.title.endswith('.pot'):
+ raise IndexingException('PowerPoint template files (.pot) are not supported')
+
+ if isinstance(source.content, str):
+ io_obj = BytesIO(source.content.encode('utf-8', 'ignore'))
+ else:
+ io_obj = source.content
+
+ if _loader_map.get(source.type):
+ result = _loader_map[source.type](io_obj)
+ return result.encode('utf-8', 'ignore').decode('utf-8', 'ignore').strip()
+
+ return io_obj.read().decode('utf-8', 'ignore').strip()
+ except IndexingException:
+ raise
+ except PdfFileNotDecryptedError as e:
+ raise IndexingException('PDF file is encrypted and cannot be read') from e
+ except Exception as e:
+ raise IndexingException(f'Error decoding source file: {e}') from e
finally:
- source.file.close() # Ensure file is closed after processing
+ if io_obj is not None:
+ io_obj.close()
diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py
index 5871ebb8..ad2777ed 100644
--- a/context_chat_backend/chain/ingest/injest.py
+++ b/context_chat_backend/chain/ingest/injest.py
@@ -2,65 +2,240 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
+import asyncio
import logging
import re
+from collections.abc import Mapping
+from io import BytesIO
+from time import perf_counter_ns
-from fastapi.datastructures import UploadFile
+import niquests
from langchain.schema import Document
+from nc_py_api import AsyncNextcloudApp
from ...dyn_loader import VectorDBLoader
-from ...types import TConfig
-from ...utils import is_valid_source_id, to_int
+from ...types import IndexingError, IndexingException, ReceivedFileItem, SourceItem, TConfig
from ...vectordb.base import BaseVectorDB
from ...vectordb.types import DbException, SafeDbException, UpdateAccessOp
from ..types import InDocument
from .doc_loader import decode_source
from .doc_splitter import get_splitter_for
-from .mimetype_list import SUPPORTED_MIMETYPES
logger = logging.getLogger('ccb.injest')
-def _allowed_file(file: UploadFile) -> bool:
- return file.headers['type'] in SUPPORTED_MIMETYPES
+MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB, all loaded in RAM at once
+
+
+async def __fetch_file_content(
+ semaphore: asyncio.Semaphore,
+ file_id: int,
+ user_id: str,
+ _rlimit = 3,
+) -> BytesIO:
+ '''
+ Raises
+ ------
+ IndexingException
+ '''
+
+ async with semaphore:
+ nc = AsyncNextcloudApp()
+ try:
+ # a file pointer for storing the stream in memory until it is consumed
+ fp = BytesIO()
+ await nc._session.download2fp(
+ url_path=f'/ocs/v2.php/apps/context_chat/files/{file_id}',
+ fp=fp,
+ dav=False,
+ params={ 'userId': user_id },
+ )
+ fp.seek(0)
+ return fp
+ except niquests.exceptions.RequestException as e:
+ if e.response is None:
+ raise
+
+ if e.response.status_code == niquests.codes.too_many_requests: # pyright: ignore[reportAttributeAccessIssue]
+ # todo: implement rate limits in php CC?
+ wait_for = int(e.response.headers.get('Retry-After', '30'))
+ if _rlimit <= 0:
+ raise IndexingException(
+ f'Rate limited when fetching content for file id {file_id}, user id {user_id},'
+ ' max retries exceeded',
+ retryable=True,
+ ) from e
+ logger.warning(
+ f'Rate limited when fetching content for file id {file_id}, user id {user_id},'
+ f' waiting {wait_for} before retrying',
+ exc_info=e,
+ )
+ await asyncio.sleep(wait_for)
+ return await __fetch_file_content(semaphore, file_id, user_id, _rlimit - 1)
+
+ raise
+ except IndexingException:
+ raise
+ except Exception as e:
+ logger.error(f'Error fetching content for file id {file_id}, user id {user_id}: {e}', exc_info=e)
+ raise IndexingException(f'Error fetching content for file id {file_id}, user id {user_id}: {e}') from e
+
+
+async def __fetch_files_content(
+ sources: Mapping[int, SourceItem | ReceivedFileItem],
+ concurrent_file_fetches: int,
+) -> tuple[Mapping[int, SourceItem], Mapping[int, IndexingError]]:
+ source_items = {}
+ error_items = {}
+ tasks = []
+ task_sources = {}
+ semaphore = asyncio.Semaphore(concurrent_file_fetches)
+
+ file_count = sum(1 for s in sources.values() if isinstance(s, ReceivedFileItem))
+ logger.debug('Fetching content for %d file(s) (max %d concurrent)', file_count, concurrent_file_fetches)
+
+ for db_id, file in sources.items():
+ if isinstance(file, SourceItem):
+ continue
+
+ try:
+ # to detect any validation errors but it should not happen since file.reference is validated
+ file.file_id # noqa: B018
+ except ValueError as e:
+ logger.error(
+ f'Invalid file reference format for db id {db_id}, file reference {file.reference}: {e}',
+ exc_info=e,
+ )
+ error_items[db_id] = IndexingError(
+ error=f'Invalid file reference format: {file.reference}',
+ retryable=False,
+ )
+ continue
+
+ if file.size > MAX_FILE_SIZE:
+ logger.info(
+ f'Skipping db id {db_id}, file id {file.file_id}, source id {file.reference} due to size'
+ f' {(file.size/(1024*1024)):.2f} MiB exceeding the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB',
+ )
+ error_items[db_id] = IndexingError(
+ error=(
+ f'File size {(file.size/(1024*1024)):.2f} MiB'
+ f' exceeds the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB'
+ ),
+ retryable=False,
+ )
+ continue
+ # any user id from the list should have read access to the file
+ tasks.append(asyncio.ensure_future(__fetch_file_content(semaphore, file.file_id, file.userIds[0])))
+ task_sources[db_id] = file
+
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ for (db_id, file), result in zip(task_sources.items(), results, strict=True):
+ if isinstance(result, str) or isinstance(result, BytesIO):
+ source_items[db_id] = SourceItem(
+ **{
+ **file.model_dump(),
+ 'content': result,
+ }
+ )
+ elif isinstance(result, IndexingException):
+ logger.error(
+ f'Error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}'
+ f': {result}',
+ exc_info=result,
+ )
+ error_items[db_id] = IndexingError(
+ error=str(result),
+ retryable=result.retryable,
+ )
+ elif isinstance(result, BaseException):
+ logger.error(
+ f'Unexpected error fetching content for db id {db_id}, file id {file.file_id},'
+ f' reference {file.reference}: {result}',
+ exc_info=result,
+ )
+ error_items[db_id] = IndexingError(
+ error=f'Unexpected error: {result}',
+ retryable=True,
+ )
+ else:
+ logger.error(
+ f'Unknown error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}'
+ f': {result}',
+ exc_info=True,
+ )
+ error_items[db_id] = IndexingError(
+ error='Unknown error',
+ retryable=True,
+ )
+
+ # add the content providers from the orginal "sources" to the result unprocessed
+ for db_id, source in sources.items():
+ if isinstance(source, SourceItem):
+ source_items[db_id] = source
+
+ return source_items, error_items
def _filter_sources(
vectordb: BaseVectorDB,
- sources: list[UploadFile]
-) -> tuple[list[UploadFile], list[UploadFile]]:
+ sources: Mapping[int, SourceItem | ReceivedFileItem]
+) -> tuple[Mapping[int, SourceItem | ReceivedFileItem], Mapping[int, SourceItem | ReceivedFileItem]]:
'''
Returns
-------
- tuple[list[str], list[UploadFile]]
+ tuple[Mapping[int, SourceItem | ReceivedFileItem], Mapping[int, SourceItem | ReceivedFileItem]]:
First value is a list of sources that already exist in the vectordb.
Second value is a list of sources that are new and should be embedded.
'''
try:
- existing_sources, new_sources = vectordb.check_sources(sources)
+ existing_source_ids, to_embed_source_ids = vectordb.check_sources(sources)
except Exception as e:
- raise DbException('Error: Vectordb sources_to_embed error') from e
+ raise DbException('Error: Vectordb error while checking existing sources in indexing') from e
+
+ existing_sources = {}
+ to_embed_sources = {}
+
+ for db_id, source in sources.items():
+ if source.reference in existing_source_ids:
+ existing_sources[db_id] = source
+ elif source.reference in to_embed_source_ids:
+ to_embed_sources[db_id] = source
- return ([
- source for source in sources
- if source.filename in existing_sources
- ], [
- source for source in sources
- if source.filename in new_sources
- ])
+ return existing_sources, to_embed_sources
-def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[InDocument]:
- indocuments = []
+def _sources_to_indocuments(
+ config: TConfig,
+ sources: Mapping[int, SourceItem]
+) -> tuple[Mapping[int, InDocument], Mapping[int, IndexingError]]:
+ indocuments = {}
+ errored_docs = {}
- for source in sources:
- logger.debug('processing source', extra={ 'source_id': source.filename })
+ for db_id, source in sources.items():
+ logger.debug('processing source', extra={ 'source_id': source.reference })
# transform the source to have text data
- content = decode_source(source)
+ try:
+ logger.debug('Decoding source %s (type: %s)', source.reference, source.type)
+ t0 = perf_counter_ns()
+ content = decode_source(source)
+ elapsed_ms = (perf_counter_ns() - t0) / 1e6
+ logger.debug('Decoded source %s in %.2f ms (%d chars)', source.reference, elapsed_ms, len(content))
+ except IndexingException as e:
+ logger.error(f'Error decoding source ({source.reference}): {e}', exc_info=e)
+ errored_docs[db_id] = IndexingError(
+ error=str(e),
+ retryable=False,
+ )
+ continue
- if content is None or (content := content.strip()) == '':
- logger.debug('decoded empty source', extra={ 'source_id': source.filename })
+ if content == '':
+ logger.debug('decoded empty source', extra={ 'source_id': source.reference })
+ errored_docs[db_id] = IndexingError(
+ error='Decoded content is empty',
+ retryable=False,
+ )
continue
# replace more than two newlines with two newlines (also blank spaces, more than 4)
@@ -68,97 +243,151 @@ def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[
# NOTE: do not use this with all docs when programming files are added
content = re.sub(r'(\s){5,}', r'\g<1>', content)
# filter out null bytes
- content = content.replace('\0', '')
-
- if content is None or content == '':
- logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.filename })
+ content = content.replace('\0', '').strip()
+
+ if content == '':
+ logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.reference })
+ errored_docs[db_id] = IndexingError(
+ error='Cleaned up content is empty',
+ retryable=False,
+ )
continue
- logger.debug('decoded non empty source', extra={ 'source_id': source.filename })
+ logger.debug('decoded non empty source', extra={ 'source_id': source.reference })
metadata = {
- 'source': source.filename,
- 'title': _decode_latin_1(source.headers['title']),
- 'type': source.headers['type'],
+ 'source': source.reference,
+ 'title': _decode_latin_1(source.title),
+ 'type': source.type,
}
doc = Document(page_content=content, metadata=metadata)
- splitter = get_splitter_for(config.embedding_chunk_size, source.headers['type'])
+ splitter = get_splitter_for(config.embedding_chunk_size, source.type)
split_docs = splitter.split_documents([doc])
logger.debug('split document into chunks', extra={
- 'source_id': source.filename,
+ 'source_id': source.reference,
'len(split_docs)': len(split_docs),
})
- indocuments.append(InDocument(
+ indocuments[db_id] = InDocument(
documents=split_docs,
- userIds=list(map(_decode_latin_1, source.headers['userIds'].split(','))),
- source_id=source.filename, # pyright: ignore[reportArgumentType]
- provider=source.headers['provider'],
- modified=to_int(source.headers['modified']),
- ))
+ userIds=list(map(_decode_latin_1, source.userIds)),
+ source_id=source.reference,
+ provider=source.provider,
+ modified=source.modified, # pyright: ignore[reportArgumentType]
+ )
- return indocuments
+ return indocuments, errored_docs
+
+
+def _increase_access_for_existing_sources(
+ vectordb: BaseVectorDB,
+ existing_sources: Mapping[int, SourceItem | ReceivedFileItem]
+) -> Mapping[int, IndexingError | None]:
+ '''
+ update userIds for existing sources
+ allow the userIds as additional users, not as the only users
+ '''
+ if len(existing_sources) == 0:
+ return {}
+
+ results = {}
+ logger.debug('Increasing access for existing sources', extra={
+ 'source_ids': [source.reference for source in existing_sources.values()]
+ })
+ for db_id, source in existing_sources.items():
+ try:
+ vectordb.update_access(
+ UpdateAccessOp.ALLOW,
+ list(map(_decode_latin_1, source.userIds)),
+ source.reference,
+ )
+ results[db_id] = None
+ except SafeDbException as e:
+ logger.error(f'Failed to update access for source ({source.reference}): {e.args[0]}')
+ results[db_id] = IndexingError(
+ error=str(e),
+ retryable=False,
+ )
+ continue
+ except Exception as e:
+ logger.error(f'Unexpected error while updating access for source ({source.reference}): {e}')
+ results[db_id] = IndexingError(
+ error='Unexpected error while updating access',
+ retryable=True,
+ )
+ continue
+ return results
def _process_sources(
vectordb: BaseVectorDB,
config: TConfig,
- sources: list[UploadFile],
-) -> tuple[list[str],list[str]]:
+ sources: Mapping[int, SourceItem | ReceivedFileItem]
+) -> Mapping[int, IndexingError | None]:
'''
Processes the sources and adds them to the vectordb.
Returns the list of source ids that were successfully added and those that need to be retried.
'''
- existing_sources, filtered_sources = _filter_sources(vectordb, sources)
+ existing_sources, to_embed_sources = _filter_sources(vectordb, sources)
logger.debug('db filter source results', extra={
'len(existing_sources)': len(existing_sources),
'existing_sources': existing_sources,
- 'len(filtered_sources)': len(filtered_sources),
- 'filtered_sources': filtered_sources,
+ 'len(to_embed_sources)': len(to_embed_sources),
+ 'to_embed_sources': to_embed_sources,
})
- loaded_source_ids = [source.filename for source in existing_sources]
- # update userIds for existing sources
- # allow the userIds as additional users, not as the only users
- if len(existing_sources) > 0:
- logger.debug('Increasing access for existing sources', extra={
- 'source_ids': [source.filename for source in existing_sources]
- })
- for source in existing_sources:
- try:
- vectordb.update_access(
- UpdateAccessOp.allow,
- list(map(_decode_latin_1, source.headers['userIds'].split(','))),
- source.filename, # pyright: ignore[reportArgumentType]
- )
- except SafeDbException as e:
- logger.error(f'Failed to update access for source ({source.filename}): {e.args[0]}')
- continue
-
- if len(filtered_sources) == 0:
+ source_proc_results = _increase_access_for_existing_sources(vectordb, existing_sources)
+
+ logger.debug(
+ 'Fetching file contents for %d source(s) from Nextcloud',
+ len(to_embed_sources),
+ )
+ t0 = perf_counter_ns()
+ populated_to_embed_sources, errored_sources = asyncio.run(
+ __fetch_files_content(to_embed_sources, config.concurrent_file_fetches)
+ )
+ elapsed_ms = (perf_counter_ns() - t0) / 1e6
+ logger.debug(
+ 'File content fetch complete in %.2f ms: %d fetched, %d errored',
+ elapsed_ms, len(populated_to_embed_sources), len(errored_sources),
+ )
+ source_proc_results.update(errored_sources) # pyright: ignore[reportAttributeAccessIssue]
+
+ if len(populated_to_embed_sources) == 0:
# no new sources to embed
logger.debug('Filtered all sources, nothing to embed')
- return loaded_source_ids, [] # pyright: ignore[reportReturnType]
+ return source_proc_results
logger.debug('Filtered sources:', extra={
- 'source_ids': [source.filename for source in filtered_sources]
+ 'source_ids': [source.reference for source in populated_to_embed_sources.values()]
})
# invalid/empty sources are filtered out here and not counted in loaded/retryable
- indocuments = _sources_to_indocuments(config, filtered_sources)
+ indocuments, errored_docs = _sources_to_indocuments(config, populated_to_embed_sources)
- logger.debug('Converted all sources to documents')
+ source_proc_results.update(errored_docs) # pyright: ignore[reportAttributeAccessIssue]
+ logger.debug('Converted sources to documents')
if len(indocuments) == 0:
# filtered document(s) were invalid/empty, not an error
logger.debug('All documents were found empty after being processed')
- return loaded_source_ids, [] # pyright: ignore[reportReturnType]
+ return source_proc_results
+
+ logger.debug('Adding documents to vectordb', extra={
+ 'source_ids': [indoc.source_id for indoc in indocuments.values()]
+ })
- added_source_ids, retry_source_ids = vectordb.add_indocuments(indocuments)
- loaded_source_ids.extend(added_source_ids)
+ t0 = perf_counter_ns()
+ doc_add_results = vectordb.add_indocuments(indocuments)
+ elapsed_ms = (perf_counter_ns() - t0) / 1e6
+ logger.info(
+ 'vectordb.add_indocuments completed in %.2f ms for %d document(s)',
+ elapsed_ms, len(indocuments),
+ )
+ source_proc_results.update(doc_add_results) # pyright: ignore[reportAttributeAccessIssue]
logger.debug('Added documents to vectordb')
- return loaded_source_ids, retry_source_ids # pyright: ignore[reportReturnType]
+ return source_proc_results
def _decode_latin_1(s: str) -> str:
@@ -172,31 +401,15 @@ def _decode_latin_1(s: str) -> str:
def embed_sources(
vectordb_loader: VectorDBLoader,
config: TConfig,
- sources: list[UploadFile],
-) -> tuple[list[str],list[str]]:
- # either not a file or a file that is allowed
- sources_filtered = [
- source for source in sources
- if is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType]
- or _allowed_file(source)
- ]
-
+ sources: Mapping[int, SourceItem | ReceivedFileItem]
+) -> Mapping[int, IndexingError | None]:
logger.debug('Embedding sources:', extra={
'source_ids': [
- f'{source.filename} ({_decode_latin_1(source.headers["title"])})'
- for source in sources_filtered
- ],
- 'invalid_source_ids': [
- source.filename for source in sources
- if not is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType]
- ],
- 'not_allowed_file_ids': [
- source.filename for source in sources
- if not _allowed_file(source)
+ f'{source.reference} ({_decode_latin_1(source.title)})'
+ for source in sources.values()
],
- 'len(source_ids)': len(sources_filtered),
- 'len(total_source_ids)': len(sources),
+ 'len(source_ids)': len(sources),
})
vectordb = vectordb_loader.load()
- return _process_sources(vectordb, config, sources_filtered)
+ return _process_sources(vectordb, config, sources)
diff --git a/context_chat_backend/chain/one_shot.py b/context_chat_backend/chain/one_shot.py
index 1c0521bf..3bd45573 100644
--- a/context_chat_backend/chain/one_shot.py
+++ b/context_chat_backend/chain/one_shot.py
@@ -8,41 +8,29 @@
from ..dyn_loader import VectorDBLoader
from ..types import TConfig
-from .context import get_context_chunks, get_context_docs
+from .context import get_context_docs
from .query_proc import get_pruned_query
-from .types import ContextException, LLMOutput, ScopeType
+from .types import ContextException, LLMOutput, ScopeType, SearchResult
-_LLM_TEMPLATE = '''Answer based only on this context and do not add any imaginative details. Make sure to use the same language as the question in your answer.
+_LLM_TEMPLATE = '''You're an AI assistant named Nextcloud Assistant, good at finding relevant context from documents to answer questions provided by the user.
+Use the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.
+
+START OF CONTEXT:
{context}
-{question}
-''' # noqa: E501
+END OF CONTEXT!
-logger = logging.getLogger('ccb.chain')
+If you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer.
+Don't mention the context in your answer but rather just answer the question directly.
+Detect the language of the question and make sure to use the same language that was used in the question to answer the question.
+Don't mention which language was used, but just answer the question directly in the same langauge.
-def process_query(
- user_id: str,
- llm: LLM,
- app_config: TConfig,
- query: str,
- no_ctx_template: str | None = None,
- end_separator: str = '',
-):
- """
- Raises
- ------
- ValueError
- If the context length is too small to fit the query
- """
- stop = [end_separator] if end_separator else None
- output = llm.invoke(
- (query, get_pruned_query(llm, app_config, query, no_ctx_template, []))[no_ctx_template is not None], # pyright: ignore[reportArgumentType]
- stop=stop,
- userid=user_id,
- ).strip()
+Question: {question}
- return LLMOutput(output=output, sources=[])
+Let's think this step-by-step.
+''' # noqa: E501
+logger = logging.getLogger('ccb.chain')
def process_context_query(
user_id: str,
@@ -50,11 +38,10 @@ def process_context_query(
llm: LLM,
app_config: TConfig,
query: str,
- ctx_limit: int = 20,
+ ctx_limit: int = 30,
scope_type: ScopeType | None = None,
scope_list: list[str] | None = None,
template: str | None = None,
- end_separator: str = '',
):
"""
Raises
@@ -65,19 +52,21 @@ def process_context_query(
db = vectordb_loader.load()
context_docs = get_context_docs(user_id, query, db, ctx_limit, scope_type, scope_list)
if len(context_docs) == 0:
+ if scope_type is not None:
+ raise ContextException('No documents retrieved, please choose a wider scope of documents to search from')
raise ContextException('No documents retrieved, please index a few documents first')
- context_chunks = get_context_chunks(context_docs)
logger.debug('context retrieved', extra={
'len(context_docs)': len(context_docs),
- 'len(context_chunks)': len(context_chunks),
})
output = llm.invoke(
- get_pruned_query(llm, app_config, query, template or _LLM_TEMPLATE, context_chunks),
- stop=[end_separator],
+ get_pruned_query(llm, app_config, query, template or _LLM_TEMPLATE, context_docs),
userid=user_id,
).strip()
- unique_sources: list[str] = list({source for d in context_docs if (source := d.metadata.get('source'))})
+ unique_sources = [SearchResult(
+ source_id=source,
+ title=d.metadata.get('title', ''),
+ ) for d in context_docs if (source := d.metadata.get('source'))]
return LLMOutput(output=output, sources=unique_sources)
diff --git a/context_chat_backend/chain/query_proc.py b/context_chat_backend/chain/query_proc.py
index b6a99829..1fe68270 100644
--- a/context_chat_backend/chain/query_proc.py
+++ b/context_chat_backend/chain/query_proc.py
@@ -7,6 +7,7 @@
from sys import maxsize as SYS_MAXSIZE
from langchain.llms.base import LLM
+from langchain.schema import Document
from transformers import GPT2Tokenizer
from ..types import TConfig
@@ -22,7 +23,7 @@ def get_num_tokens(text: str, tokenizer: GPT2Tokenizer) -> int:
return len(tokenizer.encode(text, max_length=SYS_MAXSIZE, truncation=True))
-def get_pruned_query(llm: LLM, config: TConfig, query: str, template: str, text_chunks: list[str]) -> str:
+def get_pruned_query(llm: LLM, config: TConfig, query: str, template: str, doc_chunks: list[Document]) -> str:
'''
Truncates the input to fit the model's maximum context length
and returns the model's prediction
@@ -39,7 +40,7 @@ def get_pruned_query(llm: LLM, config: TConfig, query: str, template: str, text_
n_ctx = llm_config.get('n_ctx') \
or llm_config.get('config', {}).get('context_length') \
or llm_config.get('pipeline_kwargs', {}).get('config', {}).get('max_length') \
- or 8192
+ or 16384
# fav: tokens to generate
n_gen = llm_config.get('max_tokens') \
@@ -69,19 +70,21 @@ def get_pruned_query(llm: LLM, config: TConfig, query: str, template: str, text_
accepted_chunks = []
- while text_chunks and remaining_tokens > 0:
- context = text_chunks.pop(0)
+ for chunk in doc_chunks:
+ context = f'{chunk.metadata.get("title", "")}:\n\n{chunk.page_content}'
context_tokens = get_num_tokens(context, tokenizer)
- if context_tokens <= remaining_tokens:
- accepted_chunks.append(context)
- remaining_tokens -= context_tokens
+ if context_tokens > remaining_tokens or remaining_tokens <= 0:
+ break
+
+ accepted_chunks.append(context)
+ remaining_tokens -= context_tokens
logger.debug('pruned query stats', extra={
'total tokens': n_ctx - remaining_tokens,
'remaining tokens': remaining_tokens,
'accepted chunks': len(accepted_chunks),
- 'total chunks': len(text_chunks),
+ 'total chunks': len(doc_chunks),
})
return template.format(context='\n\n'.join(accepted_chunks), question=query)
diff --git a/context_chat_backend/chain/types.py b/context_chat_backend/chain/types.py
index b006ad1a..3afdf297 100644
--- a/context_chat_backend/chain/types.py
+++ b/context_chat_backend/chain/types.py
@@ -33,12 +33,24 @@ class ContextException(Exception):
...
+class SearchResult(TypedDict):
+ source_id: str
+ title: str
+
+
class LLMOutput(TypedDict):
output: str
- sources: list[str]
- # todo: add "titles" field
+ sources: list[SearchResult]
-class SearchResult(TypedDict):
- source_id: str
- title: str
+class EnrichedSource(BaseModel):
+ id: str
+ label: str
+ icon: str
+ url: str
+
+class EnrichedSourceList(BaseModel):
+ sources: list[EnrichedSource]
+
+class ScopeList(BaseModel):
+ source_ids: list[str]
diff --git a/context_chat_backend/config_parser.py b/context_chat_backend/config_parser.py
index dafef75f..0a62019a 100644
--- a/context_chat_backend/config_parser.py
+++ b/context_chat_backend/config_parser.py
@@ -103,17 +103,11 @@ def get_config(file_path: str) -> TConfig:
except Exception as e:
raise AssertionError('Error: could not create embedding config from config file') from e
- return TConfig(
- debug=config.get('debug', False),
- uvicorn_log_level=config.get('uvicorn_log_level', 'info'),
- disable_aaa=config.get('disable_aaa', False),
- verify_ssl=config.get('verify_ssl', config.get('httpx_verify_ssl', True)),
- use_colors=config.get('use_colors', True),
- uvicorn_workers=config.get('uvicorn_workers', 1),
- embedding_chunk_size=config.get('embedding_chunk_size', 1000),
- doc_parser_worker_limit=config.get('doc_parser_worker_limit', 10),
-
- vectordb=vectordb,
- embedding=embedding_config,
- llm=llm,
- )
+ config['verify_ssl'] = config.get('verify_ssl', config.get('httpx_verify_ssl', True))
+ config.pop('httpx_verify_ssl', None)
+
+ config['llm'] = llm
+ config['vectordb'] = vectordb
+ config['embedding'] = embedding_config
+
+ return TConfig(**config)
diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py
index c26b930a..3dadf18c 100644
--- a/context_chat_backend/controller.py
+++ b/context_chat_backend/controller.py
@@ -2,11 +2,14 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
+import time
+
+from nc_py_api.ex_app.providers.task_processing import TaskProcessingProvider
# isort: off
-from .chain.types import ContextException, LLMOutput, ScopeType, SearchResult
-from .types import LoaderException, EmbeddingException
-from .vectordb.types import DbException, SafeDbException, UpdateAccessOp
+from .chain.types import ContextException
+from .types import AppRole, LoaderException, EmbeddingException
+from .vectordb.types import DbException, SafeDbException
from .setup_functions import ensure_config_file, repair_run, setup_env_vars
# setup env vars before importing other modules
@@ -23,39 +26,31 @@
from collections.abc import Callable
from contextlib import asynccontextmanager
from functools import wraps
-from threading import Event, Thread
-from time import sleep
-from typing import Annotated, Any
-from fastapi import Body, FastAPI, Request, UploadFile
-from langchain.llms.base import LLM
+from fastapi import FastAPI, Request
from nc_py_api import AsyncNextcloudApp, NextcloudApp
from nc_py_api.ex_app import persistent_storage, set_handlers
-from pydantic import BaseModel, ValidationInfo, field_validator
from starlette.responses import FileResponse
-from .chain.context import do_doc_search
-from .chain.ingest.injest import embed_sources
-from .chain.one_shot import process_context_query, process_query
from .config_parser import get_config
-from .dyn_loader import LLMModelLoader, VectorDBLoader
+from .dyn_loader import VectorDBLoader
from .models.types import LlmException
from nc_py_api.ex_app import AppAPIAuthMiddleware
-from .utils import JSONResponse, exec_in_proc, is_valid_provider_id, is_valid_source_id, value_of
-from .vectordb.service import (
- count_documents_by_provider,
- decl_update_access,
- delete_by_provider,
- delete_by_source,
- delete_user,
- update_access,
-)
+from .utils import JSONResponse, exec_in_proc, get_app_role, is_k8s_env
+from .task_fetcher import THREAD_STOP_EVENT, start_bg_threads, trigger_handler, wait_for_bg_threads
+from .vectordb.service import count_documents_by_provider
# setup
-repair_run()
-ensure_config_file()
+# only run once
+APP_ROLE = get_app_role()
+if mp.current_process().name == 'MainProcess' and APP_ROLE in (AppRole.NORMAL, AppRole.REQUEST_PROC):
+ # normal docker containers and RP role in k8s
+ repair_run()
+ ensure_config_file()
+
logger = logging.getLogger('ccb.controller')
+app_config = get_config(os.environ['CC_CONFIG_PATH'])
__download_models_from_hf = os.environ.get('CC_DOWNLOAD_MODELS_FROM_HF', 'true').lower() in ('1', 'true', 'yes')
models_to_fetch = {
@@ -70,13 +65,55 @@
'revision': '607a30d783dfa663caf39e06633721c8d4cfcd7e',
}
} if __download_models_from_hf else {}
-app_enabled = Event()
-def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str:
- if enabled:
- app_enabled.set()
- else:
- app_enabled.clear()
+
+app_enabled = threading.Event()
+last_enabled_check: float | None = None
+enabled_check_lock: threading.Lock = threading.Lock()
+def get_enabled_state() -> bool:
+ global last_enabled_check
+ with enabled_check_lock:
+ if last_enabled_check is None or time.time() - last_enabled_check > 30:
+ nc = NextcloudApp()
+ if nc.enabled_state:
+ app_enabled.set()
+ else:
+ app_enabled.clear()
+ last_enabled_check = time.time()
+ return app_enabled.is_set()
+
+def enabled_handler(enabled: bool, nc: NextcloudApp | AsyncNextcloudApp) -> str:
+ try:
+ if enabled:
+ provider = TaskProcessingProvider(
+ id='context_chat-context_chat_search',
+ name='Context Chat',
+ task_type='context_chat:context_chat_search',
+ expected_runtime=30,
+ input_shape_defaults={
+ 'limit': 10,
+ },
+ )
+ nc.providers.task_processing.register(provider)
+ provider = TaskProcessingProvider(
+ id='context_chat-context_chat',
+ name='Context Chat',
+ task_type='context_chat:context_chat',
+ expected_runtime=30,
+ )
+ nc.providers.task_processing.register(provider)
+ app_enabled.set()
+ if THREAD_STOP_EVENT.is_set():
+ # If the threads were previously stopped, we start them again
+ # otherwise the lifecycle handler has already started them
+ start_bg_threads(app_config, get_enabled_state)
+ THREAD_STOP_EVENT.clear()
+ else:
+ app_enabled.clear()
+ wait_for_bg_threads()
+ except Exception as e:
+ logger.exception('Error in enabled handler:', exc_info=e)
+ return f'Error in enabled handler: {e}'
logger.info(f'App {("disabled", "enabled")[enabled]}')
return ''
@@ -84,28 +121,28 @@ def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str:
@asynccontextmanager
async def lifespan(app: FastAPI):
- set_handlers(app, enabled_handler, models_to_fetch=models_to_fetch)
- nc = NextcloudApp()
- if nc.enabled_state:
- app_enabled.set()
- logger.info(f'App enable state at startup: {app_enabled.is_set()}')
- t = Thread(target=background_thread_task, args=())
- t.start()
+ if APP_ROLE == AppRole.NORMAL:
+ set_handlers(app, enabled_handler, models_to_fetch=models_to_fetch, trigger_handler=trigger_handler)
+ else:
+ # k8s' rp role pulls tasks
+ set_handlers(app, enabled_handler, models_to_fetch=models_to_fetch)
+
+ start_bg_threads(app_config, get_enabled_state)
+ logger.info(f'App enable state at startup: {get_enabled_state()}')
yield
vectordb_loader.offload()
- llm_loader.offload()
+ wait_for_bg_threads()
-app_config = get_config(os.environ['CC_CONFIG_PATH'])
app = FastAPI(debug=app_config.debug, lifespan=lifespan) # pyright: ignore[reportArgumentType]
app.extra['CONFIG'] = app_config
+k8s_env = is_k8s_env()
# loaders
vectordb_loader = VectorDBLoader(app_config)
-llm_loader = LLMModelLoader(app, app_config)
# locks and semaphores
@@ -117,22 +154,12 @@ async def lifespan(app: FastAPI):
index_lock = threading.Lock()
_indexing = {}
-# limit the number of concurrent document parsing
-doc_parse_semaphore = mp.Semaphore(app_config.doc_parser_worker_limit)
-
# middlewares
if not app_config.disable_aaa:
app.add_middleware(AppAPIAuthMiddleware)
-# logger background thread
-
-def background_thread_task():
- while(True):
- logger.info(f'Currently indexing {len(_indexing)} documents (filename, size): ', extra={'_indexing': _indexing})
- sleep(10)
-
# exception handlers
@app.exception_handler(DbException)
@@ -189,7 +216,7 @@ def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
disable_aaa = app.extra['CONFIG'].disable_aaa
- if not disable_aaa and not app_enabled.is_set():
+ if not disable_aaa and not get_enabled_state():
return JSONResponse('Context Chat is disabled, enable it from AppAPI to use it.', 503)
return func(*args, **kwargs)
@@ -210,122 +237,7 @@ def _(request: Request):
@app.get('/enabled')
def _():
- return JSONResponse(content={'enabled': app_enabled.is_set()}, status_code=200)
-
-
-@app.post('/updateAccessDeclarative')
-@enabled_guard(app)
-def _(
- userIds: Annotated[list[str], Body()],
- sourceId: Annotated[str, Body()],
-):
- logger.debug('Update access declarative request:', extra={
- 'user_ids': userIds,
- 'source_id': sourceId,
- })
-
- if len(userIds) == 0:
- return JSONResponse('Empty list of user ids', 400)
-
- if not is_valid_source_id(sourceId):
- return JSONResponse('Invalid source id', 400)
-
- exec_in_proc(target=decl_update_access, args=(vectordb_loader, userIds, sourceId))
-
- return JSONResponse('Access updated')
-
-
-@app.post('/updateAccess')
-@enabled_guard(app)
-def _(
- op: Annotated[UpdateAccessOp, Body()],
- userIds: Annotated[list[str], Body()],
- sourceId: Annotated[str, Body()],
-):
- logger.debug('Update access request', extra={
- 'op': op,
- 'user_ids': userIds,
- 'source_id': sourceId,
- })
-
- if len(userIds) == 0:
- return JSONResponse('Empty list of user ids', 400)
-
- if not is_valid_source_id(sourceId):
- return JSONResponse('Invalid source id', 400)
-
- exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, sourceId))
-
- return JSONResponse('Access updated')
-
-
-@app.post('/updateAccessProvider')
-@enabled_guard(app)
-def _(
- op: Annotated[UpdateAccessOp, Body()],
- userIds: Annotated[list[str], Body()],
- providerId: Annotated[str, Body()],
-):
- logger.debug('Update access by provider request', extra={
- 'op': op,
- 'user_ids': userIds,
- 'provider_id': providerId,
- })
-
- if len(userIds) == 0:
- return JSONResponse('Empty list of user ids', 400)
-
- if not is_valid_provider_id(providerId):
- return JSONResponse('Invalid provider id', 400)
-
- exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, providerId))
-
- return JSONResponse('Access updated')
-
-
-@app.post('/deleteSources')
-@enabled_guard(app)
-def _(sourceIds: Annotated[list[str], Body(embed=True)]):
- logger.debug('Delete sources request', extra={
- 'source_ids': sourceIds,
- })
-
- sourceIds = [source.strip() for source in sourceIds if source.strip() != '']
-
- if len(sourceIds) == 0:
- return JSONResponse('No sources provided', 400)
-
- res = exec_in_proc(target=delete_by_source, args=(vectordb_loader, sourceIds))
- if res is False:
- return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400)
-
- return JSONResponse('All valid sources deleted')
-
-
-@app.post('/deleteProvider')
-@enabled_guard(app)
-def _(providerKey: str = Body(embed=True)):
- logger.debug('Delete sources by provider for all users request', extra={ 'provider_key': providerKey })
-
- if value_of(providerKey) is None:
- return JSONResponse('Invalid provider key provided', 400)
-
- exec_in_proc(target=delete_by_provider, args=(vectordb_loader, providerKey))
-
- return JSONResponse('All valid sources deleted')
-
-
-@app.post('/deleteUser')
-@enabled_guard(app)
-def _(userId: str = Body(embed=True)):
- logger.debug('Remove access list for user, and orphaned sources', extra={ 'user_id': userId })
-
- if value_of(userId) is None:
- return JSONResponse('Invalid userId provided', 400)
-
- exec_in_proc(target=delete_user, args=(vectordb_loader, userId))
-
- return JSONResponse('User deleted')
+ return JSONResponse(content={'enabled': get_enabled_state()}, status_code=200)
@app.post('/countIndexedDocuments')
@@ -335,179 +247,14 @@ def _():
return JSONResponse(counts)
-@app.put('/loadSources')
-@enabled_guard(app)
-def _(sources: list[UploadFile]):
- global _indexing
-
- if len(sources) == 0:
- return JSONResponse('No sources provided', 400)
-
- filtered_sources = []
-
- for source in sources:
- if not value_of(source.filename):
- logger.warning('Skipping source with invalid source_id', extra={
- 'source_id': source.filename,
- 'title': source.headers.get('title'),
- })
- continue
-
- with index_lock:
- if source.filename in _indexing:
- # this request will be retried by the client
- return JSONResponse(
- f'This source ({source.filename}) is already being processed in another request, try again later',
- 503,
- headers={'cc-retry': 'true'},
- )
-
- if not (
- value_of(source.headers.get('userIds'))
- and source.headers.get('title', None) is not None
- and value_of(source.headers.get('type'))
- and value_of(source.headers.get('modified'))
- and source.headers['modified'].isdigit()
- and value_of(source.headers.get('provider'))
- ):
- logger.warning('Skipping source with invalid/missing headers', extra={
- 'source_id': source.filename,
- 'title': source.headers.get('title'),
- 'headers': source.headers,
- })
- continue
-
- filtered_sources.append(source)
-
- # wait for 10 minutes before failing the request
- semres = doc_parse_semaphore.acquire(block=True, timeout=10*60)
- if not semres:
+@app.get('/downloadLogs')
+def download_logs():
+ if k8s_env:
return JSONResponse(
- 'Document parser worker limit reached, try again in some time or consider increasing the limit',
- 503,
- headers={'cc-retry': 'true'}
+ 'Download of logs is not supported in Kubernetes environment. Use the standard logging infrastructure.',
+ status_code=400,
)
- with index_lock:
- for source in filtered_sources:
- _indexing[source.filename] = source.size
-
- try:
- loaded_sources, not_added_sources = exec_in_proc(
- target=embed_sources,
- args=(vectordb_loader, app.extra['CONFIG'], filtered_sources)
- )
- except (DbException, EmbeddingException):
- raise
- except Exception as e:
- raise DbException('Error: failed to load sources') from e
- finally:
- with index_lock:
- for source in filtered_sources:
- _indexing.pop(source.filename, None)
- doc_parse_semaphore.release()
-
- if len(loaded_sources) != len(filtered_sources):
- logger.debug('Some sources were not loaded', extra={
- 'Count of loaded sources': f'{len(loaded_sources)}/{len(filtered_sources)}',
- 'source_ids': loaded_sources,
- })
-
- # loaded sources include the existing sources that may only have their access updated
- return JSONResponse({'loaded_sources': loaded_sources, 'sources_to_retry': not_added_sources})
-
-
-class Query(BaseModel):
- userId: str
- query: str
- useContext: bool = True
- scopeType: ScopeType | None = None
- scopeList: list[str] | None = None
- ctxLimit: int = 20
-
- @field_validator('userId', 'query', 'ctxLimit')
- @classmethod
- def check_empty_values(cls, value: Any, info: ValidationInfo):
- if value_of(value) is None:
- raise ValueError('Empty value for field', info.field_name)
-
- return value
-
- @field_validator('ctxLimit')
- @classmethod
- def at_least_one_context(cls, value: int):
- if value < 1:
- raise ValueError('Invalid context chunk limit')
-
- return value
-
-
-def execute_query(query: Query, in_proc: bool = True) -> LLMOutput:
- llm: LLM = llm_loader.load()
- template = app.extra.get('LLM_TEMPLATE')
- no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE']
- # todo: array
- end_separator = app.extra.get('LLM_END_SEPARATOR', '')
-
- if query.useContext:
- target = process_context_query
- args=(
- query.userId,
- vectordb_loader,
- llm,
- app_config,
- query.query,
- query.ctxLimit,
- query.scopeType,
- query.scopeList,
- template,
- end_separator,
- )
- else:
- target=process_query
- args=(
- query.userId,
- llm,
- app_config,
- query.query,
- no_ctx_template,
- end_separator,
- )
-
- if in_proc:
- return exec_in_proc(target=target, args=args)
-
- return target(*args) # pyright: ignore
-
-
-@app.post('/query')
-@enabled_guard(app)
-def _(query: Query) -> LLMOutput:
- logger.debug('received query request', extra={ 'query': query.dict() })
-
- if app_config.llm[0] == 'nc_texttotext':
- return execute_query(query)
-
- with llm_lock:
- return execute_query(query, in_proc=False)
-
-
-@app.post('/docSearch')
-@enabled_guard(app)
-def _(query: Query) -> list[SearchResult]:
- # useContext from Query is not used here
- return exec_in_proc(target=do_doc_search, args=(
- query.userId,
- query.query,
- vectordb_loader,
- query.ctxLimit,
- query.scopeType,
- query.scopeList,
- ))
-
-
-@app.get('/downloadLogs')
-def download_logs() -> FileResponse:
with tempfile.NamedTemporaryFile('wb', delete=False) as tmp:
with zipfile.ZipFile(tmp, mode='w', compression=zipfile.ZIP_DEFLATED) as zip_file:
files = os.listdir(os.path.join(persistent_storage(), 'logs'))
diff --git a/context_chat_backend/dyn_loader.py b/context_chat_backend/dyn_loader.py
index d67310ff..47b19575 100644
--- a/context_chat_backend/dyn_loader.py
+++ b/context_chat_backend/dyn_loader.py
@@ -7,11 +7,9 @@
import gc
import logging
from abc import ABC, abstractmethod
-from time import time
from typing import Any
import torch
-from fastapi import FastAPI
from langchain.llms.base import LLM
from .models.loader import init_model
@@ -54,19 +52,11 @@ def offload(self) -> None:
class LLMModelLoader(Loader):
- def __init__(self, app: FastAPI, config: TConfig) -> None:
+ def __init__(self, config: TConfig) -> None:
self.config = config
- self.app = app
def load(self) -> LLM:
- if self.app.extra.get('LLM_MODEL') is not None:
- self.app.extra['LLM_LAST_ACCESSED'] = time()
- return self.app.extra['LLM_MODEL']
-
llm_name, llm_config = self.config.llm
- self.app.extra['LLM_TEMPLATE'] = llm_config.pop('template', '')
- self.app.extra['LLM_NO_CTX_TEMPLATE'] = llm_config.pop('no_ctx_template', '')
- self.app.extra['LLM_END_SEPARATOR'] = llm_config.pop('end_separator', '')
try:
model = init_model('llm', (llm_name, llm_config))
@@ -75,13 +65,9 @@ def load(self) -> LLM:
if not isinstance(model, LLM):
raise LoaderException(f'Error: {model} does not implement "llm" type or has returned an invalid object')
- self.app.extra['LLM_MODEL'] = model
- self.app.extra['LLM_LAST_ACCESSED'] = time()
return model
def offload(self) -> None:
- if self.app.extra.get('LLM_MODEL') is not None:
- del self.app.extra['LLM_MODEL']
clear_cache()
diff --git a/context_chat_backend/logger.py b/context_chat_backend/logger.py
index 79e99aff..25fb1613 100644
--- a/context_chat_backend/logger.py
+++ b/context_chat_backend/logger.py
@@ -51,6 +51,7 @@ def __init__(
self,
*,
fmt_keys: dict[str, str] | None = None,
+ use_colors: bool = False,
):
super().__init__()
self.fmt_keys = fmt_keys if fmt_keys is not None else {}
diff --git a/context_chat_backend/chain/ingest/mimetype_list.py b/context_chat_backend/mimetype_list.py
similarity index 100%
rename from context_chat_backend/chain/ingest/mimetype_list.py
rename to context_chat_backend/mimetype_list.py
diff --git a/context_chat_backend/network_em.py b/context_chat_backend/network_em.py
index 18bb11f4..5ba8faf5 100644
--- a/context_chat_backend/network_em.py
+++ b/context_chat_backend/network_em.py
@@ -3,14 +3,16 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
#
import logging
+import socket
from time import sleep
from typing import Literal, TypedDict
+from urllib.parse import urlparse
import niquests
from langchain_core.embeddings import Embeddings
-from pydantic import BaseModel
from .types import (
+ DocErrorEmbeddingException,
EmbeddingException,
FatalEmbeddingException,
RetryableEmbeddingException,
@@ -20,6 +22,7 @@
)
logger = logging.getLogger('ccb.nextwork_em')
+TCP_CONNECT_TIMEOUT = 2.0 # seconds
# Copied from llama_cpp/llama_types.py
@@ -41,8 +44,35 @@ class CreateEmbeddingResponse(TypedDict):
usage: EmbeddingUsage
-class NetworkEmbeddings(Embeddings, BaseModel):
- app_config: TConfig
+class NetworkEmbeddings(Embeddings):
+ def __init__(self, app_config: TConfig):
+ self.app_config = app_config
+
+ def _get_host_and_port(self) -> tuple[str, int]:
+ parsed = urlparse(self.app_config.embedding.base_url)
+ host = parsed.hostname
+
+ if not host:
+ raise ValueError("Invalid URL: Missing hostname")
+
+ if parsed.port:
+ port = parsed.port
+ else:
+ port = 443 if parsed.scheme == "https" else 80
+
+ return host, port
+
+ def check_connection(self, check_origin: str) -> bool:
+ try:
+ host, port = self._get_host_and_port()
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(TCP_CONNECT_TIMEOUT)
+ sock.connect((host, port))
+ sock.close()
+ return True
+ except (ValueError, TimeoutError, ConnectionRefusedError, socket.gaierror) as e:
+ logger.warning(f'[{check_origin}] Embedding server is not reachable, retrying after some time: {e}')
+ return False
def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float] | list[list[float]]:
emconf = self.app_config.embedding
@@ -76,13 +106,27 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float]
if response.status_code is None:
raise EmbeddingException('Error: no response from embedding service')
if response.status_code // 100 == 4:
- raise FatalEmbeddingException(response.text)
+ raise FatalEmbeddingException(
+ response.text or f'Error: embedding request returned non-2xx status code {response.status_code}',
+ )
if response.status_code // 100 != 2:
- raise EmbeddingException(response.text)
+ raise EmbeddingException(
+ response.text or f'Error: embedding request returned non-2xx status code {response.status_code}',
+ response,
+ )
except FatalEmbeddingException as e:
logger.error('Fatal error while getting embeddings: %s', str(e), exc_info=e)
raise e
except EmbeddingException as e:
+ try:
+ if e.response is not None:
+ err_msg = e.response.json().get('error', {}).get('message', '')
+ if err_msg == 'llama_decode returned -1':
+ # the document coult not be processed
+ raise DocErrorEmbeddingException(f'Failed to embed the document: {err_msg}') from e
+ except niquests.exceptions.JSONDecodeError:
+ ...
+
if try_ > 0:
logger.debug('Retrying embedding request in 5 secs', extra={'try': try_})
sleep(5)
@@ -108,10 +152,14 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float]
logger.error('Unexpected error while getting embeddings', exc_info=e)
raise EmbeddingException('Error: unexpected error while getting embeddings') from e
- # converts TypedDict to a pydantic model
- resp = CreateEmbeddingResponse(**response.json())
- if isinstance(input_, str):
- return resp['data'][0]['embedding']
+ try:
+ # converts TypedDict to a pydantic model
+ resp = CreateEmbeddingResponse(**response.json())
+ if isinstance(input_, str):
+ return resp['data'][0]['embedding']
+ except Exception as e:
+ logger.error('Error parsing embedding response', exc_info=e)
+ raise EmbeddingException('Error: failed to parse embedding response') from e
# only one embedding in d['embedding'] since truncate is True
return [d['embedding'] for d in resp['data']] # pyright: ignore[reportReturnType]
diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py
new file mode 100644
index 00000000..baa882db
--- /dev/null
+++ b/context_chat_backend/task_fetcher.py
@@ -0,0 +1,775 @@
+#
+# SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors
+# SPDX-License-Identifier: AGPL-3.0-or-later
+#
+import logging
+import math
+import os
+from collections.abc import Mapping
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import suppress
+from enum import Enum
+from threading import Event, Thread
+from time import sleep
+from typing import Any
+
+import niquests
+from langchain.llms.base import LLM
+from nc_py_api import NextcloudApp, NextcloudException
+from niquests import JSONDecodeError, RequestException
+from pydantic import ValidationError
+
+from .chain.context import do_doc_search
+from .chain.ingest.injest import embed_sources
+from .chain.one_shot import process_context_query
+from .chain.types import ContextException, EnrichedSourceList, LLMOutput, ScopeList, SearchResult
+from .dyn_loader import LLMModelLoader, VectorDBLoader
+from .network_em import NetworkEmbeddings
+from .types import (
+ ActionsQueueItems,
+ ActionType,
+ AppRole,
+ EmbeddingException,
+ FilesQueueItems,
+ IndexingError,
+ LoaderException,
+ ReceivedFileItem,
+ SourceItem,
+ TConfig,
+)
+from .utils import SubprocessKilledError, exec_in_proc, get_app_role
+from .vectordb.service import (
+ decl_update_access,
+ delete_by_provider,
+ delete_by_source,
+ delete_user,
+ update_access,
+ update_access_provider,
+)
+from .vectordb.types import DbException, SafeDbException
+
+APP_ROLE = get_app_role()
+THREADS = {}
+THREAD_STOP_EVENT = Event()
+LOGGER = logging.getLogger('ccb.task_fetcher')
+MIN_FILES_PER_CPU = 4
+POLLING_COOLDOWN = 30
+
+# task processing or request processing
+TP_TRIGGER = Event()
+TP_CHECK_INTERVAL = 5
+TP_CHECK_INTERVAL_WITH_TRIGGER = 5 * 60
+TP_CHECK_INTERVAL_ON_ERROR = 15
+CONTEXT_LIMIT = 30
+
+
+class ThreadType(Enum):
+ FILES_INDEXING = 'files_indexing'
+ UPDATES_PROCESSING = 'updates_processing'
+ REQUEST_PROCESSING = 'request_processing'
+
+
+def files_indexing_thread(app_config: TConfig, get_enabled_state) -> None:
+ try:
+ network_em = NetworkEmbeddings(app_config)
+ vectordb_loader = VectorDBLoader(app_config)
+ except LoaderException as e:
+ LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e)
+ return
+
+ def _load_sources(source_items: Mapping[int, SourceItem | ReceivedFileItem]) -> Mapping[int, IndexingError | None]:
+ source_refs = [s.reference for s in source_items.values()]
+ LOGGER.info('Starting embed_sources subprocess for %d source(s)', len(source_items), extra={
+ 'source_ids': source_refs,
+ })
+ try:
+ result = exec_in_proc(
+ target=embed_sources,
+ args=(vectordb_loader, app_config, source_items),
+ )
+ errors = {k: v for k, v in result.items() if isinstance(v, IndexingError)}
+ LOGGER.info(
+ 'embed_sources finished for %d source(s): %d succeeded, %d errored',
+ len(source_items), len(result) - len(errors), len(errors),
+ extra={'errors': errors},
+ )
+ return result
+ except SubprocessKilledError as e:
+ LOGGER.error(
+ 'embed_sources subprocess was killed for %d source(s) with exitcode %s',
+ len(source_items), e.exitcode, exc_info=e, extra={
+ 'source_ids': source_refs,
+ },
+ )
+ if len(source_items) == 1:
+ return dict.fromkeys(
+ source_items,
+ IndexingError(error=f'Subprocess killed with exitcode {e.exitcode}: {e}', retryable=False),
+ )
+
+ # Fall back to one-by-one to isolate the problematic file.
+ LOGGER.warning(
+ 'Falling back to individual processing for %d sources',
+ len(source_items),
+ )
+ fallback: dict[int, IndexingError | None] = {}
+ for db_id, item in source_items.items():
+ fallback.update(_load_sources({db_id: item}))
+ return fallback
+ except Exception as e:
+ err = IndexingError(
+ error=f'{e.__class__.__name__}: {e}',
+ retryable=True,
+ )
+ LOGGER.error(
+ 'embed_sources subprocess raised a %s error for %d sources, marking all as retryable',
+ e.__class__.__name__, len(source_refs), exc_info=e, extra={
+ 'source_ids': source_refs,
+ }
+ )
+ return dict.fromkeys(source_items, err)
+
+
+ # divides the batch into these many chunks
+ file_parsing_cpu_count = (
+ app_config.file_parsing_cpu_count, # when set to a positive value
+ max(1, (os.cpu_count() or 2) - 1), # when set to auto (-1)
+ )[app_config.file_parsing_cpu_count == -1]
+ LOGGER.info(f'Using {file_parsing_cpu_count} parallel file parsing workers')
+
+ nc = NextcloudApp()
+ while True:
+ if THREAD_STOP_EVENT.is_set():
+ LOGGER.info('Files indexing thread is stopping due to stop event being set')
+ return
+
+ if not get_enabled_state():
+ LOGGER.info('App is disabled, files indexing thread will sleep until next enabled state check')
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ try:
+ if not network_em.check_connection(ThreadType.FILES_INDEXING.value):
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ q_items_res = nc.ocs(
+ 'GET',
+ '/ocs/v2.php/apps/context_chat/queues/documents',
+ params={ 'n': app_config.doc_indexing_batch_size }
+ )
+
+ try:
+ q_items: FilesQueueItems = FilesQueueItems.model_validate(q_items_res)
+ except ValidationError as e:
+ raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e
+
+ if not q_items.files and not q_items.content_providers:
+ LOGGER.debug('No documents to index')
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ files_result = {}
+ providers_result = {}
+
+ # chunk file parsing for better file operation parallelism
+ file_chunk_size = max(MIN_FILES_PER_CPU, math.ceil(len(q_items.files) / file_parsing_cpu_count))
+ file_chunks = [
+ dict(list(q_items.files.items())[i:i+file_chunk_size])
+ for i in range(0, len(q_items.files), file_chunk_size)
+ ]
+ provider_chunk_size = max(
+ MIN_FILES_PER_CPU,
+ math.ceil(len(q_items.content_providers) / file_parsing_cpu_count),
+ )
+ provider_chunks = [
+ dict(list(q_items.content_providers.items())[i:i+provider_chunk_size])
+ for i in range(0, len(q_items.content_providers), provider_chunk_size)
+ ]
+
+ with ThreadPoolExecutor(
+ max_workers=file_parsing_cpu_count,
+ thread_name_prefix='IndexingPool',
+ ) as executor:
+ LOGGER.info(
+ 'Dispatching %d file chunk(s) and %d provider chunk(s)',
+ len(file_chunks), len(provider_chunks),
+ )
+ file_futures = [executor.submit(_load_sources, chunk) for chunk in file_chunks]
+ provider_futures = [executor.submit(_load_sources, chunk) for chunk in provider_chunks]
+
+ for i, future in enumerate(file_futures):
+ LOGGER.debug('Waiting for file chunk %d/%d future to complete', i + 1, len(file_futures))
+ files_result.update(future.result())
+ LOGGER.debug('File chunk %d/%d future completed', i + 1, len(file_futures))
+ for i, future in enumerate(provider_futures):
+ LOGGER.debug('Waiting for provider chunk %d/%d future to complete', i + 1, len(provider_futures))
+ providers_result.update(future.result())
+ LOGGER.debug('Provider chunk %d/%d future completed', i + 1, len(provider_futures))
+
+ if (
+ any(isinstance(res, IndexingError) for res in files_result.values())
+ or any(isinstance(res, IndexingError) for res in providers_result.values())
+ ):
+ LOGGER.error('Some sources failed to index', extra={
+ 'file_errors': {
+ db_id: error
+ for db_id, error in files_result.items()
+ if isinstance(error, IndexingError)
+ },
+ 'provider_errors': {
+ provider_id: error
+ for provider_id, error in providers_result.items()
+ if isinstance(error, IndexingError)
+ },
+ })
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error fetching documents to index, will retry:', exc_info=e)
+ sleep(5)
+ continue
+ except Exception as e:
+ LOGGER.exception('Error fetching documents to index:', exc_info=e)
+ sleep(5)
+ continue
+
+ # delete the entries from the PHP side queue where indexing succeeded or the error is not retryable
+ to_delete_files_db_ids = [
+ db_id for db_id, result in files_result.items()
+ if result is None or (isinstance(result, IndexingError) and not result.retryable)
+ ]
+ to_delete_provider_db_ids = [
+ db_id for db_id, result in providers_result.items()
+ if result is None or (isinstance(result, IndexingError) and not result.retryable)
+ ]
+
+ try:
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/documents/',
+ json={
+ 'files': to_delete_files_db_ids,
+ 'content_providers': to_delete_provider_db_ids,
+ },
+ )
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error reporting indexing results, will retry:', exc_info=e)
+ sleep(5)
+ with suppress(Exception):
+ nc = NextcloudApp()
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/documents/',
+ json={
+ 'files': to_delete_files_db_ids,
+ 'content_providers': to_delete_provider_db_ids,
+ },
+ )
+ continue
+ except Exception as e:
+ LOGGER.exception('Error reporting indexing results:', exc_info=e)
+ sleep(5)
+ continue
+
+
+
+def updates_processing_thread(app_config: TConfig, get_enabled_state) -> None:
+ try:
+ vectordb_loader = VectorDBLoader(app_config)
+ except LoaderException as e:
+ LOGGER.error('Error initializing vector DB loader, updates processing thread will not start:', exc_info=e)
+ return
+
+ nc = NextcloudApp()
+ while True:
+ if THREAD_STOP_EVENT.is_set():
+ LOGGER.info('Updates processing thread is stopping due to stop event being set')
+ return
+
+ if not get_enabled_state():
+ LOGGER.info('App is disabled, updates processing thread will sleep until next enabled state check')
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ try:
+ q_items_res = nc.ocs(
+ 'GET',
+ '/ocs/v2.php/apps/context_chat/queues/actions',
+ params={ 'n': app_config.actions_batch_size }
+ )
+
+ try:
+ q_items: ActionsQueueItems = ActionsQueueItems.model_validate(q_items_res)
+ except ValidationError as e:
+ raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error fetching updates to process, will retry:', exc_info=e)
+ sleep(5)
+ continue
+ except Exception as e:
+ LOGGER.exception('Error fetching updates to process:', exc_info=e)
+ sleep(5)
+ continue
+
+ if not q_items.actions:
+ LOGGER.debug('No updates to process')
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ processed_event_ids = []
+ errored_events = {}
+ for i, (db_id, action_item) in enumerate(q_items.actions.items()):
+ try:
+ match action_item.type:
+ case ActionType.DELETE_SOURCE_IDS:
+ exec_in_proc(target=delete_by_source, args=(vectordb_loader, action_item.payload.sourceIds))
+
+ case ActionType.DELETE_PROVIDER_ID:
+ exec_in_proc(target=delete_by_provider, args=(vectordb_loader, action_item.payload.providerId))
+
+ case ActionType.DELETE_USER_ID:
+ exec_in_proc(target=delete_user, args=(vectordb_loader, action_item.payload.userId))
+
+ case ActionType.UPDATE_ACCESS_SOURCE_ID:
+ exec_in_proc(
+ target=update_access,
+ args=(
+ vectordb_loader,
+ action_item.payload.op,
+ action_item.payload.userIds,
+ action_item.payload.sourceId,
+ ),
+ )
+
+ case ActionType.UPDATE_ACCESS_PROVIDER_ID:
+ exec_in_proc(
+ target=update_access_provider,
+ args=(
+ vectordb_loader,
+ action_item.payload.op,
+ action_item.payload.userIds,
+ action_item.payload.providerId,
+ ),
+ )
+
+ case ActionType.UPDATE_ACCESS_DECL_SOURCE_ID:
+ exec_in_proc(
+ target=decl_update_access,
+ args=(
+ vectordb_loader,
+ action_item.payload.userIds,
+ action_item.payload.sourceId,
+ ),
+ )
+
+ case _:
+ LOGGER.warning(
+ f'Unknown action type {action_item.type} for action id {db_id},'
+ f' type {action_item.type}, skipping and marking as processed',
+ extra={ 'action_item': action_item },
+ )
+ continue
+
+ processed_event_ids.append(db_id)
+ except SafeDbException as e:
+ LOGGER.debug(
+ f'Safe DB error thrown while processing action id {db_id}, type {action_item.type},'
+ " it's safe to ignore and mark as processed.",
+ exc_info=e,
+ extra={ 'action_item': action_item },
+ )
+ processed_event_ids.append(db_id)
+ continue
+
+ except (LoaderException, DbException) as e:
+ LOGGER.error(
+ f'Error deleting source for action id {db_id}, type {action_item.type}: {e}',
+ exc_info=e,
+ extra={ 'action_item': action_item },
+ )
+ errored_events[db_id] = str(e)
+ continue
+
+ except Exception as e:
+ LOGGER.error(
+ f'Unexpected error processing action id {db_id}, type {action_item.type}: {e}',
+ exc_info=e,
+ extra={ 'action_item': action_item },
+ )
+ errored_events[db_id] = f'Unexpected error: {e}'
+ continue
+
+ if (i + 1) % 20 == 0:
+ LOGGER.debug(f'Processed {i + 1} updates, sleeping for a bit to allow other operations to proceed')
+ sleep(2)
+
+ LOGGER.info(f'Processed {len(processed_event_ids)} updates with {len(errored_events)} errors', extra={
+ 'errored_events': errored_events,
+ })
+
+ if len(processed_event_ids) == 0:
+ LOGGER.debug('No updates processed, skipping reporting to the server')
+ continue
+
+ try:
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/actions/',
+ json={ 'actions': processed_event_ids },
+ )
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error reporting processed updates, will retry:', exc_info=e)
+ sleep(5)
+ with suppress(Exception):
+ nc = NextcloudApp()
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/actions/',
+ json={ 'ids': processed_event_ids },
+ )
+ continue
+ except Exception as e:
+ LOGGER.exception('Error reporting processed updates:', exc_info=e)
+ sleep(5)
+ continue
+
+
+def resolve_scope_list(source_ids: list[str], userId: str) -> list[str]:
+ """
+
+ Parameters
+ ----------
+ source_ids
+
+ Returns
+ -------
+ source_ids with only files, no folders (or source_ids in case of non-file provider)
+ """
+ nc = NextcloudApp()
+ data = nc.ocs('POST', '/ocs/v2.php/apps/context_chat/resolve_scope_list', json={
+ 'source_ids': source_ids,
+ 'userId': userId,
+ })
+ return ScopeList.model_validate(data).source_ids
+
+
+def request_processing_thread(app_config: TConfig, get_enabled_state) -> None:
+ LOGGER.info('Starting request processing thread')
+
+ try:
+ network_em = NetworkEmbeddings(app_config)
+ vectordb_loader = VectorDBLoader(app_config)
+ llm_loader = LLMModelLoader(app_config)
+ except LoaderException as e:
+ LOGGER.error('Error initializing vector DB loader, request processing thread will not start:', exc_info=e)
+ return
+
+ nc = NextcloudApp()
+ llm: LLM = llm_loader.load()
+
+ while True:
+ if THREAD_STOP_EVENT.is_set():
+ LOGGER.info('Request processing thread is stopping due to stop event being set')
+ return
+
+ if not get_enabled_state():
+ LOGGER.info('App is disabled, request processing thread will sleep until next enabled state check')
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ try:
+ if not network_em.check_connection(ThreadType.REQUEST_PROCESSING.value):
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ # Fetch pending task
+ try:
+ response = nc.providers.task_processing.next_task(
+ ['context_chat-context_chat', 'context_chat-context_chat_search'],
+ ['context_chat:context_chat', 'context_chat:context_chat_search'],
+ )
+ if not response:
+ wait_for_tasks()
+ continue
+ except (NextcloudException, RequestException, JSONDecodeError) as e:
+ LOGGER.error(f"Network error fetching the next task {e}", exc_info=e)
+ wait_for_tasks(TP_CHECK_INTERVAL_ON_ERROR)
+ continue
+
+ # Process task
+ task = response["task"]
+ userId = task['userId']
+
+ try:
+ LOGGER.debug(f'Processing task {task["id"]}')
+
+ if task['input'].get('scopeType') == 'source':
+ # Resolve scope list to only files, no folders
+ task['input']['scopeList'] = resolve_scope_list(task['input'].get('scopeList'), userId)
+
+ if task['type'] == 'context_chat:context_chat':
+ result: LLMOutput = process_normal_task(task, vectordb_loader, llm, app_config)
+ # Return result to Nextcloud
+ success = return_result_to_nextcloud(task['id'], userId, {
+ 'output': result['output'],
+ 'sources': enrich_sources(result['sources'], userId),
+ })
+ elif task['type'] == 'context_chat:context_chat_search':
+ search_result: list[SearchResult] = process_search_task(task, vectordb_loader)
+ # Return result to Nextcloud
+ success = return_result_to_nextcloud(task['id'], userId, {
+ 'sources': enrich_sources(search_result, userId),
+ })
+ else:
+ LOGGER.error(f'Unknown task type {task["type"]}')
+ success = return_error_to_nextcloud(task['id'], Exception(f'Unknown task type {task["type"]}'))
+
+ if success:
+ LOGGER.info(f'Task {task["id"]} completed successfully')
+ else:
+ LOGGER.error(f'Failed to return result for task {task["id"]}')
+
+ except EmbeddingException as e:
+ LOGGER.warning(f'Embedding server error for task {task["id"]}: {e}')
+ return_error_to_nextcloud(task['id'], e)
+ except ContextException as e:
+ LOGGER.warning(f'Context error for task {task["id"]}: {e}')
+ return_error_to_nextcloud(task['id'], e)
+ except ValueError as e:
+ LOGGER.warning(f'Validation error for task {task["id"]}: {e}')
+ return_error_to_nextcloud(task['id'], e)
+ except Exception as e:
+ LOGGER.exception(f'Unexpected error processing task {task["id"]}', exc_info=e)
+ return_error_to_nextcloud(task['id'], e)
+
+ except Exception as e:
+ LOGGER.exception('Error in task fetcher loop', exc_info=e)
+ wait_for_tasks(TP_CHECK_INTERVAL_ON_ERROR)
+
+def trigger_handler(provider_id: str):
+ global TP_TRIGGER
+ LOGGER.debug('Task processing trigger received', extra={'provider_id': provider_id})
+ TP_TRIGGER.set()
+
+def wait_for_tasks(interval = None):
+ global TP_TRIGGER
+ global TP_CHECK_INTERVAL
+ global TP_CHECK_INTERVAL_WITH_TRIGGER
+ actual_interval = TP_CHECK_INTERVAL if interval is None else interval
+ if TP_TRIGGER.wait(timeout=actual_interval):
+ TP_CHECK_INTERVAL = TP_CHECK_INTERVAL_WITH_TRIGGER
+ TP_TRIGGER.clear()
+
+
+def enrich_sources(results: list[SearchResult], userId: str) -> list[str]:
+ nc = NextcloudApp()
+ data = nc.ocs('POST', '/ocs/v2.php/apps/context_chat/enrich_sources', json={'sources': results, 'userId': userId})
+ sources = EnrichedSourceList.model_validate(data).sources
+ return [s.model_dump_json() for s in sources]
+
+
+def return_result_to_nextcloud(task_id: int, userId: str, result: dict[str, Any]) -> bool:
+ """
+ Return query result back to Nextcloud.
+
+ Args:
+ result: dict[str, Any]
+
+ Returns:
+ True if successful, False otherwise
+ """
+ LOGGER.debug('Returning result to Nextcloud', extra={
+ 'task_id': task_id,
+ 'result': result,
+ })
+
+ nc = NextcloudApp()
+
+ try:
+ nc.providers.task_processing.report_result(task_id, result)
+ except (NextcloudException, RequestException, JSONDecodeError) as e:
+ LOGGER.error(f"Network error reporting task result {e}", exc_info=e)
+ return False
+
+ return True
+
+
+def return_error_to_nextcloud(task_id: int, e: Exception) -> bool:
+ """
+ Return error result back to Nextcloud.
+
+ Args:
+ task_id: Unique task identifier
+ e: error object
+
+ Returns:
+ True if successful, False otherwise
+ """
+ LOGGER.debug('Returning error to Nextcloud', exc_info=e)
+
+ nc = NextcloudApp()
+
+ if isinstance(e, ValueError):
+ message = "Validation error: " + str(e)
+ elif isinstance(e, ContextException):
+ message = "Context error" + str(e)
+ else:
+ message = "Unexpected error" + str(e)
+
+ try:
+ nc.providers.task_processing.report_result(task_id, None, message)
+ except (NextcloudException, RequestException, JSONDecodeError) as e:
+ LOGGER.error(f"Network error reporting task result {e}", exc_info=e)
+ return False
+
+ return True
+
+
+def process_normal_task(
+ task: dict[str, Any],
+ vectordb_loader: VectorDBLoader,
+ llm: LLM,
+ app_config: TConfig,
+) -> LLMOutput:
+ """
+ Process a single query task.
+
+ Args:
+ task: Task dictionary from fetch_query_tasks_from_nextcloud
+ vectordb_loader: Vector database loader instance
+ llm: Language model instance
+ app_config: Application configuration
+
+ Returns:
+ LLMOutput with generated text and sources
+
+ Raises:
+ Various exceptions from query execution
+ """
+ user_id = task['userId']
+ task_input = task['input']
+ if task_input.get('scopeType') == 'none':
+ task_input['scopeType'] = None
+
+ return exec_in_proc(target=process_context_query,
+ args=(
+ user_id,
+ vectordb_loader,
+ llm,
+ app_config,
+ task_input.get('prompt'),
+ CONTEXT_LIMIT,
+ task_input.get('scopeType'),
+ task_input.get('scopeList'),
+ app_config.llm[1].get('template'),
+ )
+ )
+
+def process_search_task(
+ task: dict[str, Any],
+ vectordb_loader: VectorDBLoader,
+) -> list[SearchResult]:
+ """
+ Process a single search task.
+
+ Args:
+ task: Task dictionary from fetch_query_tasks_from_nextcloud
+ vectordb_loader: Vector database loader instance
+
+ Returns:
+ list of Search results
+
+ Raises:
+ Various exceptions from query execution
+ """
+ user_id = task['userId']
+ task_input = task['input']
+ if task_input.get('scopeType') == 'none':
+ task_input['scopeType'] = None
+
+ return exec_in_proc(target=do_doc_search,
+ args=(
+ user_id,
+ task_input.get('prompt'),
+ vectordb_loader,
+ CONTEXT_LIMIT,
+ task_input.get('scopeType'),
+ task_input.get('scopeList'),
+ )
+ )
+
+
+def start_bg_threads(app_config: TConfig, get_enabled_state):
+ THREAD_STOP_EVENT.clear()
+
+ if APP_ROLE == AppRole.INDEXING or APP_ROLE == AppRole.NORMAL:
+ if ThreadType.FILES_INDEXING in THREADS:
+ LOGGER.info('Indexing background threads are already up, skipping start')
+ return
+
+ THREADS[ThreadType.FILES_INDEXING] = Thread(
+ target=files_indexing_thread,
+ args=(app_config,get_enabled_state),
+ name='FilesIndexingThread',
+ )
+ THREADS[ThreadType.FILES_INDEXING].start()
+
+ if APP_ROLE == AppRole.UPDATES_PROC or APP_ROLE == AppRole.NORMAL:
+ if ThreadType.UPDATES_PROCESSING in THREADS:
+ LOGGER.info('Updates processing background threads are already up, skipping start')
+ return
+
+ THREADS[ThreadType.UPDATES_PROCESSING] = Thread(
+ target=updates_processing_thread,
+ args=(app_config,get_enabled_state),
+ name='UpdatesProcessingThread',
+ )
+ THREADS[ThreadType.UPDATES_PROCESSING].start()
+
+ if APP_ROLE == AppRole.REQUEST_PROC or APP_ROLE == AppRole.NORMAL:
+ if ThreadType.REQUEST_PROCESSING in THREADS:
+ LOGGER.info('Request processing background threads are already up, skipping start')
+ return
+
+ THREADS[ThreadType.REQUEST_PROCESSING] = Thread(
+ target=request_processing_thread,
+ args=(app_config,get_enabled_state),
+ name='RequestProcessingThread',
+ )
+ THREADS[ThreadType.REQUEST_PROCESSING].start()
+
+
+def wait_for_bg_threads():
+ THREAD_STOP_EVENT.set()
+
+ if APP_ROLE == AppRole.INDEXING or APP_ROLE == AppRole.NORMAL:
+ if ThreadType.FILES_INDEXING not in THREADS:
+ return
+
+ THREADS[ThreadType.FILES_INDEXING].join()
+ THREADS.pop(ThreadType.FILES_INDEXING)
+
+ if APP_ROLE == AppRole.UPDATES_PROC or APP_ROLE == AppRole.NORMAL:
+ if ThreadType.UPDATES_PROCESSING not in THREADS:
+ return
+
+ THREADS[ThreadType.UPDATES_PROCESSING].join()
+ THREADS.pop(ThreadType.UPDATES_PROCESSING)
+
+ if APP_ROLE == AppRole.REQUEST_PROC or APP_ROLE == AppRole.NORMAL:
+ if (ThreadType.REQUEST_PROCESSING not in THREADS):
+ return
+
+ THREADS[ThreadType.REQUEST_PROCESSING].join()
+ THREADS.pop(ThreadType.REQUEST_PROCESSING)
diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py
index 500a97d0..700d7ddf 100644
--- a/context_chat_backend/types.py
+++ b/context_chat_backend/types.py
@@ -2,7 +2,17 @@
# SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
-from pydantic import BaseModel
+import re
+from collections.abc import Mapping
+from enum import Enum
+from io import BytesIO
+from typing import Annotated, Literal, Self
+
+import niquests
+from pydantic import AfterValidator, BaseModel, Discriminator, computed_field, field_validator, model_validator
+
+from .mimetype_list import SUPPORTED_MIMETYPES
+from .vectordb.types import UpdateAccessOp
__all__ = [
'DEFAULT_EM_MODEL_ALIAS',
@@ -15,6 +25,65 @@
]
DEFAULT_EM_MODEL_ALIAS = 'em_model'
+FILES_PROVIDER_ID = 'files__default'
+
+
+def is_valid_source_id(source_id: str) -> bool:
+ # note the ":" in the item id part
+ return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None
+
+
+def is_valid_provider_id(provider_id: str) -> bool:
+ return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None
+
+
+def _validate_source_ids(source_ids: list[str]) -> list[str]:
+ if (
+ not isinstance(source_ids, list)
+ or not all(isinstance(sid, str) and sid.strip() != '' for sid in source_ids)
+ or len(source_ids) == 0
+ ):
+ raise ValueError('sourceIds must be a non-empty list of non-empty strings')
+ return [sid.strip() for sid in source_ids]
+
+
+def _validate_source_id(source_id: str) -> str:
+ return _validate_source_ids([source_id])[0]
+
+
+def _validate_provider_id(provider_id: str) -> str:
+ if not isinstance(provider_id, str) or not is_valid_provider_id(provider_id):
+ raise ValueError('providerId must be a valid provider ID string')
+ return provider_id
+
+
+def _validate_user_ids(user_ids: list[str]) -> list[str]:
+ if (
+ not isinstance(user_ids, list)
+ or not all(isinstance(uid, str) and uid.strip() != '' for uid in user_ids)
+ or len(user_ids) == 0
+ ):
+ raise ValueError('userIds must be a non-empty list of non-empty strings')
+ return [uid.strip() for uid in user_ids]
+
+
+def _validate_user_id(user_id: str) -> str:
+ return _validate_user_ids([user_id])[0]
+
+
+def _get_file_id_from_source_ref(source_ref: str) -> int:
+ '''
+ source reference is in the format "FILES_PROVIDER_ID: ".
+ '''
+ if not source_ref.startswith(f'{FILES_PROVIDER_ID}: '):
+ raise ValueError(f'Source reference does not start with expected prefix: {source_ref}')
+
+ try:
+ return int(source_ref[len(f'{FILES_PROVIDER_ID}: '):])
+ except ValueError as e:
+ raise ValueError(
+ f'Invalid source reference format for extracting file_id: {source_ref}'
+ ) from e
class TEmbeddingAuthApiKey(BaseModel):
@@ -36,14 +105,17 @@ class TEmbeddingConfig(BaseModel):
class TConfig(BaseModel):
- debug: bool
- uvicorn_log_level: str
- disable_aaa: bool
- verify_ssl: bool
- use_colors: bool
- uvicorn_workers: int
- embedding_chunk_size: int
- doc_parser_worker_limit: int
+ debug: bool = False
+ uvicorn_log_level: str = 'info'
+ disable_aaa: bool = False
+ verify_ssl: bool = True
+ use_colors: bool = True
+ uvicorn_workers: int = 1
+ embedding_chunk_size: int = 2000
+ doc_indexing_batch_size: int = 32
+ actions_batch_size: int = 512
+ file_parsing_cpu_count: int = -1
+ concurrent_file_fetches: int = 10
vectordb: tuple[str, dict]
embedding: TEmbeddingConfig
@@ -55,7 +127,9 @@ class LoaderException(Exception):
class EmbeddingException(Exception):
- ...
+ def __init__(self, msg: str, response: niquests.Response | None = None):
+ super().__init__(msg)
+ self.response = response
class RetryableEmbeddingException(EmbeddingException):
"""
@@ -71,3 +145,215 @@ class FatalEmbeddingException(EmbeddingException):
Either malformed request, authentication error, or other non-retryable error.
"""
+
+class DocErrorEmbeddingException(EmbeddingException):
+ """
+ Exception that indicates a fatal error for the document, this document should not be retried.
+ """
+
+
+class AppRole(str, Enum):
+ NORMAL = 'normal'
+ INDEXING = 'indexing'
+ REQUEST_PROC = 'requestproc'
+ UPDATES_PROC = 'updatesproc'
+
+
+class CommonSourceItem(BaseModel):
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ # source_id of the form "appId__providerId: itemId"
+ reference: Annotated[str, AfterValidator(_validate_source_id)]
+ title: str
+ modified: int
+ type: str
+ provider: Annotated[str, AfterValidator(_validate_provider_id)]
+ size: float
+
+ @field_validator('modified', mode='before')
+ @classmethod
+ def validate_modified(cls, v):
+ if isinstance(v, int):
+ return v
+ if isinstance(v, str):
+ try:
+ return int(v)
+ except ValueError as e:
+ raise ValueError(f'Invalid modified value: {v}') from e
+ raise ValueError(f'Invalid modified type: {type(v)}')
+
+ @field_validator('reference', 'title', 'type', 'provider')
+ @classmethod
+ def validate_strings_non_empty(cls, v):
+ if not isinstance(v, str) or v.strip() == '':
+ raise ValueError('Must be a non-empty string')
+ return v.strip()
+
+ @field_validator('size')
+ @classmethod
+ def validate_size(cls, v):
+ if isinstance(v, int | float) and v >= 0:
+ return float(v)
+ raise ValueError(f'Invalid size value: {v}, must be a non-negative number')
+
+ @model_validator(mode='after')
+ def validate_type(self) -> Self:
+ if self.reference.startswith(FILES_PROVIDER_ID) and self.type not in SUPPORTED_MIMETYPES:
+ raise ValueError(f'Unsupported file type: {self.type} for reference {self.reference}')
+ return self
+
+
+class ReceivedFileItem(CommonSourceItem):
+ content: None
+
+ @computed_field
+ @property
+ def file_id(self) -> int:
+ return _get_file_id_from_source_ref(self.reference)
+
+
+class SourceItem(CommonSourceItem):
+ '''
+ Used for the unified queue of items to process, after fetching the content for files
+ and for directly fetched content providers.
+ '''
+ content: str | BytesIO
+
+ @field_validator('content')
+ @classmethod
+ def validate_content(cls, v):
+ if isinstance(v, str):
+ if v.strip() == '':
+ raise ValueError('Content must be a non-empty string')
+ return v.strip()
+ if isinstance(v, BytesIO):
+ if v.getbuffer().nbytes == 0:
+ raise ValueError('Content must be a non-empty BytesIO')
+ return v
+ raise ValueError('Content must be either a non-empty string or a non-empty BytesIO')
+
+ class Config:
+ # to allow BytesIO in content field
+ arbitrary_types_allowed = True
+
+
+class FilesQueueItems(BaseModel):
+ files: Mapping[int, ReceivedFileItem] # [db id]: FileItem
+ content_providers: Mapping[int, SourceItem] # [db id]: SourceItem
+
+
+class IndexingException(Exception):
+ retryable: bool = False
+
+ def __init__(self, message: str, retryable: bool = False):
+ super().__init__(message)
+ self.retryable = retryable
+
+
+class IndexingError(BaseModel):
+ error: str
+ retryable: bool = False
+
+
+# PHP equivalent for reference:
+
+# class ActionType {
+# // { sourceIds: array }
+# public const DELETE_SOURCE_IDS = 'delete_source_ids';
+# // { providerId: string }
+# public const DELETE_PROVIDER_ID = 'delete_provider_id';
+# // { userId: string }
+# public const DELETE_USER_ID = 'delete_user_id';
+# // { op: string, userIds: array, sourceId: string }
+# public const UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id';
+# // { op: string, userIds: array, providerId: string }
+# public const UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id';
+# // { userIds: array, sourceId: string }
+# public const UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id';
+# }
+
+
+class ActionPayloadDeleteSourceIds(BaseModel):
+ sourceIds: Annotated[list[str], AfterValidator(_validate_source_ids)]
+
+
+class ActionPayloadDeleteProviderId(BaseModel):
+ providerId: Annotated[str, AfterValidator(_validate_provider_id)]
+
+
+class ActionPayloadDeleteUserId(BaseModel):
+ userId: Annotated[str, AfterValidator(_validate_user_id)]
+
+
+class ActionPayloadUpdateAccessSourceId(BaseModel):
+ op: UpdateAccessOp
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ sourceId: Annotated[str, AfterValidator(_validate_source_id)]
+
+
+class ActionPayloadUpdateAccessProviderId(BaseModel):
+ op: UpdateAccessOp
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ providerId: Annotated[str, AfterValidator(_validate_provider_id)]
+
+
+class ActionPayloadUpdateAccessDeclSourceId(BaseModel):
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ sourceId: Annotated[str, AfterValidator(_validate_source_id)]
+
+
+class ActionType(str, Enum):
+ DELETE_SOURCE_IDS = 'delete_source_ids'
+ DELETE_PROVIDER_ID = 'delete_provider_id'
+ DELETE_USER_ID = 'delete_user_id'
+ UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id'
+ UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id'
+ UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id'
+
+
+class CommonActionsQueueItem(BaseModel):
+ id: int
+
+
+class ActionsQueueItemDeleteSourceIds(CommonActionsQueueItem):
+ type: Literal[ActionType.DELETE_SOURCE_IDS]
+ payload: ActionPayloadDeleteSourceIds
+
+
+class ActionsQueueItemDeleteProviderId(CommonActionsQueueItem):
+ type: Literal[ActionType.DELETE_PROVIDER_ID]
+ payload: ActionPayloadDeleteProviderId
+
+
+class ActionsQueueItemDeleteUserId(CommonActionsQueueItem):
+ type: Literal[ActionType.DELETE_USER_ID]
+ payload: ActionPayloadDeleteUserId
+
+
+class ActionsQueueItemUpdateAccessSourceId(CommonActionsQueueItem):
+ type: Literal[ActionType.UPDATE_ACCESS_SOURCE_ID]
+ payload: ActionPayloadUpdateAccessSourceId
+
+
+class ActionsQueueItemUpdateAccessProviderId(CommonActionsQueueItem):
+ type: Literal[ActionType.UPDATE_ACCESS_PROVIDER_ID]
+ payload: ActionPayloadUpdateAccessProviderId
+
+
+class ActionsQueueItemUpdateAccessDeclSourceId(CommonActionsQueueItem):
+ type: Literal[ActionType.UPDATE_ACCESS_DECL_SOURCE_ID]
+ payload: ActionPayloadUpdateAccessDeclSourceId
+
+
+ActionsQueueItem = Annotated[
+ ActionsQueueItemDeleteSourceIds
+ | ActionsQueueItemDeleteProviderId
+ | ActionsQueueItemDeleteUserId
+ | ActionsQueueItemUpdateAccessSourceId
+ | ActionsQueueItemUpdateAccessProviderId
+ | ActionsQueueItemUpdateAccessDeclSourceId,
+ Discriminator('type'),
+]
+
+
+class ActionsQueueItems(BaseModel):
+ actions: Mapping[int, ActionsQueueItem]
diff --git a/context_chat_backend/utils.py b/context_chat_backend/utils.py
index f6d6e672..d5727140 100644
--- a/context_chat_backend/utils.py
+++ b/context_chat_backend/utils.py
@@ -2,11 +2,16 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
+import faulthandler
+import io
import logging
import multiprocessing as mp
-import re
+import os
+import signal
+import sys
import traceback
from collections.abc import Callable
+from contextlib import suppress
from functools import partial, wraps
from multiprocessing.connection import Connection
from time import perf_counter_ns
@@ -14,10 +19,11 @@
from fastapi.responses import JSONResponse as FastAPIJSONResponse
-from .types import TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig
+from .types import AppRole, TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig
T = TypeVar('T')
_logger = logging.getLogger('ccb.utils')
+_MAX_STD_CAPTURE_CHARS = 64 * 1024
def not_none(value: T | None) -> TypeGuard[T]:
@@ -69,19 +75,105 @@ def JSONResponse(
return FastAPIJSONResponse(content, status_code, **kwargs)
-def exception_wrap(fun: Callable | None, *args, resconn: Connection, **kwargs):
- try:
- if fun is None:
- return resconn.send({ 'value': None, 'error': None })
- resconn.send({ 'value': fun(*args, **kwargs), 'error': None })
- except Exception as e:
- tb = traceback.format_exc()
- resconn.send({ 'value': None, 'error': e, 'traceback': tb })
+class SubprocessKilledError(RuntimeError):
+ """Raised when a subprocess is terminated by a signal (for example SIGKILL)."""
+
+ def __init__(self, pid: int | None, target_name: str, exitcode: int):
+ super().__init__(
+ f'Subprocess PID {pid} for {target_name} exited with signal {abs(exitcode)} '
+ f'(raw exit code: {exitcode})'
+ )
+ self.exitcode = exitcode
+
+
+class SubprocessExecutionError(RuntimeError):
+ """Raised when a subprocess exits without a recoverable Python exception payload."""
+
+ def __init__(self, pid: int | None, target_name: str, exitcode: int, details: str = ''):
+ msg = f'Subprocess PID {pid} for {target_name} exited with exit code {exitcode}'
+ if details:
+ msg = f'{msg}: {details}'
+ super().__init__(msg)
+ self.exitcode = exitcode
+
+def _truncate_capture(text: str) -> str:
+ if len(text) <= _MAX_STD_CAPTURE_CHARS:
+ return text
-def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daemon=None): # noqa: B006
+ head = _MAX_STD_CAPTURE_CHARS // 2
+ tail = _MAX_STD_CAPTURE_CHARS - head
+ omitted = len(text) - _MAX_STD_CAPTURE_CHARS
+ return (
+ f'[truncated {omitted} chars]\n'
+ f'{text[:head]}\n'
+ '[...snip...]\n'
+ f'{text[-tail:]}'
+ )
+
+
+def exception_wrap(fun: Callable | None, *args, resconn: Connection, stdconn: Connection, **kwargs):
+ # ignore SIGINT and SIGTERM in child processes these signals don't immediately stop these processes
+ # the handling is done in the fastapi lifetime to do a graceful shutdown
+ # SIGKILL is not ignored
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
+ signal.signal(signal.SIGTERM, signal.SIG_IGN)
+
+ # Preserve real stderr FD for faulthandler before we redirect sys.stderr.
+ _faulthandler_fd = os.dup(2)
+ with suppress(Exception):
+ faulthandler.enable(
+ file=os.fdopen(_faulthandler_fd, 'w', closefd=False),
+ all_threads=True,
+ )
+
+ stdout_capture = io.StringIO()
+ stderr_capture = io.StringIO()
+ orig_stdout = sys.stdout
+ orig_stderr = sys.stderr
+ sys.stdout = stdout_capture
+ sys.stderr = stderr_capture
+
+ try:
+ value = None if fun is None else fun(*args, **kwargs)
+ try:
+ resconn.send({ 'value': value, 'error': None })
+ except (BrokenPipeError, OSError, EOFError):
+ ... # parent closed the pipe during shutdown, exit cleanly
+ except BaseException as e:
+ tb = traceback.format_exc()
+ payload = {
+ 'value': None,
+ 'error': e,
+ 'traceback': tb,
+ }
+ try:
+ resconn.send(payload)
+ except Exception as send_err:
+ stderr_capture.write(f'Original error: {e}, pipe send error: {send_err}')
+ finally:
+ sys.stdout = orig_stdout
+ sys.stderr = orig_stderr
+ stdout_text = _truncate_capture(stdout_capture.getvalue())
+ stderr_text = _truncate_capture(stderr_capture.getvalue())
+ with suppress(Exception):
+ stdconn.send({
+ 'stdout': stdout_text,
+ 'stderr': stderr_text,
+ })
+ with suppress(Exception):
+ os.close(_faulthandler_fd)
+
+
+def exec_in_proc(group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None):
+ if not kwargs:
+ kwargs = {}
+
+ # parent, child
pconn, cconn = mp.Pipe()
+ std_pconn, std_cconn = mp.Pipe()
kwargs['resconn'] = cconn
+ kwargs['stdconn'] = std_cconn
p = mp.Process(
group=group,
target=partial(exception_wrap, target),
@@ -90,24 +182,92 @@ def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daem
kwargs=kwargs,
daemon=daemon,
)
+ target_name = getattr(target, '__name__', str(target))
+ start = perf_counter_ns()
p.start()
+ _logger.debug('Subprocess PID %d started for %s', p.pid, target_name)
+
+ result = None
+ stdobj = { 'stdout': '', 'stderr': '' }
+ got_result = False
+ got_std = False
+
+ # Drain result/std pipes while child is still alive to avoid deadlock on full pipe buffers.
+ # Pipe's buffer size is 64 KiB
+ while p.is_alive() and (not got_result or not got_std):
+ if not got_result and pconn.poll(0.1):
+ with suppress(EOFError, OSError, BrokenPipeError):
+ result = pconn.recv()
+ got_result = True
+ if not got_std and std_pconn.poll():
+ with suppress(EOFError, OSError, BrokenPipeError):
+ stdobj = std_pconn.recv()
+ got_std = True
+
p.join()
+ elapsed_ms = (perf_counter_ns() - start) / 1e6
+ _logger.debug(
+ 'Subprocess PID %d for %s finished in %.2f ms (exit code: %s)',
+ p.pid, target_name, elapsed_ms, p.exitcode,
+ )
- result = pconn.recv()
- if result['error'] is not None:
- _logger.error('original traceback: %s', result['traceback'])
+ if not got_std:
+ with suppress(EOFError, OSError, BrokenPipeError):
+ if std_pconn.poll():
+ stdobj = std_pconn.recv()
+ # no need to update got_std here
+ if stdobj.get('stdout') or stdobj.get('stderr'):
+ _logger.info('std info for %s', target_name, extra={
+ 'stdout': stdobj.get('stdout', ''),
+ 'stderr': stdobj.get('stderr', ''),
+ })
+
+ if not got_result:
+ with suppress(EOFError, OSError, BrokenPipeError):
+ if pconn.poll():
+ result = pconn.recv()
+ # no need to update got_result here
+
+ if result is not None and result.get('error') is not None:
+ _logger.error(
+ 'original traceback of %s (PID %d, exitcode: %s): %s',
+ target_name,
+ p.pid,
+ p.exitcode,
+ result.get('traceback', ''),
+ )
raise result['error']
- return result['value']
-
-
-def is_valid_source_id(source_id: str) -> bool:
- # note the ":" in the item id part
- return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None
+ if result is not None and 'value' in result:
+ if p.exitcode not in (None, 0):
+ _logger.warning(
+ 'Subprocess PID %d for %s exited with code %s after %.2f ms'
+ ' but returned a valid result',
+ p.pid, target_name, p.exitcode, elapsed_ms,
+ )
+ return result['value']
+ if p.exitcode and p.exitcode < 0:
+ _logger.warning(
+ 'Subprocess PID %d for %s exited due to signal %d, exitcode %d after %.2f ms',
+ p.pid, target_name, abs(p.exitcode), p.exitcode, elapsed_ms,
+ )
+ raise SubprocessKilledError(p.pid, target_name, p.exitcode)
+
+ if p.exitcode not in (None, 0):
+ raise SubprocessExecutionError(
+ p.pid,
+ target_name,
+ p.exitcode,
+ f'No structured exception payload received from child process: {result}',
+ )
-def is_valid_provider_id(provider_id: str) -> bool:
- return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None
+ raise SubprocessExecutionError(
+ p.pid,
+ target_name,
+ 0,
+ f'Subprocess exited successfully but returned no result payload: {result}',
+ )
def timed(func: Callable):
@@ -144,3 +304,19 @@ def redact_config(config: TConfig | TEmbeddingConfig) -> TConfig | TEmbeddingCon
em_conf.auth.password = '***REDACTED***' # noqa: S105
return config_copy
+
+
+def get_app_role() -> AppRole:
+ role = os.getenv('APP_ROLE', '').lower()
+ if role == '':
+ return AppRole.NORMAL
+ try:
+ return AppRole(role)
+ except ValueError:
+ _logger.warning(f'Invalid app role: {role}, defaulting to all roles')
+ return AppRole.NORMAL
+
+
+def is_k8s_env():
+ role = get_app_role()
+ return role != AppRole.NORMAL
diff --git a/context_chat_backend/vectordb/base.py b/context_chat_backend/vectordb/base.py
index 0bf10200..2b4aa35e 100644
--- a/context_chat_backend/vectordb/base.py
+++ b/context_chat_backend/vectordb/base.py
@@ -3,14 +3,15 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
#
from abc import ABC, abstractmethod
+from collections.abc import Mapping
from typing import Any
-from fastapi import UploadFile
from langchain.schema import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from ..chain.types import InDocument, ScopeType
+from ..types import IndexingError, ReceivedFileItem, SourceItem
from ..utils import timed
from .types import UpdateAccessOp
@@ -62,7 +63,7 @@ def get_instance(self) -> VectorStore:
'''
@abstractmethod
- def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list[str]]:
+ def add_indocuments(self, indocuments: Mapping[int, InDocument]) -> Mapping[int, IndexingError | None]:
'''
Adds the given indocuments to the vectordb and updates the docs + access tables.
@@ -79,10 +80,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list
@timed
@abstractmethod
- def check_sources(
- self,
- sources: list[UploadFile],
- ) -> tuple[list[str], list[str]]:
+ def check_sources(self, sources: Mapping[int, SourceItem | ReceivedFileItem]) -> tuple[list[str], list[str]]:
'''
Checks the sources in the vectordb if they are already embedded
and are up to date.
@@ -91,8 +89,8 @@ def check_sources(
Args
----
- sources: list[UploadFile]
- List of source ids to check.
+ sources: Mapping[int, SourceItem | ReceivedFileItem]
+ Dict of sources to check.
Returns
-------
diff --git a/context_chat_backend/vectordb/pgvector.py b/context_chat_backend/vectordb/pgvector.py
index 2b7fc060..4b820cd3 100644
--- a/context_chat_backend/vectordb/pgvector.py
+++ b/context_chat_backend/vectordb/pgvector.py
@@ -4,21 +4,30 @@
#
import logging
import os
+from collections.abc import Mapping
from datetime import datetime
+from time import perf_counter_ns
import psycopg
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as postgresql_dialects
import sqlalchemy.orm as orm
from dotenv import load_dotenv
-from fastapi import UploadFile
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from langchain_core.embeddings import Embeddings
from langchain_postgres.vectorstores import Base, PGVector
from ..chain.types import InDocument, ScopeType
-from ..types import EmbeddingException, RetryableEmbeddingException
+from ..types import (
+ DocErrorEmbeddingException,
+ EmbeddingException,
+ FatalEmbeddingException,
+ IndexingError,
+ ReceivedFileItem,
+ RetryableEmbeddingException,
+ SourceItem,
+)
from ..utils import timed
from .base import BaseVectorDB
from .types import DbException, SafeDbException, UpdateAccessOp
@@ -112,7 +121,15 @@ def __init__(self, embedding: Embeddings | None = None, **kwargs):
kwargs['connection'] = os.environ['CCB_DB_URL']
# setup langchain db + our access list table
- self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs)
+ try:
+ self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs)
+ except sa.exc.IntegrityError as ie: # pyright: ignore[reportAttributeAccessIssue]
+ if not isinstance(ie.orig, psycopg.errors.UniqueViolation):
+ raise
+
+ # tried to create the tables but it was already created in another process
+ # init the client again to detect it already exists, and continue from there
+ self.client = PGVector(embedding, collection_name=COLLECTION_NAME, **kwargs)
def get_instance(self) -> VectorStore:
return self.client
@@ -130,24 +147,40 @@ def get_users(self) -> list[str]:
except Exception as e:
raise DbException('Error: getting a list of all users from access list') from e
- def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], list[str]]:
+ def add_indocuments(self, indocuments: Mapping[int, InDocument]) -> Mapping[int, IndexingError | None]:
"""
Raises
EmbeddingException: if the embedding request definitively fails
"""
- added_sources = []
- retry_sources = []
+ results = {}
batch_size = PG_BATCH_SIZE // 5
with self.session_maker() as session:
- for indoc in indocuments:
+ for php_db_id, indoc in indocuments.items():
try:
# query paramerters limitation in postgres is 65535 (https://www.postgresql.org/docs/current/limits.html)
# so we chunk the documents into (5 values * 10k) chunks
# change the chunk size when there are more inserted values per document
chunk_ids = []
- for i in range(0, len(indoc.documents), batch_size):
+ total_chunks = len(indoc.documents)
+ num_batches = max(1, -(-total_chunks // batch_size)) # ceiling division
+ logger.debug(
+ 'Embedding source %s: %d chunk(s) in %d batch(es)',
+ indoc.source_id, total_chunks, num_batches,
+ )
+ for i in range(0, total_chunks, batch_size):
+ batch_num = i // batch_size + 1
+ logger.debug(
+ 'Sending embedding batch %d/%d (%d chunk(s)) for source %s',
+ batch_num, num_batches, len(indoc.documents[i:i+batch_size]), indoc.source_id,
+ )
+ t0 = perf_counter_ns()
chunk_ids.extend(self.client.add_documents(indoc.documents[i:i+batch_size]))
+ elapsed_ms = (perf_counter_ns() - t0) / 1e6
+ logger.debug(
+ 'Embedding batch %d/%d for source %s completed in %.2f ms',
+ batch_num, num_batches, indoc.source_id, elapsed_ms,
+ )
doc = DocumentsStore(
source_id=indoc.source_id,
@@ -170,7 +203,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis
)
self.decl_update_access(indoc.userIds, indoc.source_id, session)
- added_sources.append(indoc.source_id)
+ results[php_db_id] = None
session.commit()
except SafeDbException as e:
# for when the source_id is not found. This here can be an error in the DB
@@ -178,51 +211,73 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis
logger.exception('Error adding documents to vectordb', exc_info=e, extra={
'source_id': indoc.source_id,
})
- retry_sources.append(indoc.source_id)
+ results[php_db_id] = IndexingError(
+ error=str(e),
+ retryable=True,
+ )
+ continue
+ except DocErrorEmbeddingException as e:
+ logger.warning(
+ 'Error adding documents to vectordb, server failed to index it, it will not be retried',
+ exc_info=e,
+ extra={ 'source_id': indoc.source_id },
+ )
+ results[php_db_id] = IndexingError(
+ error=str(e),
+ retryable=False,
+ )
continue
- except RetryableEmbeddingException as e:
+ except FatalEmbeddingException as e:
+ raise EmbeddingException(
+ f'Fatal error while embedding documents for source {indoc.source_id}: {e}'
+ ) from e
+ except (RetryableEmbeddingException, EmbeddingException) as e:
# temporary error, continue with the next document
- logger.exception('Error adding documents to vectordb, should be retried later.', exc_info=e, extra={
+ logger.warning('Error adding documents to vectordb, should be retried later.', exc_info=e, extra={
'source_id': indoc.source_id,
})
- retry_sources.append(indoc.source_id)
+ results[php_db_id] = IndexingError(
+ error=str(e),
+ retryable=True,
+ )
continue
- except EmbeddingException as e:
- logger.exception('Error adding documents to vectordb', exc_info=e, extra={
- 'source_id': indoc.source_id,
- })
- raise
except Exception as e:
logger.exception('Error adding documents to vectordb', exc_info=e, extra={
'source_id': indoc.source_id,
})
- retry_sources.append(indoc.source_id)
+ results[php_db_id] = IndexingError(
+ error='An unexpected error occurred while adding documents to the database.',
+ retryable=True,
+ )
continue
- return added_sources, retry_sources
+ return results
@timed
- def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str]]:
+ def check_sources(self, sources: Mapping[int, SourceItem | ReceivedFileItem]) -> tuple[list[str], list[str]]:
+ '''
+ returns a tuple of (existing_source_ids, to_embed_source_ids)
+ '''
with self.session_maker() as session:
try:
stmt = (
sa.select(DocumentsStore.source_id)
- .filter(DocumentsStore.source_id.in_([source.filename for source in sources]))
+ .filter(DocumentsStore.source_id.in_([source.reference for source in sources.values()]))
.with_for_update()
)
results = session.execute(stmt).fetchall()
existing_sources = {r.source_id for r in results}
- to_embed = [source.filename for source in sources if source.filename not in existing_sources]
+ to_embed = [source.reference for source in sources.values() if source.reference not in existing_sources]
to_delete = []
- for source in sources:
+ for source in sources.values():
stmt = (
sa.select(DocumentsStore.source_id)
- .filter(DocumentsStore.source_id == source.filename)
+ .filter(DocumentsStore.source_id == source.reference)
.filter(DocumentsStore.modified < sa.cast(
- datetime.fromtimestamp(int(source.headers['modified'])),
+ datetime.fromtimestamp(int(source.modified)),
sa.DateTime,
))
)
@@ -239,14 +294,13 @@ def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str]
session.rollback()
raise DbException('Error: checking sources in vectordb') from e
- still_existing_sources = [
- source
- for source in existing_sources
- if source not in to_delete
+ still_existing_source_ids = [
+ source_id
+ for source_id in existing_sources
+ if source_id not in to_delete
]
- # the pyright issue stems from source.filename, which has already been validated
- return list(still_existing_sources), to_embed # pyright: ignore[reportReturnType]
+ return list(still_existing_source_ids), to_embed
def decl_update_access(self, user_ids: list[str], source_id: str, session_: orm.Session | None = None):
session = session_ or self.session_maker()
@@ -325,7 +379,7 @@ def update_access(
)
match op:
- case UpdateAccessOp.allow:
+ case UpdateAccessOp.ALLOW:
for i in range(0, len(user_ids), PG_BATCH_SIZE):
batched_uids = user_ids[i:i+PG_BATCH_SIZE]
stmt = (
@@ -342,7 +396,7 @@ def update_access(
session.execute(stmt)
session.commit()
- case UpdateAccessOp.deny:
+ case UpdateAccessOp.DENY:
for i in range(0, len(user_ids), PG_BATCH_SIZE):
batched_uids = user_ids[i:i+PG_BATCH_SIZE]
stmt = (
@@ -435,15 +489,17 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None
# entry from "AccessListStore" is deleted automatically due to the foreign key constraint
# batch the deletion to avoid hitting the query parameter limit
chunks_to_delete = []
+ deleted_source_ids = []
for i in range(0, len(source_ids), PG_BATCH_SIZE):
batched_ids = source_ids[i:i+PG_BATCH_SIZE]
stmt_doc = (
sa.delete(DocumentsStore)
.filter(DocumentsStore.source_id.in_(batched_ids))
- .returning(DocumentsStore.chunks)
+ .returning(DocumentsStore.chunks, DocumentsStore.source_id)
)
doc_result = session.execute(stmt_doc)
chunks_to_delete.extend(str(c) for res in doc_result for c in res.chunks)
+ deleted_source_ids.extend(str(res.source_id) for res in doc_result)
for i in range(0, len(chunks_to_delete), PG_BATCH_SIZE):
batched_chunks = chunks_to_delete[i:i+PG_BATCH_SIZE]
@@ -463,6 +519,14 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None
if session_ is None:
session.close()
+ undeleted_source_ids = set(source_ids) - set(deleted_source_ids)
+ if len(undeleted_source_ids) > 0:
+ logger.info(
+ f'Source ids {undeleted_source_ids} were not deleted from documents store.'
+ ' This can be due to the source ids not existing in the documents store due to'
+ ' already being deleted or not having been added yet.'
+ )
+
def delete_provider(self, provider_key: str):
with self.session_maker() as session:
try:
@@ -506,7 +570,16 @@ def delete_user(self, user_id: str):
session.rollback()
raise DbException('Error: deleting user from access list') from e
- self._cleanup_if_orphaned(list(source_ids), session)
+ try:
+ self._cleanup_if_orphaned(list(source_ids), session)
+ except Exception as e:
+ session.rollback()
+ logger.error(
+ 'Error cleaning up orphaned source ids after deleting user, manual cleanup might be required',
+ exc_info=e,
+ extra={ 'source_ids': list(source_ids) },
+ )
+ raise DbException('Error: cleaning up orphaned source ids after deleting user') from e
def count_documents_by_provider(self) -> dict[str, int]:
try:
@@ -537,10 +610,9 @@ def doc_search(
try:
with self.session_maker() as session:
doc_filters = [AccessListStore.uid == user_id]
- match scope_type:
- case ScopeType.PROVIDER:
+ if scope_type == ScopeType.PROVIDER.value:
doc_filters.append(DocumentsStore.provider.in_(scope_list)) # pyright: ignore[reportArgumentType]
- case ScopeType.SOURCE:
+ elif scope_type == ScopeType.SOURCE.value:
doc_filters.append(DocumentsStore.source_id.in_(scope_list)) # pyright: ignore[reportArgumentType]
# get chunks associated with the user
@@ -552,8 +624,13 @@ def doc_search(
result = session.execute(stmt).fetchall()
chunk_ids = [str(c) for res in result for c in res.chunks]
+ if len(chunk_ids) == 0:
+ return []
+
# get embeddings
return self._similarity_search(session, query, chunk_ids, k)
+ except EmbeddingException:
+ raise
except Exception as e:
raise DbException('Error: performing doc search in vectordb') from e
@@ -563,7 +640,7 @@ def _similarity_search(
session: orm.Session,
query: str,
chunk_ids: list[str],
- k: int = 20,
+ k: int,
) -> list[Document]:
embedding = self.client.embeddings.embed_query(query)
collection = self.client.get_collection(session)
diff --git a/context_chat_backend/vectordb/service.py b/context_chat_backend/vectordb/service.py
index 620a0b39..06a8e19e 100644
--- a/context_chat_backend/vectordb/service.py
+++ b/context_chat_backend/vectordb/service.py
@@ -6,27 +6,42 @@
from ..dyn_loader import VectorDBLoader
from .base import BaseVectorDB
-from .types import DbException, UpdateAccessOp
+from .types import UpdateAccessOp
logger = logging.getLogger('ccb.vectordb')
-# todo: return source ids that were successfully deleted
+
def delete_by_source(vectordb_loader: VectorDBLoader, source_ids: list[str]):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('deleting sources by id', extra={ 'source_ids': source_ids })
- try:
- db.delete_source_ids(source_ids)
- except Exception as e:
- raise DbException('Error: Vectordb delete_source_ids error') from e
+ db.delete_source_ids(source_ids)
def delete_by_provider(vectordb_loader: VectorDBLoader, provider_key: str):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug(f'deleting sources by provider: {provider_key}')
db.delete_provider(provider_key)
def delete_user(vectordb_loader: VectorDBLoader, user_id: str):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug(f'deleting user from db: {user_id}')
db.delete_user(user_id)
@@ -38,6 +53,13 @@ def update_access(
user_ids: list[str],
source_id: str,
):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ SafeDbException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('updating access', extra={ 'op': op, 'user_ids': user_ids, 'source_id': source_id })
db.update_access(op, user_ids, source_id)
@@ -49,6 +71,13 @@ def update_access_provider(
user_ids: list[str],
provider_id: str,
):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ SafeDbException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('updating access by provider', extra={ 'op': op, 'user_ids': user_ids, 'provider_id': provider_id })
db.update_access_provider(op, user_ids, provider_id)
@@ -59,11 +88,24 @@ def decl_update_access(
user_ids: list[str],
source_id: str,
):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ SafeDbException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('decl update access', extra={ 'user_ids': user_ids, 'source_id': source_id })
db.decl_update_access(user_ids, source_id)
def count_documents_by_provider(vectordb_loader: VectorDBLoader):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('counting documents by provider')
return db.count_documents_by_provider()
diff --git a/context_chat_backend/vectordb/types.py b/context_chat_backend/vectordb/types.py
index df5c6dd7..30811797 100644
--- a/context_chat_backend/vectordb/types.py
+++ b/context_chat_backend/vectordb/types.py
@@ -14,5 +14,5 @@ class SafeDbException(Exception):
class UpdateAccessOp(Enum):
- allow = 'allow'
- deny = 'deny'
+ ALLOW = 'allow'
+ DENY = 'deny'
diff --git a/dockerfile_scripts/pgsql/setup.sh b/dockerfile_scripts/pgsql/setup.sh
index cee4295b..7578ed83 100755
--- a/dockerfile_scripts/pgsql/setup.sh
+++ b/dockerfile_scripts/pgsql/setup.sh
@@ -18,7 +18,7 @@ fi
# Check if EXTERNAL_DB is set
if [ -n "${EXTERNAL_DB}" ]; then
if [[ "$EXTERNAL_DB" != "postgresql+psycopg://"* ]]; then
- echo "EXTERNAL_DB must be a PostgreSQL URL and start with 'postgresql+psycopg://'"
+ printf "%s\n" "EXTERNAL_DB must be a PostgreSQL URL and start with 'postgresql+psycopg://'" >&2
exit 1
fi
@@ -31,6 +31,11 @@ if [ -n "${EXTERNAL_DB}" ]; then
exit 0
fi
+if [[ -n "${APP_ROLE}" && "$APP_ROLE" != "normal" && "$APP_ROLE" != "" ]]; then
+ printf "%s\n" "Refusing to start the internal postgresql server in Kubernetes environment, use an external database through the EXTERNAL_DB env var." >&2
+ exit 1
+fi
+
# Ensure the directory exists and has the correct permissions
mkdir -p "$DATA_DIR"
chmod +rx "${APP_PERSISTENT_STORAGE:-persistent_storage}"
diff --git a/logger_config.k8s.yaml b/logger_config.k8s.yaml
new file mode 100644
index 00000000..6d5c7298
--- /dev/null
+++ b/logger_config.k8s.yaml
@@ -0,0 +1,43 @@
+#
+# SPDX-FileCopyrightText: 2022 MCODING, LLC
+# SPDX-FileCopyrightText: 2025 Nextcloud GmbH and Nextcloud contributors
+# SPDX-License-Identifier: AGPL-3.0-or-later
+#
+
+version: 1
+disable_existing_loggers: false
+
+formatters:
+ json:
+ (): context_chat_backend.logger.JSONFormatter
+ fmt_keys:
+ timestamp: timestamp
+ level: levelname
+ logger: name
+ message: message
+ filename: filename
+ function: funcName
+ line: lineno
+ thread_name: threadName
+ pid: process
+
+
+handlers:
+ stderr:
+ class: logging.StreamHandler
+ level: DEBUG
+ formatter: json
+ stream: ext://sys.stderr
+
+
+loggers:
+ root:
+ level: WARNING
+ handlers:
+ - stderr
+
+ ccb:
+ level: WARNING
+ handlers:
+ - stderr
+ propagate: false
diff --git a/main.py b/main.py
index c4ffa1fd..8a2bedaa 100755
--- a/main.py
+++ b/main.py
@@ -3,18 +3,22 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
+
import logging
-from os import getenv
+import multiprocessing as mp
+from os import cpu_count, getenv
+import psutil
import uvicorn
from nc_py_api.ex_app import run_app
from context_chat_backend.types import TConfig # isort: skip
from context_chat_backend.controller import app # isort: skip
from context_chat_backend.logger import get_logging_config, setup_logging # isort: skip
-from context_chat_backend.utils import redact_config # isort: skip
+from context_chat_backend.utils import is_k8s_env, redact_config # isort: skip
LOGGER_CONFIG_NAME = 'logger_config.yaml'
+LOGGER_K8S_CONFIG_NAME = 'logger_config.k8s.yaml'
def _setup_log_levels(debug: bool):
'''
@@ -43,19 +47,37 @@ def _setup_log_levels(debug: bool):
if __name__ == '__main__':
- logging_config = get_logging_config(LOGGER_CONFIG_NAME)
+ k8s_env = is_k8s_env()
+ logging_config = get_logging_config(LOGGER_K8S_CONFIG_NAME if k8s_env else LOGGER_CONFIG_NAME)
setup_logging(logging_config)
app_config: TConfig = app.extra['CONFIG']
_setup_log_levels(app_config.debug)
+ # do forks from a clean process that doesn't have any threads or locks
+ mp.set_start_method('forkserver')
+ mp.set_forkserver_preload([
+ 'context_chat_backend.chain.ingest.injest',
+ 'context_chat_backend.vectordb.pgvector',
+ 'langchain',
+ 'logging',
+ 'numpy',
+ 'sqlalchemy',
+ ])
+
+ print(f'CPU count: {cpu_count()}, Memory: {psutil.virtual_memory()}')
print('App config:\n' + redact_config(app_config).model_dump_json(indent=2), flush=True)
uv_log_config = uvicorn.config.LOGGING_CONFIG # pyright: ignore[reportAttributeAccessIssue]
- uv_log_config['formatters']['json'] = logging_config['formatters']['json']
- uv_log_config['handlers']['file_json'] = logging_config['handlers']['file_json']
+ use_colors = False if k8s_env else (app_config.use_colors and getenv('CI', 'false') == 'false')
- uv_log_config['loggers']['uvicorn']['handlers'].append('file_json')
- uv_log_config['loggers']['uvicorn.access']['handlers'].append('file_json')
+ if k8s_env:
+ uv_log_config['formatters']['default'] = logging_config['formatters']['json']
+ uv_log_config['formatters']['access'] = logging_config['formatters']['json']
+ else:
+ uv_log_config['formatters']['json'] = logging_config['formatters']['json']
+ uv_log_config['handlers']['file_json'] = logging_config['handlers']['file_json']
+ uv_log_config['loggers']['uvicorn']['handlers'].append('file_json')
+ uv_log_config['loggers']['uvicorn.access']['handlers'].append('file_json')
run_app(
uvicorn_app=app,
@@ -63,7 +85,7 @@ def _setup_log_levels(debug: bool):
interface='asgi3',
log_config=uv_log_config,
log_level=app_config.uvicorn_log_level,
- use_colors=bool(app_config.use_colors and getenv('CI', 'false') == 'false'),
+ use_colors=use_colors,
# limit_concurrency=10,
# backlog=20,
timeout_keep_alive=120,
diff --git a/main_em.py b/main_em.py
index b7d5a93b..addcfd60 100755
--- a/main_em.py
+++ b/main_em.py
@@ -12,14 +12,15 @@
import niquests
import uvicorn
-from context_chat_backend.types import DEFAULT_EM_MODEL_ALIAS # isort: skip
+from context_chat_backend.types import DEFAULT_EM_MODEL_ALIAS, AppRole # isort: skip
from context_chat_backend.config_parser import get_config # isort: skip
from context_chat_backend.logger import get_logging_config, setup_logging # isort: skip
from context_chat_backend.setup_functions import ensure_config_file, setup_env_vars # isort: skip
-from context_chat_backend.utils import redact_config # isort: skip
+from context_chat_backend.utils import get_app_role, is_k8s_env, redact_config # isort: skip
LOGGER_CONFIG_NAME = 'logger_config_em.yaml'
+LOGGER_K8S_CONFIG_NAME = 'logger_config.k8s.yaml'
STARTUP_CHECK_SEC = 10
MAX_TRIES = 180 # 180*10 secs = 30 minutes max
@@ -88,9 +89,14 @@ def _wait_main_app_enabled() -> None:
if __name__ == '__main__':
+ app_role = get_app_role()
+ if app_role == AppRole.UPDATES_PROC:
+ print('Internal embedding server is not required for the Updates Processing role, stopping this process.')
+ exit(0)
+
# intial buffer
print(
- f"Waiting for {STARTUP_CHECK_SEC} seconds before starting embedding server to allow main app to start",
+ f'Waiting for {STARTUP_CHECK_SEC} seconds before starting embedding server to allow main app to start',
flush=True,
)
sleep(STARTUP_CHECK_SEC)
@@ -108,7 +114,8 @@ def _wait_main_app_enabled() -> None:
# in local embedding server config
print('Embedder config:\n' + redact_config(em_conf).model_dump_json(indent=2), flush=True)
- logging_config = get_logging_config(LOGGER_CONFIG_NAME)
+ k8s_env = is_k8s_env()
+ logging_config = get_logging_config(LOGGER_K8S_CONFIG_NAME if k8s_env else LOGGER_CONFIG_NAME)
setup_logging(logging_config)
logger = logging.getLogger('emserver')
if app_config.debug:
@@ -158,11 +165,16 @@ def _wait_main_app_enabled() -> None:
)
uv_log_config = uvicorn.config.LOGGING_CONFIG # pyright: ignore[reportAttributeAccessIssue]
- uv_log_config['formatters']['json'] = logging_config['formatters']['json']
- uv_log_config['handlers']['file_json'] = logging_config['handlers']['file_json']
+ use_colors = False if k8s_env else (app_config.use_colors and os.getenv('CI', 'false') == 'false')
- uv_log_config['loggers']['uvicorn']['handlers'].append('file_json')
- uv_log_config['loggers']['uvicorn.access']['handlers'].append('file_json')
+ if k8s_env:
+ uv_log_config['formatters']['default'] = logging_config['formatters']['json']
+ uv_log_config['formatters']['access'] = logging_config['formatters']['json']
+ else:
+ uv_log_config['formatters']['json'] = logging_config['formatters']['json']
+ uv_log_config['handlers']['file_json'] = logging_config['handlers']['file_json']
+ uv_log_config['loggers']['uvicorn']['handlers'].append('file_json')
+ uv_log_config['loggers']['uvicorn.access']['handlers'].append('file_json')
uvicorn.run(
# todo: use string import of the app
@@ -173,6 +185,6 @@ def _wait_main_app_enabled() -> None:
interface='asgi3',
log_config=uv_log_config,
log_level=app_config.uvicorn_log_level,
- use_colors=bool(app_config.use_colors and os.getenv('CI', 'false') == 'false'),
+ use_colors=use_colors,
workers=em_conf.workers,
)