kz209 commited on
Commit
5e8ccd5
β€’
1 Parent(s): c52847e
Files changed (2) hide show
  1. pages/summarization_playground.py +61 -14
  2. utils/model.py +4 -4
pages/summarization_playground.py CHANGED
@@ -12,6 +12,60 @@ import logging
12
 
13
  load_dotenv()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  __model_on_gpu__ = ''
16
  model = {model_name: None for model_name in Model.__model_list__}
17
 
@@ -53,14 +107,14 @@ def get_model_batch_generation(model_name):
53
  return model[model_name]
54
 
55
 
56
- def generate_answer(sources, model_name, prompt):
57
  model_device_check(model_name)
58
  content = prompt + '\n{' + sources + '}\n\nsummary:'
59
- answer = model[model_name].gen(content)[0].strip()
60
 
61
  return answer
62
 
63
- def process_input(input_text, model_selection, prompt):
64
  if input_text:
65
  logging.info("Start generation")
66
  response = generate_answer(input_text, model_selection, prompt)
@@ -75,13 +129,14 @@ def update_input(example):
75
  return examples[example]
76
 
77
  def create_summarization_interface():
78
- with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm",text_size="sm")) as demo:
79
  gr.Markdown("## This is a playground to test prompts for clinical dialogue summarizations")
80
 
81
  with gr.Row():
82
  example_dropdown = gr.Dropdown(choices=list(examples.keys()), label="Choose an example", value=random_label)
83
  model_dropdown = gr.Dropdown(choices=Model.__model_list__, label="Choose a model", value=Model.__model_list__[0])
84
 
 
85
  Template_text = gr.Textbox(value="""Summarize the following dialogue""", label='Input Prompting Template', lines=8, placeholder='Input your prompts')
86
  datapoint = random.choice(dataset)
87
  input_text = gr.Textbox(label="Input Dialogue", lines=10, placeholder="Enter text here...", value=datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue'])
@@ -89,24 +144,16 @@ def create_summarization_interface():
89
 
90
  with gr.Row():
91
  with gr.Row():
92
- with gr.Column():
93
- gr.Markdown("<div style='border: 4px solid white; padding: 3px; border-radius: 5px;width:100px;padding-top: 0.5px;padding-bottom: 10px;'><h3>Prompt πŸ‘₯</h3></center></div>")
94
- prompt = gr.Textbox(label="Input", lines=6, placeholder = "Enter the Patient-Doctor conversation here.",elem_classes=["bordered-text"])
95
- context = gr.Textbox(label="Context", placeholder="Enter relevant context about the patient medical history.",elem_classes="bordered-text")
96
- token = gr.Textbox(label="Token",elem_classes="bordered-text")
97
  with gr.Column():
98
  gr.Markdown("<div style='border: 4px solid white; padding: 2px; border-radius: 5px;width:130px;padding-bottom: 10px;'><b><h3>Parameters πŸ“ˆ</h3></center></b></div>")
99
- with gr.Column():
100
- topK = gr.Textbox(label="TopP",elem_classes="bordered-text")
101
- topP = gr.Textbox(label="TopK",elem_classes="bordered-text")
102
  temperature = gr.Textbox(label="Temperature",elem_classes="parameter-text")
103
  max_new_tokens = gr.Textbox(label="Max New Tokens",elem_classes="parameter-text")
104
  do_sample = gr.Dropdown(['Default','None'],label="Do Sample",elem_classes="parameter-text")
105
- return_text = gr.Dropdown(['Default','None'],label="Return Text",elem_classes="parameter-text")
106
  output = gr.Markdown(line_breaks=True)
107
 
108
  example_dropdown.change(update_input, inputs=[example_dropdown], outputs=[input_text])
109
- submit_button.click(process_input, inputs=[input_text, model_dropdown, Template_text], outputs=[output])
110
 
111
  return demo
112
 
 
12
 
13
  load_dotenv()
14
 
15
+ custom_css = """
16
+ gradio-app {
17
+ background: #eeeefc !important;
18
+ }
19
+ .bordered-text {
20
+ border-style: solid;
21
+ border-width: 1px;
22
+ padding: 5px;
23
+ margin-bottom: 0px;
24
+ border-radius: 1px;
25
+ font-family: Verdana;
26
+ font-size: 20px !important;
27
+ font-weight: bold ;
28
+ color:#000000;
29
+ }
30
+ .parameter-text {
31
+ border-style: solid;
32
+ border-width: 1px;
33
+ padding: 5px;
34
+ margin-bottom: 0px;
35
+ border-radius: 1px;
36
+ font-family: Verdana;
37
+ font-size: 10px !important;
38
+ font-weight: bold ;
39
+ color:#000000;
40
+ }
41
+ .title {
42
+ font-size: 35px;
43
+ font-weight: maroon;
44
+ font-family: Helvetica;
45
+ }
46
+ input-label {
47
+ font-size: 20px;
48
+ font-weight: bold;
49
+ font-family: Papyrus;
50
+ }
51
+ .custom-button {
52
+ background-color: white !important /* Green background */
53
+ color: black; /* White text */
54
+ border: none; /* Remove border */
55
+ padding: 10px 20px; /* Add padding */
56
+ text-align: center; /* Center text */
57
+ display: inline-block; /* Inline block */
58
+ font-size: 22px; /* Font size */
59
+ margin: 4px 2px; /* Margin */
60
+ cursor: pointer; /* Pointer cursor on hover */
61
+ border-radius: 4px; /* Rounded corners */
62
+ }
63
+ .custom-button:hover {
64
+ background-color: black;
65
+ color: white;
66
+ }
67
+ """
68
+
69
  __model_on_gpu__ = ''
70
  model = {model_name: None for model_name in Model.__model_list__}
71
 
 
107
  return model[model_name]
108
 
109
 
110
+ def generate_answer(sources, model_name, prompt, temperature, max_new_tokens, do_sample):
111
  model_device_check(model_name)
112
  content = prompt + '\n{' + sources + '}\n\nsummary:'
113
+ answer = model[model_name].gen(content,temperature,max_new_tokens,do_sample)[0].strip()
114
 
115
  return answer
116
 
117
+ def process_input(input_text, model_selection, prompt, temperature, max_new_tokens, do_sample):
118
  if input_text:
119
  logging.info("Start generation")
120
  response = generate_answer(input_text, model_selection, prompt)
 
129
  return examples[example]
130
 
131
  def create_summarization_interface():
132
+ with gr.Blocks(theme=gr.themes.Soft(spacing_size="sm",text_size="sm"), css=custom_css) as demo:
133
  gr.Markdown("## This is a playground to test prompts for clinical dialogue summarizations")
134
 
135
  with gr.Row():
136
  example_dropdown = gr.Dropdown(choices=list(examples.keys()), label="Choose an example", value=random_label)
137
  model_dropdown = gr.Dropdown(choices=Model.__model_list__, label="Choose a model", value=Model.__model_list__[0])
138
 
139
+ gr.Markdown("<div style='border: 4px solid white; padding: 3px; border-radius: 5px;width:100px;padding-top: 0.5px;padding-bottom: 10px;'><h3>Prompt πŸ‘₯</h3></center></div>")
140
  Template_text = gr.Textbox(value="""Summarize the following dialogue""", label='Input Prompting Template', lines=8, placeholder='Input your prompts')
141
  datapoint = random.choice(dataset)
142
  input_text = gr.Textbox(label="Input Dialogue", lines=10, placeholder="Enter text here...", value=datapoint['section_text'] + '\n\nDialogue:\n' + datapoint['dialogue'])
 
144
 
145
  with gr.Row():
146
  with gr.Row():
 
 
 
 
 
147
  with gr.Column():
148
  gr.Markdown("<div style='border: 4px solid white; padding: 2px; border-radius: 5px;width:130px;padding-bottom: 10px;'><b><h3>Parameters πŸ“ˆ</h3></center></b></div>")
149
+ with gr.Column(theme=gr.themes.Soft(spacing_size="sm",text_size="sm")):
 
 
150
  temperature = gr.Textbox(label="Temperature",elem_classes="parameter-text")
151
  max_new_tokens = gr.Textbox(label="Max New Tokens",elem_classes="parameter-text")
152
  do_sample = gr.Dropdown(['Default','None'],label="Do Sample",elem_classes="parameter-text")
 
153
  output = gr.Markdown(line_breaks=True)
154
 
155
  example_dropdown.change(update_input, inputs=[example_dropdown], outputs=[input_text])
156
+ submit_button.click(process_input, inputs=[input_text,model_dropdown,Template_text,temperature,max_new_tokens,do_sample], outputs=[output])
157
 
158
  return demo
159
 
utils/model.py CHANGED
@@ -55,14 +55,14 @@ class Model(torch.nn.Module):
55
  def return_model(self):
56
  return self.model
57
 
58
- def streaming(self, content_list, temp=0.001, max_length=500):
59
  # Convert list of texts to input IDs
60
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
 
62
  # Set up the initial generation parameters
63
  gen_kwargs = {
64
  "input_ids": input_ids,
65
- "do_sample": True,
66
  "temperature": temp,
67
  "eos_token_id": self.tokenizer.eos_token_id,
68
  "max_new_tokens": 1, # Generate one token at a time
@@ -96,7 +96,7 @@ class Model(torch.nn.Module):
96
  gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]
97
 
98
 
99
- def gen(self, content_list, temp=0.001, max_length=500):
100
  # Convert list of texts to input IDs
101
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
102
 
@@ -104,7 +104,7 @@ class Model(torch.nn.Module):
104
  outputs = self.model.generate(
105
  input_ids,
106
  max_new_tokens=max_length,
107
- do_sample=True,
108
  temperature=temp,
109
  eos_token_id=self.tokenizer.eos_token_id,
110
  )
 
55
  def return_model(self):
56
  return self.model
57
 
58
+ def streaming(self, content_list, temp=0.001, max_length=500, do_sample=True):
59
  # Convert list of texts to input IDs
60
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
 
62
  # Set up the initial generation parameters
63
  gen_kwargs = {
64
  "input_ids": input_ids,
65
+ "do_sample": do_sample,
66
  "temperature": temp,
67
  "eos_token_id": self.tokenizer.eos_token_id,
68
  "max_new_tokens": 1, # Generate one token at a time
 
96
  gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]
97
 
98
 
99
+ def gen(self, content_list, temp=0.001, max_length=500, do_sample=True):
100
  # Convert list of texts to input IDs
101
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
102
 
 
104
  outputs = self.model.generate(
105
  input_ids,
106
  max_new_tokens=max_length,
107
+ do_sample=do_sample,
108
  temperature=temp,
109
  eos_token_id=self.tokenizer.eos_token_id,
110
  )