Description
This is ai-forever/FRED-T5-1.7B model trained on Question-Answering, Question-Generation and Answer-Aware Question Generation tasks on russian dataset (hivaze/ru-AAQG-QA-QG)
Prompts
AAQG_PROMPT = "Сгенерируй вопрос по тексту, используя известный ответ. Текст: '{context}'. Ответ: '{answer}'."
QG_PROMPT = "Сгенерируй вопрос по тексту. Текст: '{context}'."
QA_PROMPT = "Сгенерируй ответ на вопрос по тексту. Текст: '{context}'. Вопрос: '{question}'."
Examples and code
from transformers import AutoTokenizer, T5ForConditionalGeneration
from functools import partial
saved_checkpoint = 'hivaze/AAQG-QA-QG-FRED-T5-1.7B'
tokenizer = AutoTokenizer.from_pretrained(saved_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(saved_checkpoint).cuda()
def generate_text(prompt, tokenizer, model, n=1, temperature=0.8, num_beams=3):
encoded_input = tokenizer.encode_plus(prompt, return_tensors='pt')
encoded_input = {k: v.to(model.device) for k, v in encoded_input.items()}
resulted_tokens = model.generate(**encoded_input,
max_new_tokens=64,
do_sample=True,
num_beams=num_beams,
num_return_sequences=n,
temperature=temperature,
top_p=0.9,
top_k=50)
resulted_texts = tokenizer.batch_decode(resulted_tokens, skip_special_tokens=True)
return resulted_texts
generate_text = partial(generate_text, tokenizer=tokenizer, model=model)
test_context = "Путешественник Федор Конюхов и пилот Игорь Потапкин установили мировой рекорд высоты полета на паралёте, поднявшись на высоту 4728 метров — сайт Конюхова"
AAQG
generate_text(AAQG_PROMPT.format(
context=test_context,
answer='на паралёте'
), n=1)
"На чем путешественник Федор Конюхов и пилот Игорь Потапкин установили мировой рекорд высоты полета?"
generate_text(AAQG_PROMPT.format(
context=test_context,
answer='рекорд высоты полета'
), n=1)
"Что установили путешественник Федор Конюхов и пилот Игорь Потапкин?"
QA
generate_text(QA_PROMPT.format(
context=test_context,
question='Что установили путешественник Федор Конюхов и пилот Игорь Потапкин?'
), n=1)
"Мировой рекорд высоты полета на паралёте"
QG
generate_text(QG_PROMPT.format(context=test_context), n=1)
"Кто установил мировой рекорд высоты полета на паралёте?"
Metrics
Step | Training Loss | Validation Loss | Sbleu | Chr F | Rouge1 | Rouge2 | Rougel |
---|---|---|---|---|---|---|---|
500 | 1.020500 | 1.059296 | 41.556000 | 66.391100 | 0.104200 | 0.033700 | 0.104200 |
1000 | 1.050200 | 0.998357 | 43.035900 | 66.376800 | 0.105100 | 0.034100 | 0.105200 |
1500 | 0.994000 | 0.966051 | 43.692200 | 66.597600 | 0.106300 | 0.034400 | 0.106400 |
2000 | 0.947800 | 0.953637 | 44.012400 | 66.711100 | 0.106600 | 0.034900 | 0.106800 |
2500 | 0.978200 | 0.944621 | 44.027900 | 66.657400 | 0.106500 | 0.034600 | 0.106500 |
Authors
- Sergei Bratchikov (https://t.me/nlpwanderer)
- Downloads last month
- 100
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.