Transformers.js 的基本使用
使用背景
目前 WebApp 使用 AI 模型一般是调用后台接口(模型的云服务或者在自己的服务器上使用模型),这样的方案存在服务器成本、网络稳定性、用户隐私等问题。所以,在这种情况下就可以使用 Transformers.js 在本地进行 AI 推理。
什么是 Transformers.js
Transformers.js 是由 Hugging Face 开发的一个 JavaScript 库,旨在让用户能够直接在浏览器中运行最先进的机器学习模型,而无需服务器支持。该库与 Hugging Face 的 Python 版 transformers 库功能等效,支持多种预训练模型,涵盖自然语言处理、计算机视觉和语音识别等任务。
Transformers.js 使用 ONNX Runtime 运行模型,支持在 CPU 和 WebGPU 上执行,提供了高效的模型转换和量化工具,方便用户将 PyTorch、TensorFlow 或 JAX 模型转换为 ONNX 格式并在浏览器中运行。
功能列表
- 自然语言处理:文本分类、命名实体识别、问答、语言建模、摘要、翻译、多项选择和文本生成。
- 计算机视觉:图像分类、对象检测、分割和深度估计。
- 语音识别:自动语音识别、音频分类和文本转语音。
- 多模态任务:嵌入、零镜头音频分类、零镜头图像分类和零镜头对象检测。
Transformers.js 在 v3 版本可以利用 WebGPU 进行高性能推理,速度相比 wasm 方案有了极大的提升。https://huggingface.co/blog/transformersjs-v3
使用
线上 DEMO:https://transformers-js-basic-use.vercel.app/
快速安装
# npm
npm i @huggingface/transformers
# pnpm
pnpm add @huggingface/transformers
接口使用
import { pipeline } from "@huggingface/transformers";
const generator = await pipeline('summarization', 'Xenova/distilbart-cnn-6-6', config);
const text = 'xxx';
const output = await generator(text, {
max_new_tokens: 100,
}); // [{ summary_text: 'xxx' }]
使用 pipeline
接口(pipeline
简化了模型的下载、加载和使用)加载模型,第一个参数传递的是 task 类型,这里使用的是 ‘summarization’,第二个参数是模型名字(如果忽略的话会使用默认模型),第三个参数是配置信息(包含进度、缓存、设置等),这里主要了解下 config.progress_callback
,config.device
和 config.dtype
的使用。
-
config.progress_callback
:进度回调。这里的进度包含初始化、下载、加载和准备阶段progress_callback: data => { switch (data.status) { // 模型开始初始化 case "initiate": { const { name, file } = data; console.log("initiate", name, file); } break; // 模型开始下载 case "download": { const { name, file } = data; console.log("download", name, file); } break; // 模型下载进度 case "progress": { const { name, file, progress, loaded, total } = data; console.log("progress", name, file, progress, loaded, total); } break; // 模型下载完成 case "done": { const { name, file } = data; console.log("done", name, file); } break; // 模型准备完成 case "ready": { const { task, model } = data; console.log("ready", task, model); } break; } },
-
config.device
:设置推理的设置,默认使用的是wasm
,如果条件允许的话尽量使用webgpu
,推理速度会更快。完整的 device 类型如下:/** * The list of devices supported by Transformers.js */ export const DEVICE_TYPES = Object.freeze({ auto: 'auto', // Auto-detect based on device and environment gpu: 'gpu', // Auto-detect GPU cpu: 'cpu', // CPU wasm: 'wasm', // WebAssembly webgpu: 'webgpu', // WebGPU cuda: 'cuda', // CUDA dml: 'dml', // DirectML webnn: 'webnn', // WebNN (default) 'webnn-npu': 'webnn-npu', // WebNN NPU 'webnn-gpu': 'webnn-gpu', // WebNN GPU 'webnn-cpu': 'webnn-cpu', // WebNN CPU });
-
config.dtype
:设置推理的精度,如果忽略的话会根据device
使用不同的dtype
,如果device
是wasm
的话dtype
为q8
,否则为fp32
。需要根据具体的应用场景和模型使用参数,一般来说webgpu
使用fp32
或者fp16
,wasm
使用q8
或者q4
。
如果使用的模型没有在内置的 pipeline
中,那么需要自己手动处理模型的逻辑
// 使用 pipline
await pipeline("summarization", "Xenova/distilbart-cnn-6-6");
const summary = await generatorRef.current(input);
const output = summary[0].summary_text;
// 不适用 pipline
const model = await AutoModelForSeq2SeqLM.from_pretrained("Xenova/distilbart-cnn-6-6");
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/distilbart-cnn-6-6");
const inputs = await tokenizerRef.current([input], {
truncation: true,
return_tensors: true,
});
const modelOutputs = await model .generate(inputs);
const summary = await tokenizer .batch_decode(modelOutputs, {
skip_special_tokens: true,
});
const output = summary[0];