Skip to content

Commit d1d9f2e

Browse files
authored
Merge pull request #171 from NillionNetwork/feat/add_pricing_management
feat: added pricing management
2 parents d4e6630 + af6ddb0 commit d1d9f2e

14 files changed

Lines changed: 860 additions & 116 deletions

File tree

.env.ci

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ ATTESTATION_HOST = "attestation"
2323
ATTESTATION_PORT = 8080
2424

2525
# nilAuth Trusted URLs
26-
NILAUTH_TRUSTED_ROOT_ISSUERS = "http://nilauth-credit-server:3000" # "http://nilauth:30921"
26+
NILAUTH_TRUSTED_ROOT_ISSUERS = "http://nilauth-credit-server:3000"
2727
CREDIT_API_TOKEN = "n i l l i o n"
2828

29+
# Admin token for pricing management API
30+
ADMIN_TOKEN = "SecretAdminToken"
31+
2932
# Postgres Docker Compose Config
3033
POSTGRES_HOST = "postgres"
3134
POSTGRES_USER = "user"

nilai-api/src/nilai_api/app.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from fastapi import Depends, FastAPI
66
from nilai_api.auth import get_auth_info
77
from nilai_api.rate_limiting import setup_redis_conn
8-
from nilai_api.routers import private, public
8+
from nilai_api.routers import private, public, pricing
9+
from nilai_api.pricing_service import PricingService, set_pricing_service
910
from nilai_api import config
1011
from contextlib import asynccontextmanager
1112
from fastapi.middleware.cors import CORSMiddleware
@@ -16,6 +17,11 @@
1617
async def lifespan(app: FastAPI):
1718
client, rate_limit_command = await setup_redis_conn(config.CONFIG.redis.url)
1819

20+
# Initialize pricing service
21+
pricing_service = PricingService(client)
22+
await pricing_service.initialize_from_config()
23+
set_pricing_service(pricing_service)
24+
1925
yield {"redis": client, "redis_rate_limit_command": rate_limit_command}
2026

2127

@@ -88,6 +94,7 @@ async def lifespan(app: FastAPI):
8894

8995
app.include_router(public.router)
9096
app.include_router(private.router, dependencies=[Depends(get_auth_info)])
97+
app.include_router(pricing.router)
9198

9299
app.add_middleware(
93100
CORSMiddleware,

nilai-api/src/nilai_api/config/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .nildb import NilDBConfig
99
from .web_search import WebSearchSettings
1010
from .rate_limiting import RateLimitingConfig
11+
from .pricing import LLMPricingConfig, LLMPriceConfig
1112
from .utils import create_config_model, CONFIG_DATA
1213

1314

@@ -37,6 +38,9 @@ class NilAIConfig(BaseModel):
3738
nildb: NilDBConfig = create_config_model(
3839
NilDBConfig, "nildb", CONFIG_DATA, "NILDB_"
3940
)
41+
llm_pricing: LLMPricingConfig = create_config_model(
42+
LLMPricingConfig, "llm_pricing", CONFIG_DATA
43+
)
4044

4145
def prettify(self):
4246
"""Print the config in a pretty format removing passwords and other sensitive information"""
@@ -66,7 +70,10 @@ def prettify(self):
6670
CONFIG = NilAIConfig()
6771
__all__ = [
6872
# Main config object
69-
"CONFIG"
73+
"CONFIG",
74+
# Pricing config for external use
75+
"LLMPriceConfig",
76+
"LLMPricingConfig",
7077
]
7178

7279
logging.info(CONFIG.prettify())

nilai-api/src/nilai_api/config/auth.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ class AuthConfig(BaseModel):
1313
auth_token: Optional[str] = Field(
1414
default=None, description="Auth token for e2e tests and development"
1515
)
16+
admin_token: Optional[str] = Field(
17+
default=None, description="Admin token for pricing updates"
18+
)
1619

1720
@property
1821
def credit_service_url(self) -> str:

nilai-api/src/nilai_api/config/config-a779.yaml

Lines changed: 0 additions & 31 deletions
This file was deleted.

nilai-api/src/nilai_api/config/config-e176.yaml

Lines changed: 0 additions & 35 deletions
This file was deleted.

nilai-api/src/nilai_api/config/config-f910.yaml

Lines changed: 0 additions & 35 deletions
This file was deleted.

nilai-api/src/nilai_api/config/config.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ auth:
99
strategy: "api_key"
1010
nilauth_trusted_root_issuers:
1111
- http://nilauth-credit-server:3000
12+
admin_token: null # Set via ADMIN_TOKEN env var for pricing management
1213

1314
# Documentation Configuration
1415
docs:
@@ -46,3 +47,31 @@ rate_limiting:
4647
openai/gpt-oss-20b: 50
4748
google/gemma-3-27b-it: 50
4849
default: 50
50+
51+
# LLM Pricing Configuration
52+
llm_pricing:
53+
default:
54+
prompt_tokens_price: 0.15
55+
completion_tokens_price: 0.45
56+
web_search_cost: 0.05
57+
models:
58+
meta-llama/Llama-3.2-1B-Instruct:
59+
prompt_tokens_price: 0.03
60+
completion_tokens_price: 0.09
61+
web_search_cost: 0.05
62+
meta-llama/Llama-3.1-8B-Instruct:
63+
prompt_tokens_price: 0.03
64+
completion_tokens_price: 0.09
65+
web_search_cost: 0.05
66+
openai/gpt-oss-20b:
67+
prompt_tokens_price: 0.15
68+
completion_tokens_price: 0.45
69+
web_search_cost: 0.05
70+
google/gemma-3-27b-it:
71+
prompt_tokens_price: 0.15
72+
completion_tokens_price: 0.45
73+
web_search_cost: 0.05
74+
Qwen/Qwen3-Coder-30B-A3B-Instruct:
75+
prompt_tokens_price: 0.15
76+
completion_tokens_price: 0.45
77+
web_search_cost: 0.05
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Dict
2+
from pydantic import BaseModel, Field
3+
4+
5+
class LLMPriceConfig(BaseModel):
6+
"""Pricing configuration for a single LLM model."""
7+
8+
prompt_tokens_price: float = Field(
9+
default=2.0, description="Cost per 1M prompt tokens"
10+
)
11+
completion_tokens_price: float = Field(
12+
default=2.0, description="Cost per 1M completion tokens"
13+
)
14+
web_search_cost: float = Field(default=0.05, description="Cost per web search")
15+
16+
17+
class LLMPricingConfig(BaseModel):
18+
"""Container for all LLM pricing configurations."""
19+
20+
default: LLMPriceConfig = Field(default_factory=LLMPriceConfig)
21+
models: Dict[str, LLMPriceConfig] = Field(default_factory=dict)

nilai-api/src/nilai_api/credit.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212

1313
from nilai_api.config import CONFIG
14+
from nilai_api.pricing_service import get_pricing_service
1415

1516
from nuc.envelope import NucTokenEnvelope
1617

@@ -51,6 +52,22 @@ def default() -> "LLMCost":
5152
prompt_tokens_price=2.0, completion_tokens_price=2.0, web_search_cost=0.05
5253
)
5354

55+
@staticmethod
56+
async def from_redis(model_name: str) -> "LLMCost":
57+
"""Fetch pricing from Redis for a specific model."""
58+
try:
59+
pricing_service = get_pricing_service()
60+
price_config = await pricing_service.get_price(model_name)
61+
return LLMCost(
62+
prompt_tokens_price=price_config.prompt_tokens_price,
63+
completion_tokens_price=price_config.completion_tokens_price,
64+
web_search_cost=price_config.web_search_cost,
65+
)
66+
except RuntimeError:
67+
# Pricing service not initialized, use default
68+
logger.warning("Pricing service not initialized, using default pricing")
69+
return LLMCost.default()
70+
5471
def total_cost(
5572
self, prompt_tokens: int, completion_tokens: int, web_searches: int
5673
) -> float:
@@ -87,14 +104,6 @@ class LLMResponse(BaseModel):
87104

88105
LLMCostDict: TypeAlias = dict[str, LLMCost]
89106

90-
91-
MyCostDictionary: LLMCostDict = {
92-
"meta-llama/Llama-3.2-1B-Instruct": LLMCost(
93-
prompt_tokens_price=3.0, completion_tokens_price=3.0, web_search_cost=0.05
94-
),
95-
"default": LLMCost.default(),
96-
}
97-
98107
# Configure the singleton credit client
99108
CreditClientSingleton.configure(
100109
base_url=CONFIG.auth.credit_service_url,
@@ -138,10 +147,10 @@ async def extractor(request: Request) -> str:
138147
return extractor
139148

140149

141-
def llm_cost_calculator(llm_cost_dict: LLMCostDict):
150+
def llm_cost_calculator():
142151
async def calculator(request: Request, response_data: dict) -> float:
143152
model_name = getattr(request, "model", "default")
144-
llm_cost = llm_cost_dict.get(model_name, LLMCost.default())
153+
llm_cost = await LLMCost.from_redis(model_name)
145154
total_cost = 0.0
146155
usage: Optional[LLMUsage] = response_data.get("usage", None)
147156
if usage is None:
@@ -157,8 +166,8 @@ async def calculator(request: Request, response_data: dict) -> float:
157166

158167
_base_llm_meter = create_metering_dependency(
159168
credential_extractor=credential_extractor(),
160-
estimated_cost=2.0,
161-
cost_calculator=llm_cost_calculator(MyCostDictionary),
169+
estimated_cost=0.5,
170+
cost_calculator=llm_cost_calculator(),
162171
public_identifiers=CONFIG.auth.auth_strategy == "nuc",
163172
)
164173

0 commit comments

Comments
 (0)