Skip to content

LLM Module

uglychain.llm

llm(model, response_format=None, map_keys=None, **api_params)

LLM 装饰器,用于指定语言模型和其参数。

:param model: 模型名称 :param api_params: API 参数,以关键字参数形式传入 :return: 返回一个装饰器,用于装饰提示函数

Source code in src/uglychain/llm.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def llm(
    model: str,
    response_format: type[T] | None = None,
    map_keys: list[str] | None = None,
    **api_params: Any,
) -> Callable[
    [Callable[P, str | list[dict[str, str]] | T]],
    Callable[P, str | list[str] | T | list[T] | ToolResopnse | list[ToolResopnse]],
]:
    """
    LLM 装饰器,用于指定语言模型和其参数。

    :param model: 模型名称
    :param api_params: API 参数,以关键字参数形式传入
    :return: 返回一个装饰器,用于装饰提示函数
    """
    default_model_from_decorator = model
    default_api_params_from_decorator = api_params.copy()

    def parameterized_lm_decorator(
        prompt: Callable[P, str | list[dict[str, str]] | T],
    ) -> Callable[P, str | list[str] | T | list[T] | ToolResopnse | list[ToolResopnse]]:
        @wraps(prompt)
        def model_call(
            *prompt_args: P.args,
            api_params: dict[str, Any] | None = None,  # type: ignore
            **prompt_kwargs: P.kwargs,
        ) -> str | list[str] | T | list[T] | ToolResopnse | list[ToolResopnse]:
            console = Console()
            # 获取被修饰函数的返回类型
            response_model = ResponseModel[T](prompt, response_format)

            # 合并装饰器级别的API参数和函数级别的API参数
            merged_api_params = config.default_api_params.copy()
            if default_api_params_from_decorator:
                merged_api_params.update(default_api_params_from_decorator)
            if api_params:
                merged_api_params.update(api_params)

            # 获取同时运行的次数
            n = merged_api_params.get("n", 1)
            # 获取模型名称
            model = merged_api_params.pop("model", default_model_from_decorator)

            console.log_model_usage_pre(model, prompt, prompt_args, prompt_kwargs)

            m, map_args_index_set, map_kwargs_keys_set = _get_map_keys(prompt, prompt_args, prompt_kwargs, map_keys)
            if m > 1 and n > 1:
                raise ValueError("n > 1 和列表长度 > 1 不能同时成立")

            def process_single_prompt(i: int) -> list[Any]:
                args = [arg[i] if j in map_args_index_set else arg for j, arg in enumerate(prompt_args)]  # type: ignore
                kwargs = {
                    key: value[i] if key in map_kwargs_keys_set else value  # type: ignore
                    for key, value in prompt_kwargs.items()
                }
                res = prompt(*args, **kwargs)  # type: ignore
                assert (
                    isinstance(res, str) or isinstance(res, list) and all(isinstance(item, dict) for item in res)
                ), ValueError("被修饰的函数返回值必须是 str 或 `messages`(list[dict[str, str]]) 类型")
                messages = _get_messages(res, prompt)
                response_model.process_parameters(model, messages, merged_api_params)

                console.log_model_usage_post_info(messages, merged_api_params)

                response = Client.generate(model, messages, **merged_api_params)

                # 从响应中解析结果
                result = [response_model.parse_from_response(choice) for choice in response]
                console.log_model_usage_post_intermediate(result)
                return result

            results = []
            console.log_progress_start(m if m > 1 else n)

            if config.use_parallel_processing:
                with ThreadPoolExecutor() as executor:
                    futures = [executor.submit(process_single_prompt, i) for i in range(m)]

                    for future in as_completed(futures):
                        results.extend(future.result())
            else:
                for i in range(m):
                    results.extend(process_single_prompt(i))

            console.log_progress_end()

            if len(results) == 0:
                raise ValueError("模型未返回任何选择")
            elif m == n == len(results) == 1:
                return results[0]
            return results

        model_call.__api_params__ = default_api_params_from_decorator  # type: ignore
        model_call.__func__ = prompt  # type: ignore

        return model_call  # type: ignore

    return parameterized_lm_decorator  # type: ignore[return-value]