上传Android项目
This commit is contained in:
@@ -0,0 +1,68 @@
|
||||
package com.example.myapplication.data
|
||||
|
||||
import android.content.Context
|
||||
import java.io.BufferedReader
|
||||
import java.io.InputStreamReader
|
||||
|
||||
data class BigramModel(
|
||||
val vocab: List<String>, // 保留全部词(含 <unk>, <s>, </s>),与二元矩阵索引对齐
|
||||
val uniLogp: FloatArray, // 长度 = vocab.size
|
||||
val biRowptr: IntArray, // 长度 = vocab.size + 1 (CSR)
|
||||
val biCols: IntArray, // 长度 = nnz
|
||||
val biLogp: FloatArray // 长度 = nnz
|
||||
)
|
||||
|
||||
object LanguageModelLoader {
|
||||
fun load(context: Context): BigramModel {
|
||||
val vocab = context.assets.open("vocab.txt").bufferedReader()
|
||||
.use(BufferedReader::readLines)
|
||||
|
||||
val uniLogp = readFloat32(context, "uni_logp.bin")
|
||||
val biRowptr = readInt32(context, "bi_rowptr.bin")
|
||||
val biCols = readInt32(context, "bi_cols.bin")
|
||||
val biLogp = readFloat32(context, "bi_logp.bin")
|
||||
|
||||
require(uniLogp.size == vocab.size) { "uni_logp length != vocab size" }
|
||||
require(biRowptr.size == vocab.size + 1) { "bi_rowptr length invalid" }
|
||||
require(biCols.size == biLogp.size) { "bi cols/logp nnz mismatch" }
|
||||
|
||||
return BigramModel(vocab, uniLogp, biRowptr, biCols, biLogp)
|
||||
}
|
||||
|
||||
private fun readInt32(context: Context, name: String): IntArray {
|
||||
context.assets.open(name).use { input ->
|
||||
val bytes = input.readBytes()
|
||||
val n = bytes.size / 4
|
||||
val out = IntArray(n)
|
||||
var i = 0; var j = 0
|
||||
while (i < n) {
|
||||
// 小端序
|
||||
val v = (bytes[j].toInt() and 0xFF) or
|
||||
((bytes[j+1].toInt() and 0xFF) shl 8) or
|
||||
((bytes[j+2].toInt() and 0xFF) shl 16) or
|
||||
((bytes[j+3].toInt() and 0xFF) shl 24)
|
||||
out[i++] = v
|
||||
j += 4
|
||||
}
|
||||
return out
|
||||
}
|
||||
}
|
||||
|
||||
private fun readFloat32(context: Context, name: String): FloatArray {
|
||||
context.assets.open(name).use { input ->
|
||||
val bytes = input.readBytes()
|
||||
val n = bytes.size / 4
|
||||
val out = FloatArray(n)
|
||||
var i = 0; var j = 0
|
||||
while (i < n) {
|
||||
val bits = (bytes[j].toInt() and 0xFF) or
|
||||
((bytes[j+1].toInt() and 0xFF) shl 8) or
|
||||
((bytes[j+2].toInt() and 0xFF) shl 16) or
|
||||
((bytes[j+3].toInt() and 0xFF) shl 24)
|
||||
out[i++] = Float.fromBits(bits)
|
||||
j += 4
|
||||
}
|
||||
return out
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user