Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,94 @@ import (
)

const (
defaultManifestPath = "/config/models.json"
defaultManifestPath = "/config/models/models.json"
defaultFlagsConfig = "/config/flags/flags.json"
maxRetries = 3
retryDelay = 10 * time.Second
)

var errConflict = errors.New("flag already exists (conflict)")

// FlagTag represents a tag to attach to an Unleash feature flag.
type FlagTag struct {
Type string `json:"type"`
Value string `json:"value"`
}

// FlagSpec describes a feature flag to sync to Unleash.
// All flags are created disabled with type "release" and a flexibleRollout
// strategy at 0%. Tags are optional and per-flag.
type FlagSpec struct {
Name string `json:"name"`
Description string `json:"description"`
Tags []FlagTag `json:"tags,omitempty"`
}

// FlagsConfig is the JSON structure for the generic flags config file.
type FlagsConfig struct {
Flags []FlagSpec `json:"flags"`
}

// FlagsFromManifest converts a model manifest into FlagSpecs.
// Skips the default model and unavailable models.
func FlagsFromManifest(manifest *types.ModelManifest) []FlagSpec {
var specs []FlagSpec
for _, model := range manifest.Models {
if model.ID == manifest.DefaultModel {
continue
}
if !model.Available {
continue
}
specs = append(specs, FlagSpec{
Name: sanitizeLogString(fmt.Sprintf("model.%s.enabled", model.ID)),
Description: sanitizeLogString(fmt.Sprintf("Enable %s (%s) for users", model.Label, model.ID)),
Tags: []FlagTag{{Type: "scope", Value: "workspace"}},
})
}
return specs
}

// FlagsConfigPath returns the filesystem path to the generic flags config.
// Defaults to defaultFlagsConfig; override via FLAGS_CONFIG_PATH env var.
func FlagsConfigPath() string {
if p := os.Getenv("FLAGS_CONFIG_PATH"); p != "" {
return p
}
return defaultFlagsConfig
}

// FlagsFromConfig loads generic flag definitions from a JSON file.
// Returns nil if the file does not exist (flags config is optional).
func FlagsFromConfig(path string) ([]FlagSpec, error) {
data, err := os.ReadFile(path)
if os.IsNotExist(err) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("reading flags config %s: %w", path, err)
}

var cfg FlagsConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parsing flags config: %w", err)
}

// Sanitize flag names and descriptions to prevent log injection.
// Model-derived names are constrained (model.<id>.enabled) but
// config-file names are user-defined and unconstrained.
for i := range cfg.Flags {
cfg.Flags[i].Name = sanitizeLogString(cfg.Flags[i].Name)
cfg.Flags[i].Description = sanitizeLogString(cfg.Flags[i].Description)
for j := range cfg.Flags[i].Tags {
cfg.Flags[i].Tags[j].Type = sanitizeLogString(cfg.Flags[i].Tags[j].Type)
cfg.Flags[i].Tags[j].Value = sanitizeLogString(cfg.Flags[i].Tags[j].Value)
}
}

return cfg.Flags, nil
}

// SyncModelFlagsFromFile reads a model manifest from disk and syncs flags.
// Used by the sync-model-flags subcommand.
func SyncModelFlagsFromFile(manifestPath string) error {
Expand All @@ -39,40 +120,39 @@ func SyncModelFlagsFromFile(manifestPath string) error {
return fmt.Errorf("parsing manifest: %w", err)
}

return SyncModelFlags(context.Background(), &manifest)
return SyncFlags(context.Background(), FlagsFromManifest(&manifest))
}

// SyncModelFlagsAsync runs SyncModelFlags in a background goroutine with
// retries. Intended for use at server startup — does not block the caller.
// SyncFlagsAsync runs SyncFlags in a background goroutine with retries.
// Intended for use at server startup — does not block the caller.
// Cancel the context to abort retries (e.g. on SIGTERM).
func SyncModelFlagsAsync(ctx context.Context, manifest *types.ModelManifest) {
func SyncFlagsAsync(ctx context.Context, flags []FlagSpec) {
go func() {
for attempt := 1; attempt <= maxRetries; attempt++ {
err := SyncModelFlags(ctx, manifest)
err := SyncFlags(ctx, flags)
if err == nil {
return
}
log.Printf("sync-model-flags: attempt %d/%d failed: %v", attempt, maxRetries, err)
log.Printf("sync-flags: attempt %d/%d failed: %v", attempt, maxRetries, err)
if attempt < maxRetries {
select {
case <-ctx.Done():
log.Printf("sync-model-flags: cancelled, stopping retries")
log.Printf("sync-flags: cancelled, stopping retries")
return
case <-time.After(retryDelay):
}
}
}
log.Printf("sync-model-flags: all %d attempts failed, giving up", maxRetries)
log.Printf("sync-flags: all %d attempts failed, giving up", maxRetries)
}()
}

// SyncModelFlags ensures every model in the manifest has a corresponding
// Unleash feature flag. Flags are created disabled with type "release"
// and tagged scope:workspace so they appear in the admin UI.
// SyncFlags ensures every FlagSpec has a corresponding Unleash feature flag.
// Flags are created disabled with type "release" and a flexibleRollout strategy.
//
// Required env vars: UNLEASH_ADMIN_URL, UNLEASH_ADMIN_TOKEN
// Optional env var: UNLEASH_PROJECT (default: "default")
func SyncModelFlags(ctx context.Context, manifest *types.ModelManifest) error {
func SyncFlags(ctx context.Context, flags []FlagSpec) error {
adminURL := strings.TrimSuffix(strings.TrimSpace(os.Getenv("UNLEASH_ADMIN_URL")), "/")
adminToken := strings.TrimSpace(os.Getenv("UNLEASH_ADMIN_TOKEN"))
project := strings.TrimSpace(os.Getenv("UNLEASH_PROJECT"))
Expand All @@ -86,80 +166,91 @@ func SyncModelFlags(ctx context.Context, manifest *types.ModelManifest) error {
}

if adminURL == "" || adminToken == "" {
log.Printf("sync-model-flags: UNLEASH_ADMIN_URL or UNLEASH_ADMIN_TOKEN not set, skipping")
log.Printf("sync-flags: UNLEASH_ADMIN_URL or UNLEASH_ADMIN_TOKEN not set, skipping")
return nil
}

client := &http.Client{Timeout: 10 * time.Second}

// Ensure the "scope" tag type exists before creating flags
if err := ensureTagType(ctx, client, adminURL, "scope", "Controls flag visibility scope", adminToken); err != nil {
return fmt.Errorf("ensuring scope tag type: %w", err)
}

var created, skipped, excluded, errCount int
log.Printf("Syncing Unleash flags for %d models...", len(manifest.Models))

for _, model := range manifest.Models {
if model.ID == manifest.DefaultModel {
log.Printf(" %s: default model, no flag needed", model.ID)
excluded++
continue
}

if !model.Available {
log.Printf(" %s: not available, skipping flag creation", model.ID)
excluded++
continue
// Ensure all required tag types exist
tagTypes := collectTagTypes(flags)
for _, tt := range tagTypes {
if err := ensureTagType(ctx, client, adminURL, tt, fmt.Sprintf("Tag type: %s", tt), adminToken); err != nil {
return fmt.Errorf("ensuring tag type %q: %w", tt, err)
}
}

flagName := fmt.Sprintf("model.%s.enabled", model.ID)
var created, skipped, errCount int
log.Printf("Syncing %d Unleash flag(s)...", len(flags))

exists, err := flagExists(ctx, client, adminURL, project, flagName, adminToken)
for _, flag := range flags {
exists, err := flagExists(ctx, client, adminURL, project, flag.Name, adminToken)
if err != nil {
log.Printf(" ERROR checking %s: %v", flagName, err)
log.Printf(" ERROR checking %s: %v", flag.Name, err)
errCount++
continue
}

if exists {
log.Printf(" %s: already exists, skipping", flagName)
log.Printf(" %s: already exists, skipping", flag.Name)
skipped++
continue
}

description := fmt.Sprintf("Enable %s (%s) for users", model.Label, model.ID)
if err := createFlag(ctx, client, adminURL, project, flagName, description, adminToken); err != nil {
if err := createFlag(ctx, client, adminURL, project, flag.Name, flag.Description, adminToken); err != nil {
if errors.Is(err, errConflict) {
log.Printf(" %s: created by another instance, skipping", flagName)
log.Printf(" %s: created by another instance, skipping", flag.Name)
skipped++
continue
}
log.Printf(" ERROR creating %s: %v", flagName, err)
log.Printf(" ERROR creating %s: %v", flag.Name, err)
errCount++
continue
}

if err := addTag(ctx, client, adminURL, flagName, adminToken); err != nil {
log.Printf(" WARNING: created %s but failed to add tag: %v", flagName, err)
for _, tag := range flag.Tags {
if err := addFlagTag(ctx, client, adminURL, flag.Name, tag, adminToken); err != nil {
log.Printf(" WARNING: created %s but failed to add tag %s:%s: %v", flag.Name, tag.Type, tag.Value, err)
}
}

if err := addRolloutStrategy(ctx, client, adminURL, project, environment, flagName, adminToken); err != nil {
log.Printf(" WARNING: created %s but failed to add rollout strategy: %v", flagName, err)
if err := addRolloutStrategy(ctx, client, adminURL, project, environment, flag.Name, adminToken); err != nil {
log.Printf(" WARNING: created %s but failed to add rollout strategy: %v", flag.Name, err)
}

log.Printf(" %s: created (disabled, 0%% rollout)", flagName)
log.Printf(" %s: created (disabled, 0%% rollout)", flag.Name)
created++
}

log.Printf("Summary: %d created, %d skipped, %d excluded, %d errors", created, skipped, excluded, errCount)
log.Printf("Summary: %d created, %d skipped, %d errors", created, skipped, errCount)

if errCount > 0 {
return fmt.Errorf("%d errors occurred during sync", errCount)
}
return nil
}

// sanitizeLogString strips newlines and carriage returns from strings
// that will be interpolated into log messages, preventing log injection.
func sanitizeLogString(s string) string {
return strings.ReplaceAll(strings.ReplaceAll(s, "\n", ""), "\r", "")
}

// collectTagTypes returns the unique set of tag types across all flags.
func collectTagTypes(flags []FlagSpec) []string {
seen := map[string]bool{}
var result []string
for _, f := range flags {
for _, t := range f.Tags {
if !seen[t.Type] {
seen[t.Type] = true
result = append(result, t.Type)
}
}
}
return result
}

// ParseManifestPath extracts --manifest-path from args, returning the path
// and whether it was found. Falls back to defaultManifestPath.
func ParseManifestPath(args []string) string {
Expand All @@ -175,7 +266,6 @@ func ParseManifestPath(args []string) string {
}

func ensureTagType(ctx context.Context, client *http.Client, adminURL, name, description, token string) error {
// Check if tag type exists
reqURL := fmt.Sprintf("%s/api/admin/tag-types/%s", adminURL, url.PathEscape(name))
resp, err := doRequest(ctx, client, "GET", reqURL, token, nil)
if err != nil {
Expand All @@ -189,7 +279,6 @@ func ensureTagType(ctx context.Context, client *http.Client, adminURL, name, des
return nil
}

// Create it
createURL := fmt.Sprintf("%s/api/admin/tag-types", adminURL)
body, err := json.Marshal(map[string]string{
"name": name,
Expand Down Expand Up @@ -265,11 +354,11 @@ func createFlag(ctx context.Context, client *http.Client, adminURL, project, fla
}
}

func addTag(ctx context.Context, client *http.Client, adminURL, flagName, token string) error {
func addFlagTag(ctx context.Context, client *http.Client, adminURL, flagName string, tag FlagTag, token string) error {
reqURL := fmt.Sprintf("%s/api/admin/features/%s/tags", adminURL, url.PathEscape(flagName))
body, err := json.Marshal(map[string]string{
"type": "scope",
"value": "workspace",
"type": tag.Type,
"value": tag.Value,
})
if err != nil {
return fmt.Errorf("marshaling tag request: %w", err)
Expand Down Expand Up @@ -316,8 +405,8 @@ func addRolloutStrategy(ctx context.Context, client *http.Client, adminURL, proj
return nil
}

func doRequest(ctx context.Context, client *http.Client, method, url, token string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
func doRequest(ctx context.Context, client *http.Client, method, reqURL, token string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading