Skip to content

Commit

Permalink
Merge pull request #47 from buhe/refactor
Browse files Browse the repository at this point in the history
Refactor
  • Loading branch information
buhe authored Oct 30, 2023
2 parents db9f1df + 537cefb commit c15e098
Show file tree
Hide file tree
Showing 19 changed files with 336 additions and 361 deletions.
3 changes: 1 addition & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ let package = Package(
.package(url: "https://github.com/supabase-community/supabase-swift", .upToNextMajor(from: "0.2.1")),
.package(url: "https://github.com/SwiftyJSON/SwiftyJSON", .upToNextMajor(from: "5.0.1")),
.package(url: "https://github.com/drmohundro/SWXMLHash", .upToNextMajor(from: "7.0.2")),
.package(url: "https://github.com/scinfu/SwiftSoup", .upToNextMajor(from: "2.6.1")),
.package(url: "https://github.com/swift-server/async-http-client", .upToNextMajor(from: "1.18.0")),
.package(url: "https://github.com/scinfu/SwiftSoup", .upToNextMajor(from: "2.6.1"))
],
targets: [
// Targets are the basic building blocks of a package, defining a module or a test suite.
Expand Down
311 changes: 141 additions & 170 deletions README.md

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions Sources/LangChain/agents/Agent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ public class AgentExecutor: DefaultChain {
static let AGENT_REQ_ID = "agent_req_id"
let agent: Agent
let tools: [BaseTool]
public init(agent: Agent, tools: [BaseTool], memory: BaseMemory? = nil, outputKey: String? = nil, callbacks: [BaseCallbackHandler] = []) {
public init(agent: Agent, tools: [BaseTool], memory: BaseMemory? = nil, outputKey: String = "output", inputKey: String = "input", callbacks: [BaseCallbackHandler] = []) {
self.agent = agent
self.tools = tools
var cbs: [BaseCallbackHandler] = callbacks
if Env.addTraceCallbak() && !cbs.contains(where: { item in item is TraceCallbackHandler}) {
cbs.append(TraceCallbackHandler())
}
// assert(cbs.count == 1)
super.init(memory: memory, outputKey: outputKey, callbacks: cbs)
super.init(memory: memory, outputKey: outputKey, inputKey: inputKey, callbacks: cbs)
}

// def _take_next_step(
Expand Down Expand Up @@ -134,7 +134,7 @@ public class AgentExecutor: DefaultChain {
return (step, "default")
}
}
public override func call(args: String) async throws -> LLMResult {
public override func _call(args: String) async throws -> (LLMResult, Parsed) {
// chain run -> call -> agent plan -> llm send

// while should_continue and call
Expand All @@ -160,15 +160,15 @@ public class AgentExecutor: DefaultChain {
for callback in self.callbacks {
try callback.on_agent_finish(action: finish, metadata: [AgentExecutor.AGENT_REQ_ID: reqId])
}
return LLMResult(llm_output: next_step_output.1)
return (LLMResult(llm_output: next_step_output.1), Parsed.str(next_step_output.1))
case .action(let action):
for callback in self.callbacks {
try callback.on_agent_action(action: action, metadata: [AgentExecutor.AGENT_REQ_ID: reqId])
}
intermediate_steps.append((action, next_step_output.1))
default:
// print("error step.")
return LLMResult()
return (LLMResult(), Parsed.error)
}
}
}
Expand Down Expand Up @@ -263,7 +263,7 @@ public class ZeroShotAgent: Agent {
let tool_names = tools.map{$0.name()}.joined(separator: ", ")
let format_instructions2 = String(format: format_instructions, tool_names)
let template = [prefix0, tool_strings, format_instructions2, suffix].joined(separator: "\n\n")
return PromptTemplate(input_variables: [], template: template)
return PromptTemplate(input_variables: ["question", "thought"], partial_variable: [:], template: template)
}
// @classmethod
// def create_prompt(
Expand Down
4 changes: 2 additions & 2 deletions Sources/LangChain/agents/mrkl/MrklPrompt.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ Final Answer: the final answer to the original input question
public let SUFFIX = """
Begin!
Question: %@
Thought: %@
Question: {question}
Thought: {thought}
"""

public let FINAL_ANSWER_ACTION = "Final Answer:"
81 changes: 25 additions & 56 deletions Sources/LangChain/chains/BaseChain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

import Foundation

public class DefaultChain: Chain {
public class DefaultChain {
static let CHAIN_REQ_ID_KEY = "chain_req_id"
static let CHAIN_COST_KEY = "cost"
public init(memory: BaseMemory? = nil, outputKey: String? = nil, callbacks: [BaseCallbackHandler] = []) {
public init(memory: BaseMemory? = nil, outputKey: String, inputKey: String, callbacks: [BaseCallbackHandler] = []) {
self.memory = memory
self.outputKey = outputKey
self.inputKey = inputKey
var cbs: [BaseCallbackHandler] = callbacks
if Env.addTraceCallbak() && !cbs.contains(where: { item in item is TraceCallbackHandler}) {
cbs.append(TraceCallbackHandler())
Expand All @@ -21,11 +22,12 @@ public class DefaultChain: Chain {
self.callbacks = cbs
}
let memory: BaseMemory?
let outputKey: String?
let inputKey: String
let outputKey: String
let callbacks: [BaseCallbackHandler]
public func call(args: String) async throws -> LLMResult {
public func _call(args: String) async throws -> (LLMResult, Parsed) {
print("call base.")
return LLMResult()
return (LLMResult(), Parsed.unimplemented)
}

func callEnd(output: String, reqId: String, cost: Double) {
Expand Down Expand Up @@ -59,27 +61,35 @@ public class DefaultChain: Chain {
}

// This interface alreadly return 'LLMReult', ensure 'run' method has stream style.
public func run(args: String) async -> LLMResult {
public func run(args: String) async -> Parsed {
let _ = prep_inputs(inputs: [inputKey: args])
// = Langchain's run + __call__
let reqId = UUID().uuidString
var cost = 0.0
let now = Date.now.timeIntervalSince1970
do {
callStart(prompt: args, reqId: reqId)
let llmResult = try await self.call(args: args)
let outputs = try await self._call(args: args)
cost = Date.now.timeIntervalSince1970 - now
if !llmResult.stream {
callEnd(output: llmResult.llm_output!, reqId: reqId, cost: cost)
} else {
callEnd(output: "[LLM is streamable]", reqId: reqId, cost: cost)
}
return llmResult
//call end trace
// if !outputs.0.stream {
callEnd(output: outputs.0.llm_output!, reqId: reqId, cost: cost)
// } else {
// callEnd(output: "[LLM is streamable]", reqId: reqId, cost: cost)
// }
let _ = prep_outputs(inputs: [inputKey: args], outputs: [self.outputKey: outputs.0.llm_output!])
return outputs.1
} catch {
// print(error)
callCatch(error: error, reqId: reqId, cost: cost)
return LLMResult(llm_output: "")
return Parsed.error
}
}

func __call__() {

}

func prep_outputs(inputs: [String: String], outputs: [String: String]) -> [String: String] {
if self.memory != nil {
self.memory!.save_context(inputs: inputs, outputs: outputs)
Expand All @@ -94,7 +104,7 @@ public class DefaultChain: Chain {
func prep_inputs(inputs: [String: String]) -> [String: String] {
if self.memory != nil {
var external_context = Dictionary(uniqueKeysWithValues: self.memory!.load_memory_variables(inputs: inputs).map {(key, value) in return (key, value.joined(separator: "\n"))})
// print("ctx: \(external_context)")
// print("ctx: \(external_context)")
inputs.forEach { (key, value) in
external_context[key] = value
}
Expand All @@ -103,45 +113,4 @@ public class DefaultChain: Chain {
return inputs
}
}

// def prep_outputs(
// self,
// inputs: Dict[str, str],
// outputs: Dict[str, str],
// return_only_outputs: bool = False,
// ) -> Dict[str, str]:
// """Validate and prep outputs."""
// self._validate_outputs(outputs)
// if self.memory is not None:
// self.memory.save_context(inputs, outputs)
// if return_only_outputs:
// return outputs
// else:
// return {**inputs, **outputs}
//
// def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
// """Validate and prep inputs."""
// if not isinstance(inputs, dict):
// _input_keys = set(self.input_keys)
// if self.memory is not None:
// # If there are multiple input keys, but some get set by memory so that
// # only one is not set, we can still figure out which key it is.
// _input_keys = _input_keys.difference(self.memory.memory_variables)
// if len(_input_keys) != 1:
// raise ValueError(
// f"A single string input was passed in, but this chain expects "
// f"multiple inputs ({_input_keys}). When a chain expects "
// f"multiple inputs, please call it by passing in a dictionary, "
// "eg `chain({'foo': 1, 'bar': 2})`"
// )
// inputs = {list(_input_keys)[0]: inputs}
// if self.memory is not None:
// external_context = self.memory.load_memory_variables(inputs)
// inputs = dict(inputs, **external_context)
// self._validate_inputs(inputs)
// return inputs
}

public protocol Chain {
func call(args: String) async throws -> LLMResult
}
8 changes: 4 additions & 4 deletions Sources/LangChain/chains/DNChain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

import Foundation
public class DNChain: DefaultChain {
public override init(memory: BaseMemory? = nil, outputKey: String? = nil, callbacks: [BaseCallbackHandler] = []) {
super.init(memory: memory, outputKey: outputKey, callbacks: callbacks)
public override init(memory: BaseMemory? = nil, outputKey: String = "output", inputKey: String = "input", callbacks: [BaseCallbackHandler] = []) {
super.init(memory: memory, outputKey: outputKey, inputKey: inputKey, callbacks: callbacks)
}
public override func call(args: String) async throws -> LLMResult {
public override func _call(args: String) async throws -> (LLMResult, Parsed) {
// print("Do nothing.")
return LLMResult(llm_output: "")
return (LLMResult(), Parsed.nothing)
}

}
104 changes: 32 additions & 72 deletions Sources/LangChain/chains/LLMChain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,100 +13,60 @@ public class LLMChain: DefaultChain {
let parser: BaseOutputParser?
let stop: [String]

public init(llm: LLM, prompt: PromptTemplate? = nil, parser: BaseOutputParser? = nil, stop: [String] = [], memory: BaseMemory? = nil, outputKey: String? = nil, callbacks: [BaseCallbackHandler] = []) {
public init(llm: LLM, prompt: PromptTemplate? = nil, parser: BaseOutputParser? = nil, stop: [String] = [], memory: BaseMemory? = nil, outputKey: String = "output", inputKey: String = "input", callbacks: [BaseCallbackHandler] = []) {
self.llm = llm
self.prompt = prompt
self.parser = parser
self.stop = stop
super.init(memory: memory, outputKey: outputKey, callbacks: callbacks)
super.init(memory: memory, outputKey: outputKey, inputKey: inputKey, callbacks: callbacks)
}
public override func call(args: String) async throws -> LLMResult {
func create_outputs(output: LLMResult) -> Parsed {
if let parser = self.parser {
return parser.parse(text: output.llm_output!)
} else {
return Parsed.str(output.llm_output!)
}
}
public override func _call(args: String) async throws -> (LLMResult, Parsed) {
// ["\\nObservation: ", "\\n\\tObservation: "]

let llmResult = await self.llm.send(text: args, stops: stop)

return llmResult
}

func generate(input_list: [String]) async -> String {
// call rest api
let llmResult = await generate(input_list: [inputKey: args])

var input_prompt = ""
return (llmResult, create_outputs(output: llmResult))
}
func prep_prompts(input_list: [String: String]) -> String {
if let prompt = self.prompt {
input_prompt = prompt.format(args: input_list)
return prompt.format(args: input_list)
} else {
return input_list.first!.value
}
}
func generate(input_list: [String: String]) async -> LLMResult {
let input_prompt = prep_prompts(input_list: input_list)
do {
var llmResult = await run(args: input_prompt)
//call llm
var llmResult = await self.llm.generate(text: input_prompt, stops: stop)
try await llmResult.setOutput()
return llmResult.llm_output!
return llmResult
} catch {
print(error)
return ""
return LLMResult()
}
}


// func prep_prompts(input_list: [[String: String]]) -> [String] {
// // inputs and prompt build compelete prompt
// var prompts: [String] = []
//
// for i in input_list {
// var args: [String] = []
// for name in self.prompt.input_variables {
// args.append(i[name]!)
// }
// prompts.append(self.prompt.format(args: args))
// }
// return prompts
// }

public func apply(input_list: [String]) async -> Parsed {
// let prompts = prep_prompts(input_list: input_list)
let response: String = await generate(input_list: input_list)
if let parser = self.parser {
let results = parser.parse(text: response)
return results
} else {
return Parsed.str(response)
}
public func apply(input_list: [String: String]) async -> Parsed {
let response = await generate(input_list: input_list)
return create_outputs(output: response)
}

public func plan(input: String, agent_scratchpad: String) async -> Parsed {
return await apply(input_list: [input, agent_scratchpad])
return await apply(input_list: ["question": input, "thought": agent_scratchpad])
}

public func predict(args: [String: String] ) async -> [String: String] {
// predict -> __call__ -> _call
public func predict(args: [String: String] ) async -> String {
let inputAndContext = prep_inputs(inputs: args)
let output = await self.generate(input_list: inputAndContext.values.map{$0})
// call setOutput to finish output
let outputs = prep_outputs(inputs: inputAndContext, outputs: ["Answer": output])
return outputs
}


public func predict_and_parse(args: [String: String]) async -> Parsed {
let output = await self.predict(args: args)["Answer"]!
if let parser = self.parser {
return parser.parse(text: output)
} else {
return Parsed.str(output)
}
let outputs = await self.generate(input_list: inputAndContext)
let _ = prep_outputs(inputs: args, outputs: [self.outputKey: outputs.llm_output!])
return outputs.llm_output!
}
// def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
// """Format prompt with kwargs and pass to LLM.
//
// Args:
// callbacks: Callbacks to pass to LLMChain
// **kwargs: Keys to pass to prompt template.
//
// Returns:
// Completion from LLM.
//
// Example:
// .. code-block:: python
//
// completion = llm.predict(adjective="funny")
// """
// return self(kwargs, callbacks=callbacks)[self.output_key]
}
10 changes: 5 additions & 5 deletions Sources/LangChain/chains/SequentialChain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
import Foundation
public class SequentialChain: DefaultChain {
let chains: [DefaultChain]
public init(chains: [DefaultChain], memory: BaseMemory? = nil, outputKey: String? = nil, callbacks: [BaseCallbackHandler] = []) {
public init(chains: [DefaultChain], memory: BaseMemory? = nil, outputKey: String = "output", inputKey: String = "input", callbacks: [BaseCallbackHandler] = []) {
self.chains = chains
super.init(memory: memory, outputKey: outputKey, callbacks: callbacks)
super.init(memory: memory, outputKey: outputKey, inputKey: inputKey, callbacks: callbacks)
}
public func predict(args: String) async throws -> [String: String] {
var result: [String: String] = [:]
var input: LLMResult = LLMResult(llm_output: args)
for chain in self.chains {
assert(chain.outputKey != nil, "chain.outputKey must not be nil")
input = try await chain.call(args: input.llm_output!)
result.updateValue(input.llm_output!, forKey: chain.outputKey!)
// assert(chain.outputKey != nil, "chain.outputKey must not be nil")
input = try await chain._call(args: input.llm_output!).0
result.updateValue(input.llm_output!, forKey: chain.outputKey)
}
return result
}
Expand Down
Loading

0 comments on commit c15e098

Please sign in to comment.