优化plus

This commit is contained in:
pengxiaolong
2026-01-15 21:32:32 +08:00
parent a1fbc6417f
commit 673b4491d7
17 changed files with 649 additions and 346 deletions

View File

@@ -18,6 +18,9 @@ class BigramPredictor(
@Volatile private var word2id: Map<String, Int> = emptyMap()
@Volatile private var id2word: List<String> = emptyList()
@Volatile private var topUnigrams: List<String> = emptyList()
private val unigramCacheSize = 2000
//预先加载语言模型并构建词到ID和ID到词的双向映射。
fun preload() {
@@ -37,6 +40,7 @@ class BigramPredictor(
word2id = map
id2word = m.vocab
topUnigrams = buildTopUnigrams(m, unigramCacheSize)
} catch (_: Throwable) {
// 保持静默,允许无模型运行(仅 Trie 起作用)
} finally {
@@ -89,19 +93,34 @@ class BigramPredictor(
return topKByScore(candidates, topK)
}
// 3) 兜底:用 unigram + 前缀过滤
val heap = topKHeap(topK)
// 3) 兜底:用预计算的 unigram Top-N + 前缀过滤
if (topK <= 0) return emptyList()
for (i in m.vocab.indices) {
val w = m.vocab[i]
val cachedUnigrams = getTopUnigrams(m)
if (pfx.isEmpty()) {
return cachedUnigrams.take(topK)
}
if (pfx.isEmpty() || w.startsWith(pfx, ignoreCase = true)) {
heap.offer(w to m.uniLogp[i])
if (heap.size > topK) heap.poll()
val results = ArrayList<String>(topK)
if (cachedUnigrams.isNotEmpty()) {
for (w in cachedUnigrams) {
if (w.startsWith(pfx, ignoreCase = true)) {
results.add(w)
if (results.size >= topK) return results
}
}
}
return heap.toSortedListDescending()
if (results.size < topK) {
val fromTrie = safeTriePrefix(pfx, topK)
for (w in fromTrie) {
if (w !in results) {
results.add(w)
if (results.size >= topK) break
}
}
}
return results
}
//供上层在用户选中词时更新“上文”状态
@@ -115,12 +134,33 @@ class BigramPredictor(
if (prefix.isEmpty()) return emptyList()
return try {
trie.startsWith(prefix).take(topK)
trie.startsWith(prefix, topK)
} catch (_: Throwable) {
emptyList()
}
}
private fun getTopUnigrams(model: BigramModel): List<String> {
val cached = topUnigrams
if (cached.isNotEmpty()) return cached
val built = buildTopUnigrams(model, unigramCacheSize)
topUnigrams = built
return built
}
private fun buildTopUnigrams(model: BigramModel, limit: Int): List<String> {
if (limit <= 0) return emptyList()
val heap = topKHeap(limit)
for (i in model.vocab.indices) {
heap.offer(model.vocab[i] to model.uniLogp[i])
if (heap.size > limit) heap.poll()
}
return heap.toSortedListDescending()
}
//从给定的候选词对列表中通过一个小顶堆来过滤出评分最高的前k个词
private fun topKByScore(pairs: List<Pair<String, Float>>, k: Int): List<String> {
val heap = topKHeap(k)