Files
Android-key-of-love/app/src/main/java/com/example/myapplication/BigramPredictor.kt
pengxiaolong 673b4491d7 优化plus
2026-01-15 21:32:32 +08:00

199 lines
6.2 KiB
Kotlin
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.example.myapplication.data
import android.content.Context
import com.example.myapplication.Trie
import java.util.concurrent.atomic.AtomicBoolean
import java.util.PriorityQueue
import kotlin.math.max
class BigramPredictor(
private val context: Context,
private val trie: Trie
) {
@Volatile private var model: BigramModel? = null
private val loading = AtomicBoolean(false)
// 词 ↔ id 映射
@Volatile private var word2id: Map<String, Int> = emptyMap()
@Volatile private var id2word: List<String> = emptyList()
@Volatile private var topUnigrams: List<String> = emptyList()
private val unigramCacheSize = 2000
//预先加载语言模型并构建词到ID和ID到词的双向映射。
fun preload() {
if (!loading.compareAndSet(false, true)) return
Thread {
try {
val m = LanguageModelLoader.load(context)
model = m
// 建索引vocab 与 bigram 索引对齐,注意不丢前三个符号)
val map = HashMap<String, Int>(m.vocab.size * 2)
m.vocab.forEachIndexed { idx, w -> map[w] = idx }
word2id = map
id2word = m.vocab
topUnigrams = buildTopUnigrams(m, unigramCacheSize)
} catch (_: Throwable) {
// 保持静默,允许无模型运行(仅 Trie 起作用)
} finally {
loading.set(false)
}
}.start()
}
// 模型是否已准备好
fun isReady(): Boolean = model != null
//基于上文 lastWord可空与前缀 prefix 联想,优先bigram 条件概率 → Trie 过滤 → Top-K,兜底unigram Top-K同样做 Trie 过滤)
fun suggest(prefix: String, lastWord: String?, topK: Int = 10): List<String> {
val m = model
val pfx = prefix.trim()
if (m == null) {
// 模型未载入时,纯 Trie 前缀联想(你的 Trie 应提供类似 startsWith
return safeTriePrefix(pfx, topK)
}
val candidates = mutableListOf<Pair<String, Float>>()
val lastId = lastWord?.let { word2id[it] }
if (lastId != null) {
// 1) bigram 邻域
val start = m.biRowptr[lastId]
val end = m.biRowptr[lastId + 1]
if (start in 0..end && end <= m.biCols.size) {
// 先把 bigram 候选过一遍前缀过滤
for (i in start until end) {
val nextId = m.biCols[i]
val w = m.vocab[nextId]
if (pfx.isEmpty() || w.startsWith(pfx, ignoreCase = true)) {
val score = m.biLogp[i] // logP(next|last)
candidates += w to score
}
}
}
}
// 2) 如果有 bigram 过滤后的候选,直接取 topK
if (candidates.isNotEmpty()) {
return topKByScore(candidates, topK)
}
// 3) 兜底:用预计算的 unigram Top-N + 前缀过滤
if (topK <= 0) return emptyList()
val cachedUnigrams = getTopUnigrams(m)
if (pfx.isEmpty()) {
return cachedUnigrams.take(topK)
}
val results = ArrayList<String>(topK)
if (cachedUnigrams.isNotEmpty()) {
for (w in cachedUnigrams) {
if (w.startsWith(pfx, ignoreCase = true)) {
results.add(w)
if (results.size >= topK) return results
}
}
}
if (results.size < topK) {
val fromTrie = safeTriePrefix(pfx, topK)
for (w in fromTrie) {
if (w !in results) {
results.add(w)
if (results.size >= topK) break
}
}
}
return results
}
//供上层在用户选中词时更新“上文”状态
fun normalizeWordForContext(word: String): String? {
// 你可以在这里做大小写/符号处理,或将 OOV 映射为 <unk>
return if (word2id.containsKey(word)) word else "<unk>"
}
//在Trie数据结构中查找与给定前缀匹配的字符串并返回其中评分最高的topK个结果。
private fun safeTriePrefix(prefix: String, topK: Int): List<String> {
if (prefix.isEmpty()) return emptyList()
return try {
trie.startsWith(prefix, topK)
} catch (_: Throwable) {
emptyList()
}
}
private fun getTopUnigrams(model: BigramModel): List<String> {
val cached = topUnigrams
if (cached.isNotEmpty()) return cached
val built = buildTopUnigrams(model, unigramCacheSize)
topUnigrams = built
return built
}
private fun buildTopUnigrams(model: BigramModel, limit: Int): List<String> {
if (limit <= 0) return emptyList()
val heap = topKHeap(limit)
for (i in model.vocab.indices) {
heap.offer(model.vocab[i] to model.uniLogp[i])
if (heap.size > limit) heap.poll()
}
return heap.toSortedListDescending()
}
//从给定的候选词对列表中通过一个小顶堆来过滤出评分最高的前k个词
private fun topKByScore(pairs: List<Pair<String, Float>>, k: Int): List<String> {
val heap = topKHeap(k)
for (p in pairs) {
heap.offer(p)
if (heap.size > k) heap.poll()
}
return heap.toSortedListDescending()
}
//创建一个优先队列,用于在一组候选词对中保持评分最高的 k 个词。
private fun topKHeap(k: Int): PriorityQueue<Pair<String, Float>> {
// 小顶堆,比较 Float 分数
return PriorityQueue(k.coerceAtLeast(1)) { a, b ->
a.second.compareTo(b.second) // 分数小的优先被弹出
}
}
// 排序后的候选词列表
private fun PriorityQueue<Pair<String, Float>>.toSortedListDescending(): List<String> {
val list = ArrayList<Pair<String, Float>>(this.size)
while (this.isNotEmpty()) {
val p = this.poll() ?: continue // 防御性判断,避免 null
list.add(p)
}
list.reverse() // 从高分到低分
return list.map { it.first }
}
}