Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Heat/Conversations/ConversationInput.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ struct ConversationInput: View {
@Environment(ConversationViewModel.self) var conversationViewModel
@Environment(\.modelContext) private var modelContext

@Query(sort: \Memory.created, order: .forward) var memories: [Memory]

@State var imagePickerViewModel: ImagePickerViewModel
@State var content: String
@State var command: String
Expand Down Expand Up @@ -184,7 +182,7 @@ struct ConversationInput: View {
guard !content.isEmpty else { return }

do {
try conversationViewModel.generate(content, context: memories.map { $0.content })
try conversationViewModel.generate(content)
} catch let error as KitError {
conversationViewModel.error = error
} catch {
Expand Down
19 changes: 4 additions & 15 deletions Heat/Conversations/ConversationViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ final class ConversationViewModel {
conversationID = conversation.id
}

func generate(_ content: String, context: [String] = []) throws {
func generate(_ content: String) throws {
guard !content.isEmpty else { return }
guard let conversation else {
throw KitError.missingConversation
Expand All @@ -54,13 +54,12 @@ final class ConversationViewModel {

let toolService = try store.preferredToolService()
let toolModel = try store.preferredToolModel()

let context = prepareContext(context)

let contextTools: [ContextTool.Type] = [MemoryTool.self]

generateTask = Task {
await MessageManager()
.append(messages: messages)
.append(message: context)
.append(messages: contextTools.compactMap({ $0.prepareContext() }))
.append(message: .init(role: .user, content: content)) { message in
self.store.upsert(suggestions: [], conversationID: conversation.id)
self.store.upsert(message: message, conversationID: conversation.id)
Expand Down Expand Up @@ -253,15 +252,5 @@ final class ConversationViewModel {
}
return nil
}

private func prepareContext(_ context: [String]) -> Message? {
guard !context.isEmpty else { return nil }

return Message(role: .system, content: """
Some things to remember about who the user is. Use these to better relate to the user when responding:

\(context.joined(separator: "\n"))
""")
}
}

7 changes: 3 additions & 4 deletions Heat/Launcher/LauncherViewModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,20 @@ final class LauncherViewModel {
conversationID = conversation.id
}

func generate(_ content: String, context: [String] = []) throws {
func generate(_ content: String) throws {
guard !content.isEmpty else { return }
guard let conversation else {
throw KitError.missingConversation
}

let chatService = try store.preferredChatService()
let chatModel = try store.preferredChatModel()

let context = prepareContext(context)
let contextTools: [ContextTool.Type] = [MemoryTool.self]

generateTask = Task {
await MessageManager()
.append(messages: messages)
.append(message: context)
.append(messages: contextTools.compactMap({ $0.prepareContext() }))
.append(message: .init(role: .user, content: content)) { message in
self.store.upsert(suggestions: [], conversationID: conversation.id)
self.store.upsert(message: message, conversationID: conversation.id)
Expand Down
5 changes: 5 additions & 0 deletions HeatKit/Sources/HeatKit/Tools/ContextTool.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import GenKit

public protocol ContextTool {
static func prepareContext() -> Message?
}
20 changes: 20 additions & 0 deletions HeatKit/Sources/HeatKit/Tools/MemoryTool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,23 @@ extension MemoryTool {
}
}
}

extension MemoryTool: ContextTool {
public static func prepareContext() -> Message? {
guard let container = try? ModelContainer(for: Memory.self) else { return nil }
let fetchRequest = FetchDescriptor<Memory>(sortBy: [SortDescriptor(\Memory.created, order: .forward)])
let context = ModelContext(container)
do {
let memories: [Memory] = try context.fetch(fetchRequest)
guard !memories.isEmpty else { return nil }
return Message(role: .system, content: """
Some things to remember about who the user is. Use these to better relate to the user when responding:

\(memories.map({ $0.content }).joined(separator: "\n"))
""")
} catch {
print("Failed to fetch memories: \(error)")
return nil
}
}
}