Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does the marian model have a method like huggingface generate? #414

Open
wolf-li opened this issue Sep 11, 2023 · 4 comments
Open

Does the marian model have a method like huggingface generate? #414

wolf-li opened this issue Sep 11, 2023 · 4 comments

Comments

@wolf-li
Copy link

wolf-li commented Sep 11, 2023

Using pipline is slower than using python huggingface library transformers generate function, when the model file is loaded, in using CPU envierment.

@guillaume-be
Copy link
Owner

The pipeline should not be slower than the Python equivalent on the same device. If you are using a CUDA-enabled GPU, please ensure it is used for both frameworks. The Marian model exposes a generate method via the MarianGenerator struct and the LanguageGenerator trait.

@wolf-li
Copy link
Author

wolf-li commented Sep 12, 2023

The pipeline should not be slower than the Python equivalent on the same device. If you are using a CUDA-enabled GPU, please ensure it is used for both frameworks. The Marian model exposes a generate method via the MarianGenerator struct and the LanguageGenerator trait.

Thanks for your replay.
When the Marian model calls the pipeline using the GPU (specify the use of GPU Device::Cuda(3), observe that nvidia-smi is occupied when running rust programs), Slower than python calls to the huggingface library in docker environments without a GPU.

Python code (cpu: 0.3s Average of 100 visits)

from fastapi import FastAPI
from pydantic import BaseModel
from transformers import MarianTokenizer, MarianMTModel
# Load the Marian model and tokenizer
model_name = "Helsinki-NLP/opus-mt-zh-en"  # Replace with your desired model
model = MarianMTModel.from_pretrained(model_name)
tokenizer = MarianTokenizer.from_pretrained(model_name)

app = FastAPI()

class InputData(BaseModel):
    text: str

@app.post("/v1/predict", response_model=dict(generation_text=str))
async def predict(input_data: InputData):
    # Translate the input text
    input_text = input_data.text
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    translation_ids = model.generate(input_ids, max_length=50, num_return_sequences=1)
    generation_text = tokenizer.decode(translation_ids[0], skip_special_tokens=True)

    return {"generation_text": generation_text}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

Rust code (cpu: 0.71s gpu: 0.38s Average of 100 visits)

extern crate anyhow;
use actix_web::{error, get, post, web, 
    http::{header::ContentType, StatusCode},
    App, HttpResponse, Responder, Result, HttpRequest,HttpServer};
use serde::Serialize;
use serde::Deserialize;
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider, Resource, LocalResource};
use tch::Device;
use rust_bert::marian::{
    MarianSourceLanguages,MarianTargetLanguages, 
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
use derive_more::{Display, Error};
use std::sync::{Arc, RwLock};
use anyhow::Error;

#[derive(Deserialize)]
struct Input {
    text: String,
}

#[derive(Serialize)]
struct Output {
    generation_text: String,
}

struct ModelFile {
    config_resource:LocalResource,
    weights: Arc<RwLock<Vec<u8>>>,
    vocab_resource: LocalResource,
    merges_resource: LocalResource,
}

impl ModelFile {
    fn new(model_path:String, config_path:String, vocab_path:String, merges_path:String) -> Self  {
        let weights =  Arc::new(RwLock::new(get_weights(model_path.clone()).unwrap()));

        let config_resource = LocalResource { local_path: config_path.into(), };
        let vocab_resource = LocalResource { local_path: vocab_path.into(), };
        let merges_resource = LocalResource { local_path: merges_path.into(), };
        Self {
            weights,
            config_resource,
            vocab_resource,
            merges_resource,
        }
    }
    
    fn genertation(&self, input_context:&str) -> Result<impl Responder, MyError> {

        let source_languages = MarianSourceLanguages::CHINESE2ENGLISH;
        let target_languages = MarianTargetLanguages::CHINESE2ENGLISH;

        let translation_config = TranslationConfig::new(
            ModelType::Marian,
            // ModelResource::Torch(Box::new(BufferResource { data: self.weights })),
            ModelResource::Torch(Box::new(BufferResource { data: Arc::clone(&self.weights) })),
            self.config_resource.clone(),
            self.vocab_resource.clone(),
            Some(self.merges_resource.clone()),
            source_languages,
            target_languages,
            // Device::Cpu,
            Device::Cuda(3),
        );

        let model = TranslationModel::new(translation_config).map_err(|e| {
            MyError::ModelLoadError
        })?;
    
        // let output = model.translate(&[input_context.to_string()], None, None);
        let output = model.translate(&[input_context.to_string()], None, None).map_err(|e| {
            MyError::TranslateError
        });
    
        match output {
            Ok(vec) => {
                if let Some(first_element) = vec.get(0) {
                    Ok(web::Json(Output { generation_text: first_element.to_string(),
                    }))
                } 
                else{
                    Err(MyError::TranslateError)
                }
            }
            Err(error) => {
                // Handle the error case
                Err(MyError::TranslateError)
            }
        }
           
    }

}

#[derive(Debug, Display, Error)]
enum MyError {
    #[display(fmt = "translationModel load error")]
    ModelLoadError,

    #[display(fmt = "translate error")]
    TranslateError,
}

impl error::ResponseError for MyError {
    fn error_response(&self) -> HttpResponse {
        HttpResponse::build(self.status_code())
            .insert_header(ContentType::html())
            .body(self.to_string())
    }

    fn status_code(&self) -> StatusCode {
        match *self {
            MyError::ModelLoadError => StatusCode::INTERNAL_SERVER_ERROR,
            MyError::TranslateError => StatusCode::INTERNAL_SERVER_ERROR,
        }
    }
}


#[post("/v1/predict")]
async fn predicet_post(
    info: web::Json<Input>,
    appdata: web::Data<ModelFile>,
) -> Result<impl Responder, MyError> {
    let result = appdata.genertation(&info.text);
    result
}


#[actix_web::main]
async fn main() -> std::io::Result<()> {

    let appdata = ModelFile::new(
        "/root/.cache/.rustbert/opus-mt-zh-en/rust_model.ot".to_string(),
        "/root/.cache/.rustbert/opus-mt-zh-en/config.json".to_string(),
        "/root/.cache/.rustbert/opus-mt-zh-en/vocab.json".to_string(),
        "/root/.cache/.rustbert/opus-mt-zh-en/source.spm".to_string(),
    );

    let appdata = web::Data::new(appdata);

    HttpServer::new(move || {
        App::new()
            // .app_data(web::Data::clone(&appdata.clone()))
            .app_data(web::Data::clone(&appdata))
            .service(predicet_post)
            // .service(index2)
    })
    .workers(4)
    .bind(("0.0.0.0", 8090))?
    .run()
    .await
}


fn get_weights(model_path: String) -> anyhow::Result<Vec<u8>, anyhow::Error> {
    Ok(std::fs::read(model_path)?)
}

@guillaume-be
Copy link
Owner

Hello @wolf-li ,

Do you compile the code in release mode with all optimizations?

@linkedlist771
Copy link

Each time you call the generation function, it will creat a new model (load from your disk and init) , I think it would be the cause of it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants