diff --git a/README.md b/README.md index 1eb6237..0730399 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This repository contains a generic HTTP client which can be adapted to provide: * Ability to send files and data of type `multipart/form-data` * Ability to send data of type `application/x-www-form-urlencoded` * Debugging capabilities to see the request and response data -* Streaming text events +* Streaming text and JSON events API Documentation: https://pkg.go.dev/github.com/mutablelogic/go-client @@ -159,6 +159,9 @@ modify each individual request when using the `Do` method: * `OptTextStreamCallback(func(TextStreamCallback) error)` allows you to set a callback function to process a streaming text response of type `text/event-stream`. See below for more details. +* `OptJsonStreamCallback(func(any) error)` allows you to set a callback for JSON streaming + responses. The callback should have the signature `func(any) error`. See below for + more details. ## Authentication @@ -191,9 +194,9 @@ You can also set the token on a per-request basis using the `OptToken` option in You can create a payload with form data: -* `client.NewFormRequest(payload any, accept string)` returns a new request with a Form +* `client.NewFormRequest(payload any, accept string)` returns a new request with a Form data payload which defaults to POST. -* `client.NewMultipartRequest(payload any, accept string)` returns a new request with +* `client.NewMultipartRequest(payload any, accept string)` returns a new request with a Multipart Form data payload which defaults to POST. This is useful for file uploads. The payload should be a `struct` where the fields are converted to form tuples. File uploads require a field of type `multipart.File`. For example, @@ -241,9 +244,10 @@ type Unmarshaler interface { } ``` -## Streaming Responses +## Text Streaming Responses -The client implements a streaming text event callback which can be used to process a stream of text events, as per the [Mozilla specification](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events). +The client implements a streaming text event callback which can be used to process a stream of text events, +as per the [Mozilla specification](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events). In order to process streamed events, pass the `OptTextStreamCallback()` option to the request with a callback function, which should have the following signature: @@ -272,3 +276,12 @@ If you return an error of type `io.EOF` from the callback, then the stream will Similarly, if you return any other error the stream will be closed and the error returned. Usually, you would pair this option with `OptNoTimeout` to prevent the request from timing out. + +## JSON Streaming Responses + +The client decodes JSON streaming responses by passing a callback function to the `OptJsonStreamCallback()` option. +The callback with signature `func(any) error` is called for each JSON object in the stream, where the argument +is the same type as the object in the request. + +You can return an error from the callback to stop the stream and return the error, or return `io.EOF` to stop the stream +immediately and return success. diff --git a/client.go b/client.go index 7666ffa..76e4334 100644 --- a/client.go +++ b/client.go @@ -61,6 +61,7 @@ const ( PathSeparator = string(os.PathSeparator) ContentTypeAny = "*/*" ContentTypeJson = "application/json" + ContentTypeJsonStream = "application/x-ndjson" ContentTypeTextXml = "text/xml" ContentTypeApplicationXml = "application/xml" ContentTypeTextPlain = "text/plain" @@ -306,7 +307,7 @@ func do(client *http.Client, req *http.Request, accept string, strict bool, out // Decode the body switch mimetype { - case ContentTypeJson: + case ContentTypeJson, ContentTypeJsonStream: // JSON decode is streamable dec := json.NewDecoder(response.Body) for { diff --git a/cmd/agent/chat.go b/cmd/agent/chat.go new file mode 100644 index 0000000..e9e2d87 --- /dev/null +++ b/cmd/agent/chat.go @@ -0,0 +1,169 @@ +package main + +import ( + "context" + "fmt" + + // Packages + markdown "github.com/MichaelMure/go-term-markdown" + agent "github.com/mutablelogic/go-client/pkg/agent" +) + +///////////////////////////////////////////////////////////////////// +// TYPES + +type ChatCmd struct { + Prompt string `arg:"" optional:"" help:"The prompt to generate a response for"` + Agent string `flag:"agent" help:"The agent to use"` + Model string `flag:"model" help:"The model to use"` + Stream bool `flag:"stream" help:"Stream the response"` +} + +///////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +func (cmd *ChatCmd) Run(globals *Globals) error { + // Get the agent and the model + model_agent, model := globals.getModel(globals.ctx, cmd.Agent, cmd.Model) + if model_agent == nil || model == nil { + return fmt.Errorf("model %q not found, or not set on command line", globals.state.Model) + } + + // Generate the options + opts := make([]agent.Opt, 0) + if cmd.Stream { + opts = append(opts, agent.OptStream(func(r agent.Response) { + fmt.Println(r) + })) + } + + // Add tools + if tools := globals.getTools(); len(tools) > 0 { + opts = append(opts, agent.OptTools(tools...)) + } + + // If the prompt is empty, then we're in interative mode + context := []agent.Context{} + if cmd.Prompt == "" { + if globals.term == nil { + return fmt.Errorf("prompt is empty and not in interactive mode") + } + } else { + context = append(context, model_agent.UserPrompt(cmd.Prompt)) + } + +FOR_LOOP: + for { + // When there is no context, create some + if len(context) == 0 { + if prompt, err := globals.term.ReadLine(model.Name() + "> "); err != nil { + return err + } else if prompt == "" { + break FOR_LOOP + } else { + context = append(context, model_agent.UserPrompt(prompt)) + } + } + + // Generate a chat completion + response, err := model_agent.Generate(globals.ctx, model, context, opts...) + if err != nil { + return err + } + + // If the response is a tool call, then run the tool + if response.ToolCall != nil { + result, err := globals.runTool(globals.ctx, response.ToolCall) + if err != nil { + return err + } + response.Context = append(response.Context, result) + } else { + if globals.term != nil { + w, _ := globals.term.Size() + fmt.Println(string(markdown.Render(response.Text, w, 0))) + } else { + fmt.Println(response.Text) + } + + // Make empty context + response.Context = []agent.Context{} + } + + // Context comes from the response + context = response.Context + } + + // Return success + return nil +} + +///////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +// Get the model, either from state or from the command-line flags. +// If the model is not found, or there is another error, return nil +func (globals *Globals) getModel(ctx context.Context, agent, model string) (agent.Agent, agent.Model) { + state := globals.state + if agent != "" { + state.Agent = agent + } + if model != "" { + state.Model = model + } + + // Cycle through the agents and models to find the one we want + for _, agent := range globals.agents { + // Filter by agent + if state.Agent != "" && agent.Name() != state.Agent { + continue + } + + // Retrieve the models for this agent + models, err := agent.Models(ctx) + if err != nil { + continue + } + + // Filter by model + for _, model := range models { + if state.Model != "" && model.Name() != state.Model { + continue + } + + // This is the model we're using.... + state.Agent = agent.Name() + state.Model = model.Name() + return agent, model + } + } + + // No model found + return nil, nil +} + +// Get the tools +func (globals *Globals) getTools() []agent.Tool { + return globals.tools +} + +// Return a tool by name. If the tool is not found, return nil +func (globals *Globals) getTool(name string) agent.Tool { + for _, tool := range globals.tools { + if tool.Name() == name { + return tool + } + } + return nil +} + +// Run a tool from a tool call, and return the result +func (globals *Globals) runTool(ctx context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + tool := globals.getTool(call.Name) + if tool == nil { + return nil, fmt.Errorf("tool %q not found", call.Name) + } + + // Run the tool + return tool.Run(ctx, call) +} diff --git a/cmd/agent/list_agents.go b/cmd/agent/list_agents.go new file mode 100644 index 0000000..c07da17 --- /dev/null +++ b/cmd/agent/list_agents.go @@ -0,0 +1,31 @@ +package main + +import ( + "encoding/json" + "fmt" +) + +///////////////////////////////////////////////////////////////////// +// TYPES + +type ListAgentsCmd struct { +} + +///////////////////////////////////////////////////////////////////// +// METHODS + +func (cmd *ListAgentsCmd) Run(ctx *Globals) error { + result := make([]string, 0) + for _, agent := range ctx.agents { + result = append(result, agent.Name()) + } + + data, err := json.MarshalIndent(result, "", " ") + if err != nil { + return err + } + + fmt.Println(string(data)) + + return nil +} diff --git a/cmd/agent/list_models.go b/cmd/agent/list_models.go new file mode 100644 index 0000000..bdceef0 --- /dev/null +++ b/cmd/agent/list_models.go @@ -0,0 +1,42 @@ +package main + +import ( + "encoding/json" + "fmt" +) + +///////////////////////////////////////////////////////////////////// +// TYPES + +type ListModelsCmd struct { +} + +type modeljson struct { + Agent string `json:"agent"` + Model string `json:"model"` +} + +///////////////////////////////////////////////////////////////////// +// METHODS + +func (cmd *ListModelsCmd) Run(ctx *Globals) error { + result := make([]modeljson, 0) + for _, agent := range ctx.agents { + models, err := agent.Models(ctx.ctx) + if err != nil { + return err + } + for _, model := range models { + result = append(result, modeljson{Agent: agent.Name(), Model: model.Name()}) + } + } + + data, err := json.MarshalIndent(result, "", " ") + if err != nil { + return err + } + + fmt.Println(string(data)) + + return nil +} diff --git a/cmd/agent/list_tools.go b/cmd/agent/list_tools.go new file mode 100644 index 0000000..16cc62a --- /dev/null +++ b/cmd/agent/list_tools.go @@ -0,0 +1,37 @@ +package main + +import ( + "encoding/json" + "fmt" +) + +///////////////////////////////////////////////////////////////////// +// TYPES + +type ListToolsCmd struct { +} + +type tooljson struct { + Provider string `json:"provider"` + Name string `json:"name"` + Description string `json:"description"` +} + +///////////////////////////////////////////////////////////////////// +// METHODS + +func (cmd *ListToolsCmd) Run(ctx *Globals) error { + result := make([]tooljson, 0) + for _, tool := range ctx.tools { + result = append(result, tooljson{Provider: tool.Provider(), Name: tool.Name(), Description: tool.Description()}) + } + + data, err := json.MarshalIndent(result, "", " ") + if err != nil { + return err + } + + fmt.Println(string(data)) + + return nil +} diff --git a/cmd/agent/main.go b/cmd/agent/main.go new file mode 100644 index 0000000..bac6cbd --- /dev/null +++ b/cmd/agent/main.go @@ -0,0 +1,168 @@ +package main + +import ( + "context" + "os" + "os/signal" + "path/filepath" + "syscall" + + // Packages + kong "github.com/alecthomas/kong" + client "github.com/mutablelogic/go-client" + agent "github.com/mutablelogic/go-client/pkg/agent" + "github.com/mutablelogic/go-client/pkg/homeassistant" + "github.com/mutablelogic/go-client/pkg/ipify" + "github.com/mutablelogic/go-client/pkg/newsapi" + ollama "github.com/mutablelogic/go-client/pkg/ollama" + openai "github.com/mutablelogic/go-client/pkg/openai" + "github.com/mutablelogic/go-client/pkg/weatherapi" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +type Globals struct { + OllamaUrl string `name:"ollama-url" help:"URL of Ollama service (can be set from OLLAMA_URL env)" default:"${OLLAMA_URL}"` + OpenAIKey string `name:"openai-key" help:"API key for OpenAI service (can be set from OPENAI_API_KEY env)" default:"${OPENAI_API_KEY}"` + WeatherKey string `name:"weather-key" help:"API key for WeatherAPI service (can be set from WEATHERAPI_KEY env)" default:"${WEATHERAPI_KEY}"` + NewsKey string `name:"news-key" help:"API key for NewsAPI service (can be set from NEWSAPI_KEY env)" default:"${NEWSAPI_KEY}"` + HomeAssistantUrl string `name:"homeassistant-url" help:"URL of HomeAssistant service (can be set from HA_ENDPOINT env)" default:"${HA_ENDPOINT}"` + HomeAssistantKey string `name:"homeassistant-key" help:"API key for HomeAssistant service (can be set from HA_TOKEN env)" default:"${HA_TOKEN}"` + + // Debugging + Debug bool `name:"debug" help:"Enable debug output"` + Verbose bool `name:"verbose" help:"Enable verbose output"` + + ctx context.Context + agents []agent.Agent + tools []agent.Tool + state *State + + // Terminal interaction + term *Term +} + +type CLI struct { + Globals + + // Agents, Models and Tools + Agents ListAgentsCmd `cmd:"" help:"Return a list of agents"` + Models ListModelsCmd `cmd:"" help:"Return a list of models"` + Tools ListToolsCmd `cmd:"" help:"Return a list of tools"` + + // Generate Responses + Chat ChatCmd `cmd:"" help:"Generate a response from a chat message"` +} + +//////////////////////////////////////////////////////////////////////////////// +// MAIN + +func main() { + // The name of the executable + name, err := os.Executable() + if err != nil { + panic(err) + } else { + name = filepath.Base(name) + } + + // Create a cli parser + cli := CLI{} + cmd := kong.Parse(&cli, + kong.Name(name), + kong.Description("Agent command line interface"), + kong.UsageOnError(), + kong.ConfigureHelp(kong.HelpOptions{Compact: true}), + kong.Vars{ + "OLLAMA_URL": envOrDefault("OLLAMA_URL", ""), + "OPENAI_API_KEY": envOrDefault("OPENAI_API_KEY", ""), + "WEATHERAPI_KEY": envOrDefault("WEATHERAPI_KEY", ""), + "NEWSAPI_KEY": envOrDefault("NEWSAPI_KEY", ""), + "HA_TOKEN": envOrDefault("HA_TOKEN", ""), + "HA_ENDPOINT": envOrDefault("HA_ENDPOINT", ""), + }, + ) + + if cli.OllamaUrl != "" { + ollama, err := ollama.New(cli.OllamaUrl, clientOpts(&cli)...) + cmd.FatalIfErrorf(err) + cli.Globals.agents = append(cli.Globals.agents, ollama) + } + if cli.OpenAIKey != "" { + openai, err := openai.New(cli.OpenAIKey, clientOpts(&cli)...) + cmd.FatalIfErrorf(err) + cli.Globals.agents = append(cli.Globals.agents, openai) + } + if cli.WeatherKey != "" { + weather, err := weatherapi.New(cli.WeatherKey, clientOpts(&cli)...) + cmd.FatalIfErrorf(err) + cli.Globals.tools = append(cli.Globals.tools, weather.Tools()...) + } + if cli.NewsKey != "" { + news, err := newsapi.New(cli.NewsKey, clientOpts(&cli)...) + cmd.FatalIfErrorf(err) + cli.Globals.tools = append(cli.Globals.tools, news.Tools()...) + } + if cli.HomeAssistantKey != "" && cli.HomeAssistantUrl != "" { + ha, err := homeassistant.New(cli.HomeAssistantUrl, cli.HomeAssistantKey, clientOpts(&cli)...) + cmd.FatalIfErrorf(err) + cli.Globals.tools = append(cli.Globals.tools, ha.Tools()...) + } + + // Add ipify + ipify, err := ipify.New(clientOpts(&cli)...) + cmd.FatalIfErrorf(err) + cli.Globals.tools = append(cli.Globals.tools, ipify.Tools()...) + + // Create a context + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + cli.Globals.ctx = ctx + + // Create a state + if state, err := NewState(name); err != nil { + cmd.FatalIfErrorf(err) + return + } else { + cli.Globals.state = state + } + + // Terminal from stdin + if term, err := NewTerm(os.Stdin); err != nil { + cmd.FatalIfErrorf(err) + } else { + cli.Globals.term = term + } + + // Run the command + if err := cmd.Run(&cli.Globals); err != nil { + cmd.FatalIfErrorf(err) + return + } + + // Save state + if err := cli.Globals.state.Close(); err != nil { + cmd.FatalIfErrorf(err) + return + } +} + +//////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func envOrDefault(name, def string) string { + if value := os.Getenv(name); value != "" { + return value + } else { + return def + } +} + +func clientOpts(cli *CLI) []client.ClientOpt { + result := []client.ClientOpt{} + if cli.Debug { + result = append(result, client.OptTrace(os.Stderr, cli.Verbose)) + } + return result +} diff --git a/cmd/agent/state.go b/cmd/agent/state.go new file mode 100644 index 0000000..1553404 --- /dev/null +++ b/cmd/agent/state.go @@ -0,0 +1,97 @@ +package main + +import ( + "encoding/json" + "os" + "path/filepath" +) + +////////////////////////////////////////////////////////////////// +// TYPES + +type State struct { + Agent string `json:"agent"` + Model string `json:"model"` + + // Path of the state file + path string +} + +////////////////////////////////////////////////////////////////// +// GLOBALS + +const ( + // The name of the state file + stateFile = "state.json" +) + +////////////////////////////////////////////////////////////////// +// LIFECYCLE + +// Create a new state object with the given name +func NewState(name string) (*State, error) { + // Load the state from the file, or return a new empty state + path, err := os.UserConfigDir() + if err != nil { + return nil, err + } + + // Append the name of the application to the path + if name != "" { + path = filepath.Join(path, name) + } + + // Create the directory if it doesn't exist + if err := os.MkdirAll(path, 0700); err != nil { + return nil, err + } + + // The state to return + var state State + state.path = filepath.Join(path, stateFile) + + // Load the state from the file, ignore any errors + _ = state.Load() + + // Return success + return &state, nil +} + +// Release resources +func (s *State) Close() error { + return s.Save() +} + +////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +// Load state as JSON +func (s *State) Load() error { + // Open the file + file, err := os.Open(s.path) + if err != nil { + return nil + } + defer file.Close() + + // Decode the JSON + if err := json.NewDecoder(file).Decode(s); err != nil { + return err + } + + // Return success + return nil +} + +// Save state as JSON +func (s *State) Save() error { + // Open the file + file, err := os.Create(s.path) + if err != nil { + return err + } + defer file.Close() + + // Encode the JSON + return json.NewEncoder(file).Encode(s) +} diff --git a/cmd/agent/term.go b/cmd/agent/term.go new file mode 100644 index 0000000..6c3827c --- /dev/null +++ b/cmd/agent/term.go @@ -0,0 +1,65 @@ +package main + +import ( + "io" + "os" + + "golang.org/x/term" +) + +type Term struct { + r io.Reader + fd int + *term.Terminal +} + +func NewTerm(r io.Reader) (*Term, error) { + t := new(Term) + t.r = r + + // Set file descriptor + if osf, ok := r.(*os.File); ok { + t.fd = int(osf.Fd()) + if term.IsTerminal(t.fd) { + t.Terminal = term.NewTerminal(osf, "") + } + } + + // Return success + return t, nil +} + +// Returns the width and height of the terminal, or (0,0) +func (t *Term) Size() (int, int) { + if t.Terminal != nil { + if w, h, err := term.GetSize(t.fd); err == nil { + return w, h + } + } + // Unable to get the size + return 0, 0 +} + +func (t *Term) ReadLine(prompt string) (string, error) { + // Set terminal raw mode + if t.Terminal != nil { + state, err := term.MakeRaw(t.fd) + if err != nil { + return "", err + } + defer term.Restore(t.fd, state) + } + + // Set the prompt + if t.Terminal != nil { + t.Terminal.SetPrompt(prompt) + } + + // Read the line + if t.Terminal != nil { + return t.Terminal.ReadLine() + } else { + // Don't support non-terminal input yet + return "", io.EOF + } +} diff --git a/go.mod b/go.mod index 5a9cf22..d622468 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22 toolchain go1.22.3 require ( + github.com/alecthomas/kong v0.2.1-0.20190708041108-0548c6b1afae github.com/andreburgaud/crypt2go v1.5.0 github.com/djthorpe/go-errors v1.0.3 github.com/djthorpe/go-tablewriter v0.0.7 @@ -19,6 +20,25 @@ require ( ) require ( + github.com/MichaelMure/go-term-text v0.3.1 // indirect + github.com/alecthomas/chroma v0.7.1 // indirect + github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect + github.com/disintegration/imaging v1.6.2 // indirect + github.com/dlclark/regexp2 v1.1.6 // indirect + github.com/eliukblau/pixterm/pkg/ansimage v0.0.0-20191210081756-9fb6cf8c2f75 // indirect + github.com/fatih/color v1.9.0 // indirect + github.com/gomarkdown/markdown v0.0.0-20191123064959-2c17d62f5098 // indirect + github.com/kyokomi/emoji/v2 v2.2.8 // indirect + github.com/lucasb-eyer/go-colorful v1.0.3 // indirect + github.com/mattn/go-colorable v0.1.4 // indirect + github.com/mattn/go-isatty v0.0.11 // indirect + github.com/pkg/errors v0.8.1 // indirect + golang.org/x/image v0.0.0-20191206065243-da761ea9ff43 // indirect + golang.org/x/net v0.21.0 // indirect +) + +require ( + github.com/MichaelMure/go-term-markdown v0.1.4 github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-audio/riff v1.0.0 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect diff --git a/go.sum b/go.sum index 196c86b..7d3e95b 100644 --- a/go.sum +++ b/go.sum @@ -1,39 +1,102 @@ +github.com/MichaelMure/go-term-markdown v0.1.4 h1:Ir3kBXDUtOX7dEv0EaQV8CNPpH+T7AfTh0eniMOtNcs= +github.com/MichaelMure/go-term-markdown v0.1.4/go.mod h1:EhcA3+pKYnlUsxYKBJ5Sn1cTQmmBMjeNlpV8nRb+JxA= +github.com/MichaelMure/go-term-text v0.3.1 h1:Kw9kZanyZWiCHOYu9v/8pWEgDQ6UVN9/ix2Vd2zzWf0= +github.com/MichaelMure/go-term-text v0.3.1/go.mod h1:QgVjAEDUnRMlzpS6ky5CGblux7ebeiLnuy9dAaFZu8o= +github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38 h1:smF2tmSOzy2Mm+0dGI2AIUHY+w0BUc+4tn40djz7+6U= +github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38/go.mod h1:r7bzyVFMNntcxPZXK3/+KdruV1H5KSlyVY0gc+NgInI= +github.com/alecthomas/chroma v0.7.1 h1:G1i02OhUbRi2nJxcNkwJaY/J1gHXj9tt72qN6ZouLFQ= +github.com/alecthomas/chroma v0.7.1/go.mod h1:gHw09mkX1Qp80JlYbmN9L3+4R5o6DJJ3GRShh+AICNc= +github.com/alecthomas/colour v0.0.0-20160524082231-60882d9e2721 h1:JHZL0hZKJ1VENNfmXvHbgYlbUOvpzYzvy2aZU5gXVeo= +github.com/alecthomas/colour v0.0.0-20160524082231-60882d9e2721/go.mod h1:QO9JBoKquHd+jz9nshCh40fOfO+JzsoXy8qTHF68zU0= +github.com/alecthomas/kong v0.2.1-0.20190708041108-0548c6b1afae h1:C4Q9m+oXOxcSWwYk9XzzafY2xAVAaeubZbUHJkw3PlY= +github.com/alecthomas/kong v0.2.1-0.20190708041108-0548c6b1afae/go.mod h1:+inYUSluD+p4L8KdviBSgzcqEjUQOfC5fQDRFuc36lI= +github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897 h1:p9Sln00KOTlrYkxI1zYWl1QLnEqAqEARBEYa8FQnQcY= +github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= github.com/andreburgaud/crypt2go v1.5.0 h1:7hz8l9WjaMEtAUL4+nMm64Of7HzUr1H4JhmNof7BCLc= github.com/andreburgaud/crypt2go v1.5.0/go.mod h1:ZEu8s+aLbZdRNdSHr//o6gCSMYKgT24sjNX6r4uAI8U= +github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 h1:y5HC9v93H5EPKqaS1UYVg1uYah5Xf51mBfIoWehClUQ= +github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964/go.mod h1:Xd9hchkHSWYkEqJwUGisez3G1QY8Ryz0sdWrLPMGjLk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c= +github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4= github.com/djthorpe/go-errors v1.0.3 h1:GZeMPkC1mx2vteXLI/gvxZS0Ee9zxzwD1mcYyKU5jD0= github.com/djthorpe/go-errors v1.0.3/go.mod h1:HtfrZnMd6HsX75Mtbv9Qcnn0BqOrrFArvCaj3RMnZhY= github.com/djthorpe/go-tablewriter v0.0.7 h1:jnNsJDjjLLCt0OAqB5DzGZN7V3beT1IpNMQ8GcOwZDU= github.com/djthorpe/go-tablewriter v0.0.7/go.mod h1:NVBvytpL+6fHfCKn0+3lSi15/G3A1HWf2cLNeHg6YBg= +github.com/dlclark/regexp2 v1.1.6 h1:CqB4MjHw0MFCDj+PHHjiESmHX+N7t0tJzKvC6M97BRg= +github.com/dlclark/regexp2 v1.1.6/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= +github.com/eliukblau/pixterm/pkg/ansimage v0.0.0-20191210081756-9fb6cf8c2f75 h1:vbix8DDQ/rfatfFr/8cf/sJfIL69i4BcZfjrVOxsMqk= +github.com/eliukblau/pixterm/pkg/ansimage v0.0.0-20191210081756-9fb6cf8c2f75/go.mod h1:0gZuvTO1ikSA5LtTI6E13LEOdWQNjIo5MTQOvrV0eFg= +github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= +github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs= github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA= github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498= github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g= github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE= +github.com/gomarkdown/markdown v0.0.0-20191123064959-2c17d62f5098 h1:Qxs3bNRWe8GTcKMxYOSXm0jx6j0de8XUtb/fsP3GZ0I= +github.com/gomarkdown/markdown v0.0.0-20191123064959-2c17d62f5098/go.mod h1:aii0r/K0ZnHv7G0KF7xy1v0A7s2Ljrb5byB7MO5p6TU= +github.com/kyokomi/emoji/v2 v2.2.8 h1:jcofPxjHWEkJtkIbcLHvZhxKgCPl6C7MyjTrD4KDqUE= +github.com/kyokomi/emoji/v2 v2.2.8/go.mod h1:JUcn42DTdsXJo1SWanHh4HKDEyPaR5CqkmoirZZP9qE= +github.com/lucasb-eyer/go-colorful v1.0.3 h1:QIbQXiugsb+q10B+MI+7DI1oQLdmnep86tWFlaaUAac= +github.com/lucasb-eyer/go-colorful v1.0.3/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= +github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= +github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= +github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mutablelogic/go-server v1.4.7 h1:NpzG30f/D50Xbwr96dA6uiapyr4QHBziSanc/q/LR7k= github.com/mutablelogic/go-server v1.4.7/go.mod h1:wrrDg863hlv5/DUpSG/Pb4k9LiSYO7VxRgLPiMhrE6M= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +golang.org/dl v0.0.0-20190829154251-82a15e2f2ead/go.mod h1:IUMfjQLJQd4UTqG1Z90tenwKoCX93Gn3MAQJMOSBsDQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.0.0-20191206065243-da761ea9ff43 h1:gQ6GUSD102fPgli+Yb4cR/cGaHF7tNBt+GYoRCpGC7s= +golang.org/x/image v0.0.0-20191206065243-da761ea9ff43/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sys v0.0.0-20181128092732-4ed8d59d0b35/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go new file mode 100644 index 0000000..bc6381f --- /dev/null +++ b/pkg/agent/agent.go @@ -0,0 +1,18 @@ +package agent + +import "context" + +// An LLM Agent is a client for the LLM service +type Agent interface { + // Return the name of the agent + Name() string + + // Return the models + Models(context.Context) ([]Model, error) + + // Generate a response from a prompt + Generate(context.Context, Model, []Context, ...Opt) (*Response, error) + + // Create user message context + UserPrompt(string) Context +} diff --git a/pkg/agent/context.go b/pkg/agent/context.go new file mode 100644 index 0000000..2dce2a6 --- /dev/null +++ b/pkg/agent/context.go @@ -0,0 +1,10 @@ +package agent + +////////////////////////////////////////////////////////////////// +// TYPES + +// Context is fed to the agent to generate a response. Role can be +// assistant, user, tool, tool_result, ... +type Context interface { + Role() string +} diff --git a/pkg/agent/model.go b/pkg/agent/model.go new file mode 100644 index 0000000..3329bf4 --- /dev/null +++ b/pkg/agent/model.go @@ -0,0 +1,7 @@ +package agent + +// An LLM Agent is a client for the LLM service +type Model interface { + // Return the name of the model + Name() string +} diff --git a/pkg/agent/opt.go b/pkg/agent/opt.go new file mode 100644 index 0000000..235f0a4 --- /dev/null +++ b/pkg/agent/opt.go @@ -0,0 +1,41 @@ +package agent + +import "fmt" + +////////////////////////////////////////////////////////////////// +// TYPES + +type Opts struct { + Tools []Tool + StreamFn func(Response) +} + +type Opt func(*Opts) error + +////////////////////////////////////////////////////////////////// +// METHODS + +// OptStream sets the stream function, which is called during the +// response generation process +func OptStream(fn func(Response)) Opt { + return func(o *Opts) error { + o.StreamFn = fn + return nil + } +} + +// OptTools sets the tools for the chat request +func OptTools(t ...Tool) Opt { + return func(o *Opts) error { + if len(t) == 0 { + return fmt.Errorf("no tools specified") + } + for _, tool := range t { + if tool == nil { + return fmt.Errorf("nil tool specified") + } + o.Tools = append(o.Tools, tool) + } + return nil + } +} diff --git a/pkg/agent/response.go b/pkg/agent/response.go new file mode 100644 index 0000000..6b30671 --- /dev/null +++ b/pkg/agent/response.go @@ -0,0 +1,30 @@ +package agent + +import ( + "encoding/json" + "time" +) + +////////////////////////////////////////////////////////////////// +// TYPES + +type Response struct { + Agent string `json:"agent,omitempty"` // The agent name + Model string `json:"model,omitempty"` // The model name + Context []Context `json:"context,omitempty"` // The context for the response + Text string `json:"text,omitempty"` // The response text + *ToolCall `json:"tool,omitempty"` // The tool call, if not nil + Tokens uint `json:"tokens,omitempty"` // The number of tokens + Duration time.Duration `json:"duration,omitempty"` // The response duration +} + +////////////////////////////////////////////////////////////////// +// STRINGIFY + +func (r Response) String() string { + data, err := json.MarshalIndent(r, "", " ") + if err != nil { + return err.Error() + } + return string(data) +} diff --git a/pkg/agent/tool.go b/pkg/agent/tool.go new file mode 100644 index 0000000..c71cd81 --- /dev/null +++ b/pkg/agent/tool.go @@ -0,0 +1,99 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + + // Namespace imports + . "github.com/djthorpe/go-errors" +) + +//////////////////////////////////////////////////////////////////////////////// +// TYPES + +// A tool can be called from an LLM +type Tool interface { + // Return the provider of the tool + Provider() string + + // Return the name of the tool + Name() string + + // Return the description of the tool + Description() string + + // Tool parameters + Params() []ToolParameter + + // Execute the tool with a specific tool + Run(context.Context, *ToolCall) (*ToolResult, error) +} + +// A tool parameter +type ToolParameter struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Required bool `json:"required,omitempty"` +} + +// A call to a tool +type ToolCall struct { + Id string `json:"id"` + Name string `json:"name"` + Args map[string]any `json:"args"` +} + +// The result of a tool call +type ToolResult struct { + Id string `json:"id"` + Result map[string]any `json:"result,omitempty"` +} + +//////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the arguments for the call as a JSON +func (t *ToolCall) JSON() string { + data, err := json.MarshalIndent(t.Args, "", " ") + if err != nil { + return err.Error() + } else { + return string(data) + } +} + +// Return role for the tool result +func (t *ToolResult) Role() string { + return "tool" +} + +// Return parameter as a string +func (t *ToolCall) String(name string) (string, error) { + v, ok := t.Args[name] + if !ok { + return "", ErrNotFound.Withf("%q not found", name) + } + return fmt.Sprint(v), nil +} + +// Return parameter as an integer +func (t *ToolCall) Int(name string) (int, error) { + v, ok := t.Args[name] + if !ok { + return 0, ErrNotFound.Withf("%q not found", name) + } + switch v := v.(type) { + case int: + return v, nil + case string: + if v_, err := strconv.ParseInt(v, 10, 32); err != nil { + return 0, ErrBadParameter.Withf("%q: %v", name, err) + } else { + return int(v_), nil + } + default: + return 0, ErrBadParameter.Withf("%q: Expected integer, got %T", name, v) + } +} diff --git a/pkg/homeassistant/agent.go b/pkg/homeassistant/agent.go new file mode 100644 index 0000000..bb1825d --- /dev/null +++ b/pkg/homeassistant/agent.go @@ -0,0 +1,179 @@ +package homeassistant + +import ( + "context" + "errors" + "slices" + "strings" + + // Packages + agent "github.com/mutablelogic/go-client/pkg/agent" + + // Namespace imports + . "github.com/djthorpe/go-errors" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type tool struct { + name string + description string + params []agent.ToolParameter + run func(context.Context, *agent.ToolCall) (*agent.ToolResult, error) +} + +// Ensure tool satisfies the agent.Tool interface +var _ agent.Tool = (*tool)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return all the agent tools for the weatherapi +func (c *Client) Tools() []agent.Tool { + return []agent.Tool{ + &tool{ + name: "devices", + description: "Lookup all device id's in the home, or search for a device ny name", + run: c.agentGetDeviceIds, + params: []agent.ToolParameter{ + { + Name: "name", + Description: "Name to filter devices", + }, + }, + }, &tool{ + name: "get_device_state", + description: "Return the current state of a device, given the device id", + run: c.agentGetDeviceState, + params: []agent.ToolParameter{ + { + Name: "device", + Description: "The device id", + Required: true, + }, + }, + }, + } +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS - TOOL + +func (*tool) Provider() string { + return "homeassistant" +} + +func (t *tool) Name() string { + return t.name +} + +func (t *tool) Description() string { + return t.description +} + +func (t *tool) Params() []agent.ToolParameter { + return t.params +} + +func (t *tool) Run(ctx context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + return t.run(ctx, call) +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS - TOOL + +var ( + allowedClasses = []string{ + "temperature", + "humidity", + "battery", + "select", + "number", + "switch", + "enum", + "light", + "sensor", + "binary_sensor", + "remote", + "climate", + "occupancy", + "motion", + "button", + "door", + "lock", + "tv", + "vacuum", + } +) + +// Return the current devices and their id's +func (c *Client) agentGetDeviceIds(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + name, err := call.String("name") + if errors.Is(err, ErrNotFound) { + name = "" + } else if err != nil { + return nil, err + } + + // Query all devices + devices, err := c.States() + if err != nil { + return nil, err + } + + // Make the device id's + type DeviceId struct { + Id string `json:"id"` + Name string `json:"name"` + } + var result []DeviceId + for _, device := range devices { + if !slices.Contains(allowedClasses, device.Class()) { + continue + } + var found bool + if name != "" { + if strings.Contains(strings.ToLower(device.Name()), strings.ToLower(name)) { + found = true + } else if strings.Contains(strings.ToLower(device.Class()), strings.ToLower(name)) { + found = true + } + if !found { + continue + } + } + result = append(result, DeviceId{ + Id: device.Entity, + Name: device.Name(), + }) + } + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "devices": result, + }, + }, nil +} + +// Return a device state +func (c *Client) agentGetDeviceState(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + device, err := call.String("device") + if err != nil { + return nil, err + } + + state, err := c.State(device) + if err != nil { + return nil, err + } + + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "device": state, + }, + }, nil +} diff --git a/pkg/ipify/agent.go b/pkg/ipify/agent.go index b5053ea..56c91a9 100644 --- a/pkg/ipify/agent.go +++ b/pkg/ipify/agent.go @@ -4,33 +4,73 @@ import ( "context" // Packages - schema "github.com/mutablelogic/go-client/pkg/openai/schema" - - // Namespace imports - . "github.com/djthorpe/go-errors" + agent "github.com/mutablelogic/go-client/pkg/agent" ) /////////////////////////////////////////////////////////////////////////////// // TYPES +type tool struct { + name string + description string + params []agent.ToolParameter + run func(context.Context, *agent.ToolCall) (*agent.ToolResult, error) +} + +// Ensure tool satisfies the agent.Tool interface +var _ agent.Tool = (*tool)(nil) + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS -// Return the tools available -func (c *Client) Tools() []*schema.Tool { - if get_ip_address, err := schema.NewToolEx("get_ip_address", "Get the current IP address.", nil); err != nil { - panic(err) - } else { - return []*schema.Tool{get_ip_address} +// Return all the agent tools for the weatherapi +func (c *Client) Tools() []agent.Tool { + return []agent.Tool{ + &tool{ + name: "get_ip_address", + description: "Return your IP address", + run: c.agentGetAddress, + }, } } -// Run a tool and return the result -func (c *Client) Run(ctx context.Context, name string, _ any) (any, error) { - switch name { - case "get_ip_address": - return c.Get() - default: - return nil, ErrInternalAppError.With(name) +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS - TOOL + +func (*tool) Provider() string { + return "ipify" +} + +func (t *tool) Name() string { + return t.name +} + +func (t *tool) Description() string { + return t.description +} + +func (t *tool) Params() []agent.ToolParameter { + return t.params +} + +func (t *tool) Run(ctx context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + return t.run(ctx, call) +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS - TOOL + +// Return the current general headlines +func (c *Client) agentGetAddress(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + response, err := c.Get() + if err != nil { + return nil, err } + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "ip_address": response, + }, + }, nil } diff --git a/pkg/ipify/agent_test.go b/pkg/ipify/agent_test.go deleted file mode 100644 index c7ef366..0000000 --- a/pkg/ipify/agent_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package ipify_test - -import ( - "os" - "testing" - - // Packages - opts "github.com/mutablelogic/go-client" - ipify "github.com/mutablelogic/go-client/pkg/ipify" - assert "github.com/stretchr/testify/assert" -) - -func Test_agent_001(t *testing.T) { - assert := assert.New(t) - client, err := ipify.New(opts.OptTrace(os.Stderr, true)) - assert.NoError(err) - assert.NotNil(client) - - tools := client.Tools() - assert.NotEmpty(tools) - - t.Log(tools) -} diff --git a/pkg/newsapi/agent.go b/pkg/newsapi/agent.go new file mode 100644 index 0000000..0819299 --- /dev/null +++ b/pkg/newsapi/agent.go @@ -0,0 +1,172 @@ +package newsapi + +import ( + "context" + "strings" + + // Packages + agent "github.com/mutablelogic/go-client/pkg/agent" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type tool struct { + name string + description string + params []agent.ToolParameter + run func(context.Context, *agent.ToolCall) (*agent.ToolResult, error) +} + +// Ensure tool satisfies the agent.Tool interface +var _ agent.Tool = (*tool)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return all the agent tools for the weatherapi +func (c *Client) Tools() []agent.Tool { + return []agent.Tool{ + &tool{ + name: "current_headlines", + description: "Return the current news headlines", + run: c.agentCurrentHeadlines, + }, &tool{ + name: "current_headlines_country", + description: "Return the current news headlines for a country", + run: c.agentCountryHeadlines, + params: []agent.ToolParameter{ + { + Name: "countrycode", + Description: "The two-letter country code to return headlines for", + Required: true, + }, + }, + }, &tool{ + name: "current_headlines_category", + description: "Return the current news headlines for a business, entertainment, health, science, sports or technology", + run: c.agentCategoryHeadlines, + params: []agent.ToolParameter{ + { + Name: "category", + Description: "business, entertainment, health, science, sports, technology", + Required: true, + }, + }, + }, &tool{ + name: "search_news", + description: "Return the news headlines with a search query", + run: c.agentSearchNews, + params: []agent.ToolParameter{ + { + Name: "query", + Description: "A phrase used to search for news headlines", + Required: true, + }, + }, + }, + } +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS - TOOL + +func (*tool) Provider() string { + return "newsapi" +} + +func (t *tool) Name() string { + return t.name +} + +func (t *tool) Description() string { + return t.description +} + +func (t *tool) Params() []agent.ToolParameter { + return t.params +} + +func (t *tool) Run(ctx context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + return t.run(ctx, call) +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS - TOOL + +// Return the current general headlines +func (c *Client) agentCurrentHeadlines(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + response, err := c.Headlines(OptCategory("general"), OptLimit(5)) + if err != nil { + return nil, err + } + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "headlines": response, + }, + }, nil +} + +// Return the headlines for a specific country +func (c *Client) agentCountryHeadlines(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + country, err := call.String("countrycode") + if err != nil { + return nil, err + } + country = strings.ToLower(country) + response, err := c.Headlines(OptCountry(country), OptLimit(5)) + if err != nil { + return nil, err + } + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "country": country, + "headlines": response, + }, + }, nil +} + +// Return the headlines for a specific category +func (c *Client) agentCategoryHeadlines(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + category, err := call.String("category") + if err != nil { + return nil, err + } + category = strings.ToLower(category) + response, err := c.Headlines(OptCategory(category), OptLimit(5)) + if err != nil { + return nil, err + } + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "category": category, + "headlines": response, + }, + }, nil +} + +// Return the headlines for a specific query +func (c *Client) agentSearchNews(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + query, err := call.String("query") + if err != nil { + return nil, err + } + response, err := c.Articles(OptQuery(query), OptLimit(5)) + if err != nil { + return nil, err + } + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "query": query, + "headlines": response, + }, + }, nil +} diff --git a/pkg/ollama/agent.go b/pkg/ollama/agent.go new file mode 100644 index 0000000..a794d58 --- /dev/null +++ b/pkg/ollama/agent.go @@ -0,0 +1,128 @@ +package ollama + +import ( + "context" + "fmt" + "time" + + // Packages + "github.com/mutablelogic/go-client/pkg/agent" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type model struct { + *Model +} + +type userPrompt string + +// Ensure Ollama client satisfies the agent.Agent interface +var _ agent.Agent = (*Client)(nil) + +// Ensure model satisfies the agent.Model interface +var _ agent.Model = (*model)(nil) + +// Ensure userPrompt satisfies the agent.Context interface +var _ agent.Context = userPrompt("") + +///////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the agent name +func (*Client) Name() string { + return "ollama" +} + +// Return the model name +func (m *model) Name() string { + return m.Model.Name +} + +// Return all the models and their capabilities +func (o *Client) Models(context.Context) ([]agent.Model, error) { + models, err := o.ListModels() + if err != nil { + return nil, err + } + + // Append models + result := make([]agent.Model, len(models)) + for i, m := range models { + result[i] = &model{Model: &m} + } + + // Return success + return result, nil +} + +// Return the role +func (userPrompt) Role() string { + return "user" +} + +// Create a user prompt +func (o *Client) UserPrompt(v string) agent.Context { + return userPrompt(v) +} + +// Generate a response from a text message +func (o *Client) Generate(ctx context.Context, model agent.Model, context []agent.Context, opts ...agent.Opt) (*agent.Response, error) { + // Get options + chatopts, err := newOpts(opts...) + if err != nil { + return nil, err + } + + if len(context) != 1 { + return nil, fmt.Errorf("context must contain exactly one element") + } + + prompt, ok := context[0].(userPrompt) + if !ok { + return nil, fmt.Errorf("context must contain a user prompt") + } + + // Generate a response + status, err := o.ChatGenerate(ctx, model.Name(), string(prompt), chatopts...) + if err != nil { + return nil, err + } + + // Create a response + response := agent.Response{ + Agent: o.Name(), + Model: model.Name(), + Text: status.Response, + Tokens: uint(status.ResponseTokens), + Duration: time.Nanosecond * time.Duration(status.TotalDurationNs), + } + + // Return success + return &response, nil +} + +///////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func newOpts(opts ...agent.Opt) ([]ChatOpt, error) { + // Apply the options + var o agent.Opts + for _, opt := range opts { + if err := opt(&o); err != nil { + return nil, err + } + } + + // Create local options + result := make([]ChatOpt, 0, len(opts)) + if o.StreamFn != nil { + result = append(result, OptStream(func(text string) { + fmt.Println(text) + })) + } + + // Return success + return result, nil +} diff --git a/pkg/ollama/client.go b/pkg/ollama/client.go index 387f32c..93f2c55 100644 --- a/pkg/ollama/client.go +++ b/pkg/ollama/client.go @@ -7,6 +7,7 @@ package ollama import ( // Packages "github.com/mutablelogic/go-client" + "github.com/mutablelogic/go-client/pkg/agent" ) /////////////////////////////////////////////////////////////////////////////// @@ -16,6 +17,9 @@ type Client struct { *client.Client } +// Ensure it satisfies the agent.Agent interface +var _ agent.Agent = (*Client)(nil) + /////////////////////////////////////////////////////////////////////////////// // LIFECYCLE diff --git a/pkg/ollama/client_test.go b/pkg/ollama/client_test.go index 94bf905..94136c9 100644 --- a/pkg/ollama/client_test.go +++ b/pkg/ollama/client_test.go @@ -22,9 +22,9 @@ func Test_client_001(t *testing.T) { // ENVIRONMENT func GetEndpoint(t *testing.T) string { - key := os.Getenv("OLLAMA_ENDPOINT") + key := os.Getenv("OLLAMA_URL") if key == "" { - t.Skip("OLLAMA_ENDPOINT not set") + t.Skip("OLLAMA_URL not set") t.SkipNow() } return key diff --git a/pkg/openai/agent.go b/pkg/openai/agent.go new file mode 100644 index 0000000..d3594e8 --- /dev/null +++ b/pkg/openai/agent.go @@ -0,0 +1,189 @@ +package openai + +import ( + "context" + "reflect" + "time" + + // Package imports + agent "github.com/mutablelogic/go-client/pkg/agent" + schema "github.com/mutablelogic/go-client/pkg/openai/schema" + + // Namespace imports + . "github.com/djthorpe/go-errors" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type model struct { + *schema.Model +} + +type message struct { + *schema.Message +} + +// Ensure Ollama client satisfies the agent.Agent interface +var _ agent.Agent = (*Client)(nil) + +// Ensure model satisfies the agent.Model interface +var _ agent.Model = (*model)(nil) + +// Ensure context satisfies the agent.Context interface +var _ agent.Context = (*message)(nil) + +///////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return the agent name +func (*Client) Name() string { + return "openai" +} + +// Return the model name +func (m *model) Name() string { + return m.Model.Id +} + +// Return the context role +func (m *message) Role() string { + return m.Message.Role +} + +// Return all the models and their capabilities +func (o *Client) Models(context.Context) ([]agent.Model, error) { + models, err := o.ListModels() + if err != nil { + return nil, err + } + + // Append models + result := make([]agent.Model, len(models)) + for i, m := range models { + result[i] = &model{Model: &m} + } + + // Return success + return result, nil +} + +// Create a user prompt +func (o *Client) UserPrompt(v string) agent.Context { + return &message{schema.NewMessage("user", v)} +} + +// Generate a response from a text message +func (o *Client) Generate(ctx context.Context, model agent.Model, content []agent.Context, opts ...agent.Opt) (*agent.Response, error) { + // Get options + chatopts, err := newOpts(opts...) + if err != nil { + return nil, err + } + + // Add model + chatopts = append(chatopts, OptModel(model.Name())) + + // Add usage option + now := time.Now() + response := agent.Response{ + Agent: o.Name(), + Model: model.Name(), + } + chatopts = append(chatopts, OptUsage(func(u schema.TokenUsage) { + response.Tokens = uint(u.TotalTokens) + response.Duration = time.Since(now) + })) + + // Create messages + messages := make([]*schema.Message, 0, len(content)) + for _, c := range content { + if message, ok := c.(*message); ok { + messages = append(messages, message.Message) + } else if toolresult, ok := c.(*agent.ToolResult); ok { + messages = append(messages, schema.NewToolResult(toolresult.Id, toolresult.Result)) + } else { + return nil, ErrBadParameter.Withf("context must contain a message (not %T)", c) + } + } + + // Append messages to the response + for _, m := range messages { + response.Context = append(response.Context, &message{m}) + } + + // Generate a response + response_content, err := o.Chat(ctx, messages, chatopts...) + if err != nil { + return nil, err + } + + // Combine content into a single response, and add to the context + for _, c := range response_content { + if c.Text != "" { + response.Text += c.Text + } else if c.Type == "function" { + response.ToolCall = &agent.ToolCall{Id: c.Id, Name: c.Name, Args: c.Input} + } + } + + // Append the response to the context + if response.ToolCall != nil { + m := schema.NewMessage("assistant", "") + m.ToolCalls = []schema.ToolCall{ + { + Id: response.ToolCall.Id, + Type: "function", + Function: schema.ToolFunction{ + Name: response.ToolCall.Name, + Arguments: response.ToolCall.JSON(), + }, + }, + } + response.Context = append(response.Context, &message{m}) + } else { + response.Context = append(response.Context, &message{schema.NewMessage("assistant", response.Text)}) + } + + // Return success + return &response, nil +} + +///////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS + +func newOpts(opts ...agent.Opt) ([]Opt, error) { + // Apply the options + var o agent.Opts + for _, opt := range opts { + if err := opt(&o); err != nil { + return nil, err + } + } + + // Create local options + result := make([]Opt, 0, len(opts)) + + // Stream + if o.StreamFn != nil { + result = append(result, OptStream(func(text schema.MessageChoice) { + if text.Delta != nil && text.Delta.Content != "" { + o.StreamFn(agent.Response{ + Text: text.Delta.Content, + }) + } + })) + } + + // Create tools + for _, tool := range o.Tools { + otool := schema.NewTool(tool.Name(), tool.Description()) + for _, param := range tool.Params() { + otool.Add(param.Name, param.Description, param.Required, reflect.TypeOf("")) + } + result = append(result, OptTool(otool)) + } + + // Return success + return result, nil +} diff --git a/pkg/openai/chat.go b/pkg/openai/chat.go index 4f29a32..25ab38d 100644 --- a/pkg/openai/chat.go +++ b/pkg/openai/chat.go @@ -114,6 +114,9 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O if choice.Message.Content == nil { continue } + if choice.Message.Role != "assistant" { + return nil, ErrUnexpectedResponse.With("unexpected content role ", choice.Message.Role) + } switch v := choice.Message.Content.(type) { case []string: for _, v := range v { @@ -126,6 +129,11 @@ func (c *Client) Chat(ctx context.Context, messages []*schema.Message, opts ...O } } + // Usage callback + if request.Usage != nil { + request.Usage(response.TokenUsage) + } + // Return success return result, nil } diff --git a/pkg/openai/opts.go b/pkg/openai/opts.go index cb2ab59..fd12563 100644 --- a/pkg/openai/opts.go +++ b/pkg/openai/opts.go @@ -40,6 +40,9 @@ type options struct { Quality string `json:"quality,omitempty"` Size string `json:"size,omitempty"` Style string `json:"style,omitempty"` + + // Options for usage + Usage UsageFn `json:"-"` } type streamoptions struct { @@ -52,6 +55,9 @@ type Opt func(*options) error // Callback when new stream data is received type Callback func(schema.MessageChoice) +// Callback to set the token usage +type UsageFn func(schema.TokenUsage) + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS @@ -168,6 +174,7 @@ func OptTool(value ...*schema.Tool) Opt { return ErrBadParameter.With("OptTool") } } + // Append tools o.Tools = append(o.Tools, value...) @@ -244,6 +251,17 @@ func OptStyle(value string) Opt { } } +// The style of the generated images. Must be one of vivid or natural. +// Vivid causes the model to lean towards generating hyper-real and +// dramatic images. Natural causes the model to produce more natural, +// less hyper-real looking images. +func OptUsage(fn UsageFn) Opt { + return func(o *options) error { + o.Usage = fn + return nil + } +} + /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS - CLIENT diff --git a/pkg/openai/schema/agent.go b/pkg/openai/schema/agent.go deleted file mode 100644 index df15445..0000000 --- a/pkg/openai/schema/agent.go +++ /dev/null @@ -1,15 +0,0 @@ -package schema - -import "context" - -//////////////////////////////////////////////////////////////////////////////// -// TYPES - -// An agent is a collection of tools that can be called -type Agent interface { - // Enumerate the tools available - Tools() []Tool - - // Run a tool with parameters and return the result - Run(context.Context, string, any) (any, error) -} diff --git a/pkg/openai/schema/message.go b/pkg/openai/schema/message.go index 45cff7d..95f6faf 100644 --- a/pkg/openai/schema/message.go +++ b/pkg/openai/schema/message.go @@ -35,6 +35,9 @@ type Message struct { // Any tool calls ToolCalls []ToolCall `json:"tool_calls,omitempty"` + // Tool Call Id + ToolCallId string `json:"tool_call_id,omitempty"` + // Time the message was created, in unix seconds Created int64 `json:"created,omitempty"` } diff --git a/pkg/openai/schema/tool.go b/pkg/openai/schema/tool.go index d6c8cf2..6916ebe 100644 --- a/pkg/openai/schema/tool.go +++ b/pkg/openai/schema/tool.go @@ -5,9 +5,11 @@ import ( "fmt" "reflect" + // Package imports + "github.com/djthorpe/go-tablewriter/pkg/meta" + // Namespace imports . "github.com/djthorpe/go-errors" - "github.com/djthorpe/go-tablewriter/pkg/meta" ) /////////////////////////////////////////////////////////////////////////////// @@ -57,10 +59,6 @@ func NewTool(name, description string) *Tool { return &Tool{ Name: name, Description: description, - Parameters: &toolParameters{ - Type: "object", - Properties: make(map[string]toolParameter), - }, } } @@ -87,6 +85,21 @@ func NewToolEx(name, description string, parameters any) (*Tool, error) { return t, nil } +func NewToolResult(id string, result map[string]any) *Message { + var message Message + message.Role = "tool" + message.ToolCallId = id + + data, err := json.Marshal(result) + if err != nil { + return nil + } else { + message.Content = string(data) + } + + return &message +} + /////////////////////////////////////////////////////////////////////////////// // STRINGIFY @@ -102,6 +115,12 @@ func (tool *Tool) Add(name, description string, required bool, t reflect.Type) e if name == "" { return ErrBadParameter.With("missing name") } + if tool.Parameters == nil { + tool.Parameters = &toolParameters{ + Type: "object", + Properties: make(map[string]toolParameter), + } + } if _, exists := tool.Parameters.Properties[name]; exists { return ErrDuplicateEntry.With(name) } diff --git a/pkg/weatherapi/agent.go b/pkg/weatherapi/agent.go new file mode 100644 index 0000000..6be6429 --- /dev/null +++ b/pkg/weatherapi/agent.go @@ -0,0 +1,215 @@ +package weatherapi + +import ( + "context" + + // Packages + agent "github.com/mutablelogic/go-client/pkg/agent" + + // Namespace imports + . "github.com/djthorpe/go-errors" +) + +/////////////////////////////////////////////////////////////////////////////// +// TYPES + +type tool struct { + name string + description string + params []agent.ToolParameter + run func(context.Context, *agent.ToolCall) (*agent.ToolResult, error) +} + +// Ensure tool satisfies the agent.Tool interface +var _ agent.Tool = (*tool)(nil) + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Return all the agent tools for the weatherapi +func (c *Client) Tools() []agent.Tool { + return []agent.Tool{ + &tool{ + name: "current_weather", + description: "Return the current weather", + run: c.agentCurrentWeatherAuto, + }, &tool{ + name: "current_weather_city", + description: "Return the current weather for a city", + params: []agent.ToolParameter{ + {Name: "city", Description: "City name", Required: true}, + }, + run: c.agentCurrentWeatherCity, + }, &tool{ + name: "current_weather_zip", + description: "Return the current weather for a zipcode or postcode", + params: []agent.ToolParameter{ + {Name: "zip", Description: "Zipcode or Postcode", Required: true}, + }, + run: c.agentCurrentWeatherZipcode, + }, &tool{ + name: "weather_forecast", + description: "Return the weather forecast", + run: c.agentForecastWeatherAuto, + params: []agent.ToolParameter{ + {Name: "days", Description: "Number of days to forecast ahead", Required: true}, + }, + }, &tool{ + name: "weather_forecast_city", + description: "Return the weather forecast for a city", + run: c.agentForecastWeatherCity, + params: []agent.ToolParameter{ + {Name: "city", Description: "City name", Required: true}, + {Name: "days", Description: "Number of days to forecast ahead", Required: true}, + }, + }, + } +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS - TOOL + +func (*tool) Provider() string { + return "weatherapi" +} + +func (t *tool) Name() string { + return t.name +} + +func (t *tool) Description() string { + return t.description +} + +func (t *tool) Params() []agent.ToolParameter { + return t.params +} + +func (t *tool) Run(ctx context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + return t.run(ctx, call) +} + +/////////////////////////////////////////////////////////////////////////////// +// PRIVATE METHODS - TOOL + +// Return the current weather +func (c *Client) agentCurrentWeatherAuto(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + response, err := c.Current("auto:ip") + if err != nil { + return nil, err + } + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "location": response.Location, + "weather": response.Current, + }, + }, nil +} + +// Return the current weather in a specific city +func (c *Client) agentCurrentWeatherCity(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + city, ok := call.Args["city"].(string) + if !ok || city == "" { + return nil, ErrBadParameter.Withf("city is required") + } + response, err := c.Current(city) + if err != nil { + return nil, err + } + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "location": response.Location, + "weather": response.Current, + }, + }, nil +} + +// Return the current weather for a zipcode +func (c *Client) agentCurrentWeatherZipcode(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + zip, ok := call.Args["zip"].(string) + if !ok || zip == "" { + return nil, ErrBadParameter.Withf("zipcode is required") + } + response, err := c.Current(zip) + if err != nil { + return nil, err + } + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "location": response.Location, + "weather": response.Current, + }, + }, nil +} + +// Return the weather forecast +func (c *Client) agentForecastWeatherAuto(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + // Get days parameter + days, err := call.Int("days") + if err != nil { + return nil, err + } + + // Get response + response, err := c.Forecast("auto:ip", OptDays(days)) + if err != nil { + return nil, err + } + + // Get forecast by day + result := map[string]Day{} + for _, day := range response.Forecast.Day { + result[day.Date] = *day.Day + } + + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "location": response.Location, + "days": result, + }, + }, nil +} + +// Return the weather forecast for a city +func (c *Client) agentForecastWeatherCity(_ context.Context, call *agent.ToolCall) (*agent.ToolResult, error) { + // Get city parameter + city, ok := call.Args["city"].(string) + if !ok || city == "" { + return nil, ErrBadParameter.Withf("city is required") + } + + // Get days parameter + days, err := call.Int("days") + if err != nil { + return nil, err + } + + // Get response + response, err := c.Forecast(city, OptDays(days)) + if err != nil { + return nil, err + } + + // Get forecast by day + result := map[string]Day{} + for _, day := range response.Forecast.Day { + result[day.Date] = *day.Day + } + + return &agent.ToolResult{ + Id: call.Id, + Result: map[string]any{ + "type": "text", + "location": response.Location, + "days": result, + }, + }, nil +} diff --git a/transport.go b/transport.go index 30f2f4a..64cb82e 100644 --- a/transport.go +++ b/transport.go @@ -103,7 +103,7 @@ func (transport *logtransport) RoundTrip(req *http.Request) (*http.Response, err defer resp.Body.Close() switch { - case contentType == ContentTypeJson: + case contentType == ContentTypeJson || contentType == ContentTypeJsonStream: dst := &bytes.Buffer{} if err := json.Indent(dst, body, " ", " "); err != nil { fmt.Fprintf(transport.w, " <= %q\n", string(body))