使用 MinHash 进行文本去重
本文的主要内容是介绍如何基于文章《BigCode 背后的大规模数据去重》提到的方法构建一个能够对中文文本进行去重的 MinHash 方案实现。
2024-04-20
在文章《BigCode 背后的大规模数据去重》中介绍了一种基于 MinHash 实现文本去重的方法。本文将介绍基于原文内容实现的一个针对中文文本去重的方案。
使用 MinHash 进行文本去重涉及到的基础概念
总的来说,文本去重的基本原理就是确认文本间的相似度,然后移除相似度高的条目,因此首先要做的就是将文本转换为可计算的表现形式,在原文中使用的是 N 元词袋模型。N 元词袋模型是 NLP 中常用的文本表示方法,它将文本看作是由连续的 N 个单词组成的序列,并将这些 N 个单词作为一个整体来处理,而不考虑它们之间的顺序。每个单词是一个独立的标记,通常被称为单元(unigram)或者一元(1-gram)。这种模型忽略了单词之间的顺序信息,只关注单词的频率。
如果直接使用文本的 n-gram 表示计算相似度,所需的计算量会非常的大。因此在原文中使用 MinHash 对原始文本的向量表示进行降维并基于降维后的向量表示计算文本间的相似度,从而大大减少了所需的计算量。MinHash 是局部敏感哈希(Locality Sensitive Hashing,LSH)的一种具体实现,其核心思想是“如果两个数据点在原始空间中足够接近,那么经过一个随机的投影之后,它们在新空间中的位置关系也会接近”,在实现上有两种基本形式(原文中使用的是方法 2):
- 对一个目标集合应用多个哈希函数获取一个降维后的向量表示;
- 对目标集合中的元素进行重排,然后对重排后得到的结果应用相同哈希函数获取降维后的向量表示。
在完成降维后,再使用 Jaccard 相似度 去计算相似度,最终完成去重。
使用 MinHash 进行文本去重的具体流程
根据原文中的描述结合个人理解,基于 MinHash 进行文本去重的流程应为:
-
读取待去重的文本条目(假设共有
条文本);N
-
针对每条文本条目(假设文本拆分后词的数量为
):M
2.1 使用 N 元词袋模型转化文本条目(转换后会得到
个 ngram 组);M-1
2.2 针对转换后的 ngram 表示(假设为词袋大小为 p),使用 MinHash 将文本条目转换为固定长度的向量表示(假设向量大小为
,则完成这一步需要k
次循环);p!*(M-1)
2.3 将文档向量表示按固定尺寸拆分为 LSH 带(假设带宽为
,则一个文档包含l
个 LSH 带);k//l
-
针对所有的文档,如果两个文档中包含相同的 LSH 带,则将其放置到同一个 LSH 桶中(最多会生成
个桶,此时每个桶中只包含一个元素;最少会生成一个桶,此时桶中包含N*(N-1)*(k/l)
个元素);N
-
针对所有包含复数个文本条目的桶,计算桶中文本条目间的 ngram 组的 Jaccard 相似度,如果相似度大于阈值则标记为重复项目(当上一步只生成一个桶时是最坏情况,需要
次循环);N*(N-1)
-
使用上一步得到的重复条目标记即可获取去重后的结构。
根据上面的流程可知,该算法的时间复杂度为:
O(N+p!*(M-1)+N*(N-1)*(k/l)+N(N-1))
N^2
O(N^2)
MinHash 的代码实现
下面的是具体的代码,需要说明的内容会在代码中以注释的形式给出。
# pip install -U tqdm
import math
import re
import time
from collections import defaultdict
from itertools import permutations
from pathlib import Path
from typing import List, Tuple
from tqdm import tqdm
# 定义如何拆分文本,以获取用于生成 n 元词袋的表示
# 对于中文文本而言更好的做法是进行分词并移除停用词
def split(source: str):
return list(source)
# 定义如何将 n 元词袋连接起来一遍进行 hash
# 这里直接转换为字符串就可以
def join(source: List[str]):
return "".join(source)
# 缓存哈希计算结果
HASH_CACHE = dict()
# 具体的哈希函数,这里使用的是 sha1
def hash_func(x):
import hashlib
if HASH_CACHE.get(x, None) is None:
HASH_CACHE[x] = int(hashlib.sha1(x.encode('utf-8')).hexdigest(), 16)
return HASH_CACHE.get(x)
class Document(object):
"""
定义一个文档对象,在文档对象内进行拆分词袋、计算哈希和拆分带的操作
"""
def __init__(self, id: int, content: str, ngram_n=3, band_size=2):
self.id = id
self._content = content
self.band_size = band_size
self.ngram_n = ngram_n
self._ngrams = list()
self._min_hash = list()
self._lsh_bands = list()
def prepare(self):
self._lsh_bands
@property
def content(self):
return self._content
@property
def ngrams(self):
"""
获取文档的 n 元词袋表示
"""
if len(self._ngrams) != 0:
return self._ngrams
self._ngrams = Document.split_ngrams(self.id, self._content, self.ngram_n)
return self._ngrams
@property
def min_hash(self):
"""
获取文档的 MinHash 表示
"""
if len(self._min_hash) != 0:
return self._min_hash
self._min_hash = Document.calc_min_hash(self.ngrams)
return self._min_hash
@property
def lsh_bands(self):
"""
获取文档的 LSH 带
"""
if len(self._lsh_bands) != 0:
return self._lsh_bands
return self.prepare_lsh_bands()
def prepare_lsh_bands(self):
band_count = len(self.min_hash) // self.band_size
self._lsh_bands: List[str] = list()
for i in range(band_count):
band = self.min_hash[i*self.band_size:(i+1)*self.band_size]
self._lsh_bands.append("-".join(map(lambda x: str(x), band)))
return self._lsh_bands
def calc_similarity(self, rhs):
"""
计算两个文档之间的 jaccard 相似度
"""
intersection = len(set(self.ngrams).intersection(set(rhs.ngrams)))
union = len(set(self.ngrams).union(set(rhs.ngrams)))
print(intersection, union)
return intersection / union
@staticmethod
def split_ngrams(_id: int, source: str, ngram_n: int) -> List[str]:
"""
拆分 n 元词袋
"""
no_symbols = re.sub(r'[^\w\s]', '', source)
parts = split(no_symbols)
ngrams = []
for i in range(len(parts) - (ngram_n - 1)):
ngrams.append(join(parts[i:i+ngram_n]))
if len(ngrams) == 0:
raise Exception(
f"[ERROR] 无法为数据[{_id}] 生成 {ngram_n}-grams,source={list(set(source.strip()))}")
return ngrams
@staticmethod
def calc_min_hash(ngrams: List[str]) -> List[int]:
"""
计算 MinHash
使用的是第二种方法,即对一个集合的排列应用特定的哈希函数
通过这里步骤可以将一个文档表示为一个长度为 n! 的向量
其中 n 为 n 元词袋的大小
"""
def generate_ngrams_signature(ngram: str):
result: List[Tuple[str, List[int]]] = list()
hash_values: List[int] = list()
for value in permutations(split(ngram)):
hash_values.append(hash_func(join(value)))
result.append((ngram, hash_values))
return hash_values
ngrams_hashes = list(map(generate_ngrams_signature, ngrams))
return [
min(map(lambda x: x[i], ngrams_hashes))
for i in range(len(ngrams_hashes[0]))
]
class LSH(object):
"""
执行 LSH 流程
"""
def __init__(self):
self._buckets = None
def build(self, docs: List[Document]):
"""
构建 LSH 桶并针对有复数个文档的桶进行比较、去重
"""
if self._buckets is not None:
return self._buckets
buckets = defaultdict(set)
for doc in tqdm(docs, "[INFO] ✂ 拆分 band"):
doc.prepare_lsh_bands()
for index, i_doc in tqdm(enumerate(docs), "[INFO] 🪣 构建 bucket"):
for j_doc in docs[index + 1:]:
for band in i_doc.lsh_bands:
buckets[band].add(i_doc.id)
if band in j_doc.lsh_bands:
buckets[band].add(j_doc.id)
# 下面这一步很重要,如果文档重复不多的话,能减少很多的工作量
buckets = filter(
lambda x: len(x) > 1,
[list(bucket) for bucket in buckets.values()]
)
self._buckets = list()
duplicated = set()
# 移除元素重复的桶
for bucket in buckets:
value_str = " ".join(map(lambda x: str(x), sorted(bucket)))
if value_str in duplicated:
continue
duplicated.add(value_str)
self._buckets.append(bucket)
return self._buckets
class Deduper(object):
"""
去重入口,在这里配置相关参数并进行去重
"""
def __init__(self, docs: List[str], sim_threshold=0.8, ngram_n=3, band_size=3):
self.sim_threshold = sim_threshold
self.ngram_n = ngram_n
self.band_size = band_size
if math.factorial(ngram_n) < band_size:
raise Exception("LSH 带宽不能大于 {n} * {n - 1}")
self.docs = list(map(
lambda p: Document(p[0], p[1], self.ngram_n, self.band_size),
enumerate(docs)
))
self.lsh = LSH()
self.duplicated = dict()
self._result = list()
def execute(self):
if len(self._result) != 0:
return self._result
buckets = self.lsh.build(self.docs)
print(f"[INFO] 待处理 buckets 数量:{len(buckets)}")
# 针对所有桶,根据桶中条目的 jaccard 相似度挑选出重复的条目
for bucket in buckets:
print("[INFO] 正在处理 bucket:", bucket)
for i in range(len(bucket)):
self._check_bucket_by_jaccard(i, bucket)
# 将标记为重复的条目移除
self._result = list(
filter(lambda d: not self.duplicated.get(d.id, False), self.docs))
return self._result
def _check_bucket_by_jaccard(self, i, bucket):
i_id = bucket[i]
i_doc: Document = self.docs[i_id]
if self.duplicated.get(i_id, False):
return
for j in range(i+1, len(bucket)):
j_id = bucket[j]
j_doc: Document = self.docs[j_id]
if self.duplicated.get(j_id, False):
continue
sim = i_doc.calc_similarity(j_doc)
if sim >= self.sim_threshold:
print(f"[INFO] {i_id}/{j_id}: {sim}")
self.duplicated[j_id] = True
# 移除无效的文本(不能为空,长度应大于词袋大小)
def filter_len(x: str):
uniq = "".join(set(list(x.strip())))
return len(uniq) > ngram_n
if __name__ == '__main__':
import sys
if len(sys.argv) < 3:
print("[INFO] usage: python dedupe.py <to_dedupe> <deduped_file>")
exit(1)
to_dedupe = sys.argv[1]
deduped_file = sys.argv[2]
sim_threshold, ngram_n, bands_size = 0.8, 3, 3
print(
f"[INFO] 开始 LSH 去重,sim_threshold={sim_threshold} ngram_n={ngram_n} bands_size={bands_size}")
# 使用示例
import time
started = time.time()
lines = open(to_dedupe).readlines()
docs = list(filter(filter_len, tqdm(lines, "[INFO] 🌳 筛选文本(按长度及字符)")))
print(f"[INFO] 文件已读取,待处理文本条数:{len(docs)}")
deduper = Deduper(docs, sim_threshold, ngram_n, bands_size)
result = deduper.execute()
duplicated_count = len(docs) - len(result)
print(f"[INFO] 重复数量:{duplicated_count},重复率:{duplicated_count / len(docs)}")
Path(deduped_file).write_text(
"\n".join(map(lambda x: x.content, result))
)
print(f"[INFO] 已完成清洗,共耗时 {time.time() - started} 秒")
dedup.py
上述代码仅供参考,并不适用于大量数据的处理。