18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116 | def llm(
model: str,
response_format: type[T] | None = None,
map_keys: list[str] | None = None,
**api_params: Any,
) -> Callable[
[Callable[P, str | list[dict[str, str]] | T]],
Callable[P, str | list[str] | T | list[T] | ToolResopnse | list[ToolResopnse]],
]:
"""
LLM 装饰器,用于指定语言模型和其参数。
:param model: 模型名称
:param api_params: API 参数,以关键字参数形式传入
:return: 返回一个装饰器,用于装饰提示函数
"""
default_model_from_decorator = model
default_api_params_from_decorator = api_params.copy()
def parameterized_lm_decorator(
prompt: Callable[P, str | list[dict[str, str]] | T],
) -> Callable[P, str | list[str] | T | list[T] | ToolResopnse | list[ToolResopnse]]:
@wraps(prompt)
def model_call(
*prompt_args: P.args,
api_params: dict[str, Any] | None = None, # type: ignore
**prompt_kwargs: P.kwargs,
) -> str | list[str] | T | list[T] | ToolResopnse | list[ToolResopnse]:
console = Console()
# 获取被修饰函数的返回类型
response_model = ResponseModel[T](prompt, response_format)
# 合并装饰器级别的API参数和函数级别的API参数
merged_api_params = config.default_api_params.copy()
if default_api_params_from_decorator:
merged_api_params.update(default_api_params_from_decorator)
if api_params:
merged_api_params.update(api_params)
# 获取同时运行的次数
n = merged_api_params.get("n", 1)
# 获取模型名称
model = merged_api_params.pop("model", default_model_from_decorator)
console.log_model_usage_pre(model, prompt, prompt_args, prompt_kwargs)
m, map_args_index_set, map_kwargs_keys_set = _get_map_keys(prompt, prompt_args, prompt_kwargs, map_keys)
if m > 1 and n > 1:
raise ValueError("n > 1 和列表长度 > 1 不能同时成立")
def process_single_prompt(i: int) -> list[Any]:
args = [arg[i] if j in map_args_index_set else arg for j, arg in enumerate(prompt_args)] # type: ignore
kwargs = {
key: value[i] if key in map_kwargs_keys_set else value # type: ignore
for key, value in prompt_kwargs.items()
}
res = prompt(*args, **kwargs) # type: ignore
assert (
isinstance(res, str) or isinstance(res, list) and all(isinstance(item, dict) for item in res)
), ValueError("被修饰的函数返回值必须是 str 或 `messages`(list[dict[str, str]]) 类型")
messages = _get_messages(res, prompt)
response_model.process_parameters(model, messages, merged_api_params)
console.log_model_usage_post_info(messages, merged_api_params)
response = Client.generate(model, messages, **merged_api_params)
# 从响应中解析结果
result = [response_model.parse_from_response(choice) for choice in response]
console.log_model_usage_post_intermediate(result)
return result
results = []
console.log_progress_start(m if m > 1 else n)
if config.use_parallel_processing:
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_single_prompt, i) for i in range(m)]
for future in as_completed(futures):
results.extend(future.result())
else:
for i in range(m):
results.extend(process_single_prompt(i))
console.log_progress_end()
if len(results) == 0:
raise ValueError("模型未返回任何选择")
elif m == n == len(results) == 1:
return results[0]
return results
model_call.__api_params__ = default_api_params_from_decorator # type: ignore
model_call.__func__ = prompt # type: ignore
return model_call # type: ignore
return parameterized_lm_decorator # type: ignore[return-value]
|