优化plus
This commit is contained in:
@@ -18,6 +18,9 @@ class BigramPredictor(
|
||||
@Volatile private var word2id: Map<String, Int> = emptyMap()
|
||||
|
||||
@Volatile private var id2word: List<String> = emptyList()
|
||||
@Volatile private var topUnigrams: List<String> = emptyList()
|
||||
|
||||
private val unigramCacheSize = 2000
|
||||
|
||||
//预先加载语言模型,并构建词到ID和ID到词的双向映射。
|
||||
fun preload() {
|
||||
@@ -37,6 +40,7 @@ class BigramPredictor(
|
||||
word2id = map
|
||||
|
||||
id2word = m.vocab
|
||||
topUnigrams = buildTopUnigrams(m, unigramCacheSize)
|
||||
} catch (_: Throwable) {
|
||||
// 保持静默,允许无模型运行(仅 Trie 起作用)
|
||||
} finally {
|
||||
@@ -89,19 +93,34 @@ class BigramPredictor(
|
||||
return topKByScore(candidates, topK)
|
||||
}
|
||||
|
||||
// 3) 兜底:用 unigram + 前缀过滤
|
||||
val heap = topKHeap(topK)
|
||||
// 3) 兜底:用预计算的 unigram Top-N + 前缀过滤
|
||||
if (topK <= 0) return emptyList()
|
||||
|
||||
for (i in m.vocab.indices) {
|
||||
val w = m.vocab[i]
|
||||
val cachedUnigrams = getTopUnigrams(m)
|
||||
if (pfx.isEmpty()) {
|
||||
return cachedUnigrams.take(topK)
|
||||
}
|
||||
|
||||
if (pfx.isEmpty() || w.startsWith(pfx, ignoreCase = true)) {
|
||||
heap.offer(w to m.uniLogp[i])
|
||||
|
||||
if (heap.size > topK) heap.poll()
|
||||
val results = ArrayList<String>(topK)
|
||||
if (cachedUnigrams.isNotEmpty()) {
|
||||
for (w in cachedUnigrams) {
|
||||
if (w.startsWith(pfx, ignoreCase = true)) {
|
||||
results.add(w)
|
||||
if (results.size >= topK) return results
|
||||
}
|
||||
}
|
||||
}
|
||||
return heap.toSortedListDescending()
|
||||
|
||||
if (results.size < topK) {
|
||||
val fromTrie = safeTriePrefix(pfx, topK)
|
||||
for (w in fromTrie) {
|
||||
if (w !in results) {
|
||||
results.add(w)
|
||||
if (results.size >= topK) break
|
||||
}
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
//供上层在用户选中词时更新“上文”状态
|
||||
@@ -115,12 +134,33 @@ class BigramPredictor(
|
||||
if (prefix.isEmpty()) return emptyList()
|
||||
|
||||
return try {
|
||||
trie.startsWith(prefix).take(topK)
|
||||
trie.startsWith(prefix, topK)
|
||||
} catch (_: Throwable) {
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
|
||||
private fun getTopUnigrams(model: BigramModel): List<String> {
|
||||
val cached = topUnigrams
|
||||
if (cached.isNotEmpty()) return cached
|
||||
|
||||
val built = buildTopUnigrams(model, unigramCacheSize)
|
||||
topUnigrams = built
|
||||
return built
|
||||
}
|
||||
|
||||
private fun buildTopUnigrams(model: BigramModel, limit: Int): List<String> {
|
||||
if (limit <= 0) return emptyList()
|
||||
val heap = topKHeap(limit)
|
||||
|
||||
for (i in model.vocab.indices) {
|
||||
heap.offer(model.vocab[i] to model.uniLogp[i])
|
||||
if (heap.size > limit) heap.poll()
|
||||
}
|
||||
|
||||
return heap.toSortedListDescending()
|
||||
}
|
||||
|
||||
//从给定的候选词对列表中,通过一个小顶堆来过滤出评分最高的前k个词
|
||||
private fun topKByScore(pairs: List<Pair<String, Float>>, k: Int): List<String> {
|
||||
val heap = topKHeap(k)
|
||||
|
||||
Reference in New Issue
Block a user