-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does the marian model have a method like huggingface generate? #414
Comments
The pipeline should not be slower than the Python equivalent on the same device. If you are using a CUDA-enabled GPU, please ensure it is used for both frameworks. The Marian model exposes a generate method via the MarianGenerator struct and the LanguageGenerator trait. |
Thanks for your replay. Python code (cpu: 0.3s Average of 100 visits) from fastapi import FastAPI
from pydantic import BaseModel
from transformers import MarianTokenizer, MarianMTModel
# Load the Marian model and tokenizer
model_name = "Helsinki-NLP/opus-mt-zh-en" # Replace with your desired model
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)
app = FastAPI()
class InputData(BaseModel):
text: str
@app.post("/v1/predict", response_model=dict(generation_text=str))
async def predict(input_data: InputData):
# Translate the input text
input_text = input_data.text
input_ids = tokenizer.encode(input_text, return_tensors="pt")
translation_ids = model.generate(input_ids, max_length=50, num_return_sequences=1)
generation_text = tokenizer.decode(translation_ids[0], skip_special_tokens=True)
return {"generation_text": generation_text}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) Rust code (cpu: 0.71s gpu: 0.38s Average of 100 visits) extern crate anyhow;
use actix_web::{error, get, post, web,
http::{header::ContentType, StatusCode},
App, HttpResponse, Responder, Result, HttpRequest,HttpServer};
use serde::Serialize;
use serde::Deserialize;
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider, Resource, LocalResource};
use tch::Device;
use rust_bert::marian::{
MarianSourceLanguages,MarianTargetLanguages,
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use derive_more::{Display, Error};
use std::sync::{Arc, RwLock};
use anyhow::Error;
#[derive(Deserialize)]
struct Input {
text: String,
}
#[derive(Serialize)]
struct Output {
generation_text: String,
}
struct ModelFile {
config_resource:LocalResource,
weights: Arc<RwLock<Vec<u8>>>,
vocab_resource: LocalResource,
merges_resource: LocalResource,
}
impl ModelFile {
fn new(model_path:String, config_path:String, vocab_path:String, merges_path:String) -> Self {
let weights = Arc::new(RwLock::new(get_weights(model_path.clone()).unwrap()));
let config_resource = LocalResource { local_path: config_path.into(), };
let vocab_resource = LocalResource { local_path: vocab_path.into(), };
let merges_resource = LocalResource { local_path: merges_path.into(), };
Self {
weights,
config_resource,
vocab_resource,
merges_resource,
}
}
fn genertation(&self, input_context:&str) -> Result<impl Responder, MyError> {
let source_languages = MarianSourceLanguages::CHINESE2ENGLISH;
let target_languages = MarianTargetLanguages::CHINESE2ENGLISH;
let translation_config = TranslationConfig::new(
ModelType::Marian,
// ModelResource::Torch(Box::new(BufferResource { data: self.weights })),
ModelResource::Torch(Box::new(BufferResource { data: Arc::clone(&self.weights) })),
self.config_resource.clone(),
self.vocab_resource.clone(),
Some(self.merges_resource.clone()),
source_languages,
target_languages,
// Device::Cpu,
Device::Cuda(3),
);
let model = TranslationModel::new(translation_config).map_err(|e| {
MyError::ModelLoadError
})?;
// let output = model.translate(&[input_context.to_string()], None, None);
let output = model.translate(&[input_context.to_string()], None, None).map_err(|e| {
MyError::TranslateError
});
match output {
Ok(vec) => {
if let Some(first_element) = vec.get(0) {
Ok(web::Json(Output { generation_text: first_element.to_string(),
}))
}
else{
Err(MyError::TranslateError)
}
}
Err(error) => {
// Handle the error case
Err(MyError::TranslateError)
}
}
}
}
#[derive(Debug, Display, Error)]
enum MyError {
#[display(fmt = "translationModel load error")]
ModelLoadError,
#[display(fmt = "translate error")]
TranslateError,
}
impl error::ResponseError for MyError {
fn error_response(&self) -> HttpResponse {
HttpResponse::build(self.status_code())
.insert_header(ContentType::html())
.body(self.to_string())
}
fn status_code(&self) -> StatusCode {
match *self {
MyError::ModelLoadError => StatusCode::INTERNAL_SERVER_ERROR,
MyError::TranslateError => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
#[post("/v1/predict")]
async fn predicet_post(
info: web::Json<Input>,
appdata: web::Data<ModelFile>,
) -> Result<impl Responder, MyError> {
let result = appdata.genertation(&info.text);
result
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
let appdata = ModelFile::new(
"/root/.cache/.rustbert/opus-mt-zh-en/rust_model.ot".to_string(),
"/root/.cache/.rustbert/opus-mt-zh-en/config.json".to_string(),
"/root/.cache/.rustbert/opus-mt-zh-en/vocab.json".to_string(),
"/root/.cache/.rustbert/opus-mt-zh-en/source.spm".to_string(),
);
let appdata = web::Data::new(appdata);
HttpServer::new(move || {
App::new()
// .app_data(web::Data::clone(&appdata.clone()))
.app_data(web::Data::clone(&appdata))
.service(predicet_post)
// .service(index2)
})
.workers(4)
.bind(("0.0.0.0", 8090))?
.run()
.await
}
fn get_weights(model_path: String) -> anyhow::Result<Vec<u8>, anyhow::Error> {
Ok(std::fs::read(model_path)?)
} |
Hello @wolf-li , Do you compile the code in |
Each time you call the |
Using pipline is slower than using python huggingface library transformers generate function, when the model file is loaded, in using CPU envierment.
The text was updated successfully, but these errors were encountered: