机制分析

关键文件和类

文件路径:langchain_text_splitters/character.py

类名:RecursiveCharacterTextSplitter

核心入口函数:_split_text

解析步骤及源码分析

步骤 说明 示例/细节
1. 分隔符降级 separators=["\n\n", "\n", " ", ""] 顺序尝试,分隔符可自定义 先用 \n\n 切段落,若任一段>chunk_size 字符,则对该段降级使用 \n 切,依次类推
2. 递归切分 对“超长段”重复步骤 1,直到所有片段 ≤ chunk_size 或已用完分隔符 若句子级仍超长,最终用 空字符串"" 按字符硬切
3. 段合并(Merge) 把“好段”依次拼成 尽可能长 的块,保证 ≤ chunk_size 如果拼到再加就超chunk_size,则封口、起新块
4. 重叠(Overlap) 当合并完成一个块A1之后,算法回退<=chunk_overlap 的长度,以便下个块A2包含前一个块末尾overlap的内容。特别的,若块A1的最后一段长度大于chunk_overlap,则为保证语义完整,不强行生成overlap。 相邻两个块 A1和A2,则 A2 头部 = A1 尾部 overlap 内容

上述步骤的总体逻辑在_split_text函数中,关键处均有代码注释说明:

    def _split_text(self, text: str, separators: list[str]) -> list[str]:
        """Split incoming text and return chunks."""
        final_chunks = []
        # Get appropriate separator to use
        # 得到当前层级的分隔符
        separator = separators[-1]
        new_separators = []
        for i, _s in enumerate(separators):
            _separator = _s if self._is_separator_regex else re.escape(_s)
            if _s == "":
                separator = _s
                break
            # 快速搜索文本中是否存在分隔符,不存在,则分隔符降级
            if re.search(_separator, text):
                separator = _s
                new_separators = separators[i + 1 :]
                break

        _separator = separator if self._is_separator_regex else re.escape(separator)
        splits = _split_text_with_regex(
            text, _separator, keep_separator=self._keep_separator
        )

        # Now go merging things, recursively splitting longer texts.
        # 所谓的“好段”指的是长度小于chunk_size的段
        _good_splits = []
        _separator = "" if self._keep_separator else separator
        for s in splits:
            if self._length_function(s) < self._chunk_size:
                _good_splits.append(s)
            else:
                # 如果发现一个段不是好段,就把前面累积的好段做合并
                if _good_splits:
                    merged_text = self._merge_splits(_good_splits, _separator)
                    final_chunks.extend(merged_text)
                    _good_splits = []
                if not new_separators:
                    # 这里看似会生成一个超长块,但考虑到最后一个分隔符是空(亦即按字符切分),这行代码其实是跑不到的。
                    final_chunks.append(s)
                else:
                    # 对于超长段,用下级分隔符递归切分
                    other_info = self._split_text(s, new_separators)
                    final_chunks.extend(other_info)
        if _good_splits:
            merged_text = self._merge_splits(_good_splits, _separator)
            final_chunks.extend(merged_text)
        return final_chunks

合并与重叠的逻辑则在_merge_splits函数中,关键处均有代码注释说明:

    def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
        # We now want to combine these smaller pieces into medium size
        # chunks to send to the LLM.
        separator_len = self._length_function(separator)

        docs = []
        current_doc: list[str] = []
        total = 0
        # splits是所谓的好段,我们将好段尽可能拼接为较大的块
        for d in splits:
            _len = self._length_function(d)
            # 拼到再加就超chunk_size,则封口、起新块
            if (
                total + _len + (separator_len if len(current_doc) > 0 else 0)
                > self._chunk_size
            ):
                if total > self._chunk_size:
                    logger.warning(
                        f"Created a chunk of size {total}, "
                        f"which is longer than the specified {self._chunk_size}"
                    )
                if len(current_doc) > 0:
                    doc = self._join_docs(current_doc, separator)
                    if doc is not None:
                        docs.append(doc)
                    # Keep on popping if:
                    # - we have a larger chunk than in the chunk overlap
                    # - or if we still have any chunks and the length is long
                    # 这一段是overlap的核心处理逻辑
                    # 算法回退<=chunk_overlap的长度,以便下个块A2包含前一个块A1末尾overlap的内容。特别的,若块A1的最后一段长度大于chunk_overlap,则为保证语义完整,不会生成overlap。current_doc里存的就是一个块的所有段。
                    while total > self._chunk_overlap or (
                        total + _len + (separator_len if len(current_doc) > 0 else 0)
                        > self._chunk_size
                        and total > 0
                    ):
                        total -= self._length_function(current_doc[0]) + (
                            separator_len if len(current_doc) > 1 else 0
                        )
                        current_doc = current_doc[1:]
            current_doc.append(d)
            total += _len + (separator_len if len(current_doc) > 1 else 0)
        doc = self._join_docs(current_doc, separator)
        if doc is not None:
            docs.append(doc)
        return docs
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐