MaxKB扩展开发:自定义插件开发指南

【免费下载链接】MaxKB 强大易用的开源企业级智能体平台 【免费下载链接】MaxKB 项目地址: https://gitcode.com/feizhiyun/MaxKB

引言:为什么需要自定义插件?

在企业级AI应用场景中,每个组织都有独特的需求和技术栈。MaxKB作为开源企业级智能体平台,提供了强大的插件扩展机制,让开发者能够:

  • 🔧 集成私有模型:对接企业内部训练的专属大语言模型
  • 🌐 支持特殊协议:适配非标准API接口的AI服务提供商
  • 🛡️ 增强安全性:实现自定义的认证和加密机制
  • 📊 扩展功能:添加新的模型类型和能力(如图像识别、语音合成等)

本文将深入解析MaxKB的插件架构,并通过实际案例演示如何开发一个完整的自定义模型插件。

一、MaxKB插件架构解析

1.1 核心接口体系

MaxKB采用清晰的抽象接口设计,主要包含以下核心组件:

mermaid

1.2 模型类型枚举

MaxKB支持多种AI模型类型,通过ModelTypeConst枚举定义:

模型类型 代码常量 描述
大语言模型 LLM 文本生成和对话
嵌入模型 EMBEDDING 文本向量化
语音转文本 STT 语音识别
文本转语音 TTS 语音合成
视觉模型 IMAGE 图像理解
文本生成图像 TTI 图像生成
重排序器 RERANKER 搜索结果优化

二、开发自定义模型插件:实战案例

2.1 项目结构规划

创建一个新的模型插件需要遵循以下目录结构:

custom_model_provider/
├── __init__.py
├── custom_model_provider.py      # 主插件类
├── credential/
│   ├── __init__.py
│   ├── llm_credential.py         # 认证处理
│   └── embedding_credential.py   # 嵌入模型认证
├── model/
│   ├── __init__.py
│   ├── llm.py                    # LLM模型实现
│   └── embedding.py              # 嵌入模型实现
└── icon/
    └── custom_icon.svg           # 提供商图标

2.2 实现认证凭证类

# coding=utf-8
"""
自定义模型认证凭证实现
"""
from typing import Dict
from django.utils.translation import gettext_lazy as _
from models_provider.base_model_provider import BaseModelCredential
from common.exception.app_exception import AppApiException

class CustomLLMCredential(BaseModelCredential):
    
    def is_valid(self, model_type: str, model_name, model: Dict[str, object], 
                model_params, provider, raise_exception=True):
        """
        验证认证参数有效性
        """
        required_fields = ['api_key', 'base_url']
        missing_fields = [field for field in required_fields if field not in model]
        
        if missing_fields:
            if raise_exception:
                raise AppApiException(500, 
                    _('Missing required fields: {}').format(', '.join(missing_fields)))
            return False
        
        # 验证API密钥格式(示例验证)
        api_key = model.get('api_key')
        if not api_key or len(api_key) < 20:
            if raise_exception:
                raise AppApiException(500, _('Invalid API key format'))
            return False
            
        return True

    def encryption_dict(self, model_info: Dict[str, object]):
        """
        加密敏感信息
        """
        encrypted_info = model_info.copy()
        if 'api_key' in encrypted_info:
            encrypted_info['api_key'] = self.encryption(encrypted_info['api_key'])
        return encrypted_info

    def get_model_params_setting_form(self, model_name):
        """
        返回模型参数设置表单定义
        """
        return [
            {
                "field": "temperature",
                "label": _("Temperature"),
                "type": "slider",
                "required": False,
                "default_value": 0.7,
                "min": 0,
                "max": 2,
                "step": 0.1
            },
            {
                "field": "max_tokens",
                "label": _("Max Tokens"),
                "type": "number",
                "required": False,
                "default_value": 2048
            }
        ]

2.3 实现模型基类

# coding=utf-8
"""
自定义LLM模型实现
"""
from typing import Dict, List
from langchain.schema import BaseMessage, HumanMessage, AIMessage
from models_provider.base_model_provider import MaxKBBaseModel
from common.exception.app_exception import AppApiException
from django.utils.translation import gettext_lazy as _

class CustomLLMModel(MaxKBBaseModel):
    
    @staticmethod
    def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
        """
        创建新的模型实例
        """
        # 提取认证信息
        api_key = model_credential.get('api_key')
        base_url = model_credential.get('base_url')
        
        if not api_key or not base_url:
            raise AppApiException(500, _('Missing API key or base URL'))
        
        # 创建并返回模型实例
        return CustomLLMModel(api_key, base_url, model_name, **model_kwargs)
    
    def __init__(self, api_key: str, base_url: str, model_name: str, **kwargs):
        self.api_key = api_key
        self.base_url = base_url
        self.model_name = model_name
        self.temperature = kwargs.get('temperature', 0.7)
        self.max_tokens = kwargs.get('max_tokens', 2048)
        
    def invoke(self, messages: List[BaseMessage], **kwargs):
        """
        执行模型调用
        """
        # 转换消息格式
        formatted_messages = []
        for message in messages:
            if isinstance(message, HumanMessage):
                formatted_messages.append({"role": "user", "content": message.content})
            elif isinstance(message, AIMessage):
                formatted_messages.append({"role": "assistant", "content": message.content})
        
        # 调用自定义API
        try:
            import requests
            response = requests.post(
                f"{self.base_url}/v1/chat/completions",
                headers={
                    "Authorization": f"Bearer {self.api_key}",
                    "Content-Type": "application/json"
                },
                json={
                    "model": self.model_name,
                    "messages": formatted_messages,
                    "temperature": self.temperature,
                    "max_tokens": self.max_tokens,
                    **kwargs
                },
                timeout=30
            )
            
            if response.status_code == 200:
                result = response.json()
                return result['choices'][0]['message']['content']
            else:
                raise AppApiException(500, 
                    _('API request failed: {}').format(response.text))
                    
        except Exception as e:
            raise AppApiException(500, 
                _('Model invocation error: {}').format(str(e)))
    
    @staticmethod
    def is_cache_model():
        return False

2.4 实现主插件类

# coding=utf-8
"""
自定义模型提供商主类
"""
import os
from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
from .credential.llm_credential import CustomLLMCredential
from .credential.embedding_credential import CustomEmbeddingCredential
from .model.llm import CustomLLMModel
from .model.embedding import CustomEmbeddingModel
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from common.utils.common import get_file_content

# 创建认证实例
custom_llm_credential = CustomLLMCredential()
custom_embedding_credential = CustomEmbeddingCredential()

# 定义支持的模型列表
model_info_list = [
    ModelInfo('custom-llm-v1', _('Custom LLM Version 1'), ModelTypeConst.LLM,
              custom_llm_credential, CustomLLMModel),
    ModelInfo('custom-llm-v2', _('Custom LLM Version 2'), ModelTypeConst.LLM,
              custom_llm_credential, CustomLLMModel),
]

model_info_embedding_list = [
    ModelInfo('custom-embedding-v1', _('Custom Embedding Model'), ModelTypeConst.EMBEDDING,
              custom_embedding_credential, CustomEmbeddingModel),
]

# 构建模型信息管理器
model_info_manage = (
    ModelInfoManage.builder()
    .append_model_info_list(model_info_list)
    .append_default_model_info(model_info_list[0])
    .append_model_info_list(model_info_embedding_list)
    .append_default_model_info(model_info_embedding_list[0])
    .build()
)

class CustomModelProvider(IModelProvider):
    """
    自定义模型提供商实现
    """
    
    def get_model_info_manage(self):
        return model_info_manage
    
    def get_model_provide_info(self):
        return ModelProvideInfo(
            provider='custom_model_provider',
            name='Custom AI Provider',
            icon=get_file_content(
                os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 
                           'custom_model_provider', 'icon', 'custom_icon.svg')
            )
        )
    
    def get_dialogue_number(self):
        """返回支持的对话轮数"""
        return 5

2.5 注册插件到系统

apps/models_provider/constants/model_provider_constants.py中添加注册:

# coding=utf-8
"""
模型提供商常量定义
"""
from enum import Enum
from models_provider.impl.custom_model_provider.custom_model_provider import CustomModelProvider
from models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
# ... 其他提供商导入

class ModelProvideConstants(Enum):
    # 现有提供商
    model_openai_provider = OpenAIModelProvider()
    
    # 新增自定义提供商
    custom_model_provider = CustomModelProvider()
    
    # 其他提供商...

三、高级功能开发指南

3.1 支持流式输出

def stream_invoke(self, messages: List[BaseMessage], **kwargs):
    """
    流式输出实现
    """
    import requests
    formatted_messages = self._format_messages(messages)
    
    response = requests.post(
        f"{self.base_url}/v1/chat/completions",
        headers=self._get_headers(),
        json={
            "model": self.model_name,
            "messages": formatted_messages,
            "stream": True,
            **kwargs
        },
        stream=True,
        timeout=30
    )
    
    for chunk in response.iter_lines():
        if chunk:
            chunk_str = chunk.decode('utf-8')
            if chunk_str.startswith('data: '):
                data = chunk_str[6:]
                if data != '[DONE]':
                    try:
                        chunk_data = json.loads(data)
                        if 'choices' in chunk_data and chunk_data['choices']:
                            delta = chunk_data['choices'][0].get('delta', {})
                            if 'content' in delta:
                                yield delta['content']
                    except json.JSONDecodeError:
                        continue

3.2 实现模型下载功能

def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]):
    """
    模型下载实现(适用于本地模型)
    """
    total_steps = 10
    for i in range(total_steps):
        progress = (i + 1) * 10
        yield DownModelChunk(
            status=DownModelChunkStatus.pulling,
            digest=f"downloading_{model_name}",
            progress=progress,
            details=f"Downloading {model_name} - step {i+1}/{total_steps}",
            index=i
        )
    
    yield DownModelChunk(
        status=DownModelChunkStatus.success,
        digest=f"download_complete_{model_name}",
        progress=100,
        details=f"Model {model_name} downloaded successfully",
        index=total_steps
    )

3.3 错误处理与重试机制

def _make_request_with_retry(self, url, headers, data, max_retries=3):
    """
    带重试机制的请求函数
    """
    import requests
    from time import sleep
    
    for attempt in range(max_retries):
        try:
            response = requests.post(url, headers=headers, json=data, timeout=30)
            if response.status_code == 200:
                return response
            elif response.status_code == 429:  # Rate limit
                sleep_time = 2 ** attempt  # Exponential backoff
                sleep(sleep_time)
                continue
            else:
                raise AppApiException(500, 
                    f"API request failed with status {response.status_code}")
        except requests.exceptions.RequestException as e:
            if attempt == max_retries - 1:
                raise AppApiException(500, 
                    f"Request failed after {max_retries} attempts: {str(e)}")
            sleep(2 ** attempt)
    
    raise AppApiException(500, "Max retries exceeded")

四、测试与调试

4.1 单元测试示例

# coding=utf-8
"""
自定义模型插件测试
"""
from django.test import TestCase
from models_provider.impl.custom_model_provider.custom_model_provider import CustomModelProvider
from models_provider.impl.custom_model_provider.credential.llm_credential import CustomLLMCredential

class CustomModelProviderTest(TestCase):
    
    def setUp(self):
        self.provider = CustomModelProvider()
        self.credential = CustomLLMCredential()
    
    def test_model_list(self):
        """测试获取模型列表"""
        models = self.provider.get_model_list('LLM')
        self.assertGreater(len(models), 0)
        self.assertEqual(models[0]['name'], 'custom-llm-v1')
    
    def test_credential_validation(self):
        """测试认证验证"""
        valid_credential = {'api_key': 'valid_key_12345678901234567890', 'base_url': 'https://api.example.com'}
        invalid_credential = {'api_key': 'short', 'base_url': 'https://api.example.com'}
        
        # 测试有效认证
        result = self.credential.is_valid('LLM', 'custom-llm-v1', valid_credential, {}, raise_exception=False)
        self.assertTrue(result)
        
        # 测试无效认证
        result = self.credential.is_valid('LLM', 'custom-llm-v1', invalid_credential, {}, raise_exception=False)
        self.assertFalse(result)

4.2 集成测试流程

mermaid

五、部署与维护

5.1 生产环境部署

  1. 依赖管理:确保所有依赖包正确安装
  2. 配置文件:设置适当的超时时间和重试策略
  3. 监控指标:添加性能监控和日志记录
  4. 安全审计:定期检查认证和加密机制

5.2 性能优化建议

优化项 建议方案 预期效果
连接池 使用requests.Session 减少TCP连接开销
缓存 实现响应缓存机制 降低API调用频率
批量处理 支持批量文本处理 提高吞吐量
异步调用 使用异步IO 提高并发性能

结语

通过本文的详细指南,您已经掌握了MaxKB自定义插件开发的核心技术。自定义插件开发不仅能够扩展MaxKB的功能边界,还能让企业更好地适应特定的业务需求和技术环境。

记住良好的插件应该具备:清晰的接口设计、完善的错误处理、详细的文档说明以及充分的测试覆盖。随着AI技术的快速发展,保持插件的可扩展性和维护性至关重要。

开始您的MaxKB插件开发之旅,为企业AI应用注入新的活力!

【免费下载链接】MaxKB 强大易用的开源企业级智能体平台 【免费下载链接】MaxKB 项目地址: https://gitcode.com/feizhiyun/MaxKB

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐