Source code for lmitf.templete_llm

from __future__ import annotations

import importlib.util
import os
import os.path as op
import re
from string import Template

import pandas as pd
from dotenv import load_dotenv
from wasabi import msg

from .base_llm import BaseLLM
load_dotenv()

[docs] class TemplateLLM(BaseLLM):
[docs] def __init__(self, template_path: str, api_key: str = None, base_url: str = None): super().__init__( api_key=api_key or os.getenv('OPENAI_API_KEY'), base_url=base_url or os.getenv('OPENAI_BASE_URL'), ) """初始化模板LLM客户端 """ assert op.exists(template_path), f'Template file does not exist: {template_path}' assert template_path.endswith('.py'), 'Template file must be a Python file (.py)' self._load_template(template_path) self.template_path = template_path msg.text(f'Template loaded from \n{template_path}')
def _load_template(self, template_path: str): """ 加载prompt模板文件并解析其中的模板变量 Parameters ---------- template_path : str 模板文件的路径(.py文件) """ if not op.exists(template_path): raise FileNotFoundError( f'Template file not found: {template_path}', ) spec = importlib.util.spec_from_file_location( 'template_module', template_path, ) template_module = importlib.util.module_from_spec(spec) try: spec.loader.exec_module(template_module) except Exception as e: raise ImportError(f'Failed to load template module: {e}') if not hasattr(template_module, 'prompt_template'): raise AttributeError( f"Template module must define 'prompt_template' attribute: \n{template_path}", ) if not hasattr(template_module, 'conditioned_frame'): raise AttributeError( f"Template module must define 'conditioned_frame' attribute: \n{template_path}", ) self.prompt_template = getattr(template_module, 'prompt_template') self.conditioned_frame = getattr(template_module, 'conditioned_frame') self.template_obj = Template(template_module.conditioned_frame) variables = re.findall(r'\$(\w+)', self.conditioned_frame) variables = list(set(variables)) # 去重 if not variables: raise ValueError( f'No variables found in conditioned_frame: \n{self.conditioned_frame}', ) self.variables = variables def _fill(self, **kwargs): """ 使用提供的参数替换模板中的变量 Parameters ---------- **kwargs 要替换的模板变量 Returns ------- str 替换后的文本 """ # 检查输入的kwargs是否和variables匹配 missing_vars = set(self.variables) - set(kwargs.keys()) if missing_vars: raise ValueError(f'Missing required variables: {missing_vars}') extra_vars = set(kwargs.keys()) - set(self.variables) if extra_vars: raise ValueError(f'Unexpected variables provided: {extra_vars}') prompt = self.prompt_template.copy() prompt[-1]['content'] = self.template_obj.substitute(**kwargs) return prompt
[docs] def call(self, **kwargs): """ 调用LLM,生成响应 Returns ------- str LLM的响应内容 """ if not self.prompt_template: raise ValueError('Prompt template is not defined.') template_vars = { k: v for k, v in kwargs.items() if k in self.variables } messages = self._fill(**template_vars) non_template_kwargs = { k: v for k, v in kwargs.items() if k not in self.variables } response = super().call(messages=messages, **non_template_kwargs) return response
def _repr_html_(self): """ 返回HTML格式的表示,以DataFrame形式显示模板信息 """ data = { 'Name': [op.basename(self.template_path)], 'Variables to fill': [', '.join(self.variables)], } df = pd.DataFrame(data) df.index = ['Template Info'] return df.T._repr_html_()