8 Star 28 Fork 9

Gitee 极速下载 / huatuo-llama

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/SCIR-HI/Huatuo-Llama-Med-Chinese
克隆/下载
generate.py 5.36 KB
一键复制 编辑 原始数据 按行查看 历史
s65b40 提交于 2023-08-07 21:46 . Update Huozi-based model
import sys
import fire
import gradio as gr
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except: # noqa: E722
pass
def main(
load_8bit: bool = False,
base_model: str = "",
lora_weights: str = "tloen/alpaca-lora-7b",
prompt_template: str = "med_template", # The prompt template to use, will default to alpaca.
server_name: str = "0.0.0.0", # Allows to listen on all interfaces by providing '0.0.0.0'
share_gradio: bool = True,
):
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
prompter = Prompter(prompt_template)
tokenizer = AutoTokenizer.from_pretrained(base_model)
if device == "cuda":
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
lora_weights,
torch_dtype=torch.float16,
)
elif device == "mps":
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map={"": device},
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model, device_map={"": device}, low_cpu_mem_usage=True
)
model = PeftModel.from_pretrained(
model,
lora_weights,
device_map={"": device},
)
# unwind broken decapoda-research config
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
if not load_8bit:
model.half() # seems to fix bugs for some users.
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
def evaluate(
instruction,
input=None,
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
max_new_tokens=128,
**kwargs,
):
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s)
return prompter.get_response(output)
gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2,
label="Instruction",
placeholder="Tell me about alpacas.",
),
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(
minimum=0, maximum=1, value=0.1, label="Temperature"
),
gr.components.Slider(
minimum=0, maximum=1, value=0.75, label="Top p"
),
gr.components.Slider(
minimum=0, maximum=100, step=1, value=40, label="Top k"
),
gr.components.Slider(
minimum=1, maximum=4, step=1, value=4, label="Beams"
),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
),
],
outputs=[
gr.inputs.Textbox(
lines=5,
label="Output",
)
],
title="BenTsao",
description="", # noqa: E501
).launch(server_name=server_name, share=share_gradio)
# Old testing code follows.
"""
# testing code for readme
for instruction in [
"Tell me about alpacas.",
"Tell me about the president of Mexico in 2019.",
"Tell me about the king of France in 2019.",
"List all Canadian provinces in alphabetical order.",
"Write a Python program that prints the first 10 Fibonacci numbers.",
"Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.", # noqa: E501
"Tell me five words that rhyme with 'shock'.",
"Translate the sentence 'I have no mouth but I must scream' into Spanish.",
"Count up from 1 to 500.",
]:
print("Instruction:", instruction)
print("Response:", evaluate(instruction))
print()
"""
if __name__ == "__main__":
fire.Fire(main)
Python
1
https://gitee.com/mirrors/huatuo-llama.git
git@gitee.com:mirrors/huatuo-llama.git
mirrors
huatuo-llama
huatuo-llama
main

搜索帮助