def get_chat_response(job):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_query = job["input"]["input_query"]
base_model, llama_tokenizer = create_base_model()
prompt = f"""
something
"""
model_input = llama_tokenizer(prompt, return_tensors="pt").to(device)
prompt_len = len(prompt)
base_model.eval()
with torch.no_grad():
resp = llama_tokenizer.decode(base_model.generate(**model_input, max_new_tokens=500)[0], skip_special_tokens=True)
resp = extract_regex(resp)
return resp
def create_base_model():
model_id="/base/13B-chat"
peft_id="/base/LLM_Finetune/tmp3/llama-output"
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
#quantization_config=quant_config,
device_map='auto'
)
base_model.config.use_cache = False
base_model.config.pretraining_tp = 1
llama_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_tokenizer.padding_side = "right" # Fix for fp16
base_model = PeftModel.from_pretrained(
base_model,
peft_id,
)
return base_model, llama_tokenizer
def get_chat_response(job):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_query = job["input"]["input_query"]
base_model, llama_tokenizer = create_base_model()
prompt = f"""
something
"""
model_input = llama_tokenizer(prompt, return_tensors="pt").to(device)
prompt_len = len(prompt)
base_model.eval()
with torch.no_grad():
resp = llama_tokenizer.decode(base_model.generate(**model_input, max_new_tokens=500)[0], skip_special_tokens=True)
resp = extract_regex(resp)
return resp
def create_base_model():
model_id="/base/13B-chat"
peft_id="/base/LLM_Finetune/tmp3/llama-output"
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
#quantization_config=quant_config,
device_map='auto'
)
base_model.config.use_cache = False
base_model.config.pretraining_tp = 1
llama_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_tokenizer.padding_side = "right" # Fix for fp16
base_model = PeftModel.from_pretrained(
base_model,
peft_id,
)
return base_model, llama_tokenizer