File size: 1,760 Bytes
ce814eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f3e17d
0821eba
ce814eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch

from typing import Dict, List, Any
from transformers import AutoTokenizer, BitsAndBytesConfig
from peft import AutoPeftModelForCausalLM


def parse_output(text):
    marker = "### Response:"
    if marker in text:
        pos = text.find(marker) + len(marker)
    else:
        pos = 0
    return text[pos:].replace("<pad>", "").replace("</s>", "").strip()


class EndpointHandler:
    def __init__(self, path="./", use_bnb=True):

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
        self.model = AutoPeftModelForCausalLM.from_pretrained(
            path, load_in_8bit=False, quantization_config=bnb_config, device_map="auto"
        )

        self.tokenizer = AutoTokenizer.from_pretrained(path)
        print("Memory footprint: ", self.model.get_memory_footprint())
        print("Device map: ", self.model.hf_device_map)

    def __call__(self, data: Any) -> List[List[Dict[str, str]]]:


        inputs = data.get("inputs", data)
        prompt = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction: \n{inputs}\n\n### Response: \n"
        parameters = data.get("parameters", {})
        with torch.no_grad():
            inputs = self.tokenizer(
                prompt, return_tensors="pt", return_token_type_ids=False
            ).to(self.model.device)
            outputs = self.model.generate(**inputs, **parameters)

        return {
            "generated_text": parse_output(
                self.tokenizer.decode(outputs[0].tolist(), skip_special_tokens=True)
            )
        }