from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
# transformers model load
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
def run_generation(user_text):
model_inputs = tokenizer([user_text], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(model_inputs,streamer=streamer)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
model_output = ""
for new_text in streamer:
model_output += new_text
yield f"data: {new_text}\n\n"
return model_output