nicolaus625 commited on
Commit
76cd386
1 Parent(s): 7c7f4cb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +116 -0
README.md CHANGED
@@ -1,3 +1,119 @@
1
  ---
2
  license: cc-by-4.0
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-4.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - music
7
+ - art
8
  ---
9
+
10
+ # MusiLingo-short-v1
11
+ This repo contains the code for the following paper.
12
+ __[MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response](https://arxiv.org/abs/2309.08730)__
13
+
14
+ You can refer to more information at the [GitHub repo](https://github.com/zihaod/MusiLingo)
15
+
16
+ You can use the [MusicInstruct (MI)](https://huggingface.co/datasets/m-a-p/Music-Instruct) dataset for the following demo
17
+
18
+ This checkpoint is developped on the MI-short.
19
+
20
+ # Inference Code
21
+ ```
22
+ from tqdm.auto import tqdm
23
+
24
+ import torch
25
+ from torch.utils.data import DataLoader
26
+ from transformers import Wav2Vec2FeatureExtractor
27
+ from transformers import StoppingCriteria, StoppingCriteriaList
28
+
29
+
30
+
31
+ class StoppingCriteriaSub(StoppingCriteria):
32
+ def __init__(self, stops=[], encounters=1):
33
+ super().__init__()
34
+ self.stops = stops
35
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
36
+ for stop in self.stops:
37
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
38
+ return True
39
+ return False
40
+
41
+ def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5,
42
+ repetition_penalty=1.0, length_penalty=1, temperature=0.1, max_length=2000):
43
+ audio = samples["audio"].cuda()
44
+ audio_embeds, atts_audio = self.encode_audio(audio)
45
+ if 'instruction_input' in samples: # instruction dataset
46
+ #print('Instruction Batch')
47
+ instruction_prompt = []
48
+ for instruction in samples['instruction_input']:
49
+ prompt = '<Audio><AudioHere></Audio> ' + instruction
50
+ instruction_prompt.append(self.prompt_template.format(prompt))
51
+ audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
52
+ self.llama_tokenizer.padding_side = "right"
53
+ batch_size = audio_embeds.shape[0]
54
+ bos = torch.ones([batch_size, 1],
55
+ dtype=torch.long,
56
+ device=torch.device('cuda')) * self.llama_tokenizer.bos_token_id
57
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
58
+ atts_bos = atts_audio[:, :1]
59
+ inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
60
+ attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
61
+ outputs = self.llama_model.generate(
62
+ inputs_embeds=inputs_embeds,
63
+ max_new_tokens=max_new_tokens,
64
+ stopping_criteria=stopping,
65
+ num_beams=num_beams,
66
+ do_sample=True,
67
+ min_length=min_length,
68
+ top_p=top_p,
69
+ repetition_penalty=repetition_penalty,
70
+ length_penalty=length_penalty,
71
+ temperature=temperature,
72
+ )
73
+ output_token = outputs[0]
74
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
75
+ output_token = output_token[1:]
76
+ if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
77
+ output_token = output_token[1:]
78
+ output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
79
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
80
+ output_text = output_text.split('Assistant:')[-1].strip()
81
+ return output_text
82
+
83
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
84
+ ds = CMIDataset(processor, 'path/to/MI_dataset', 'test', question_type='short')
85
+ dl = DataLoader(
86
+ ds,
87
+ batch_size=1,
88
+ num_workers=0,
89
+ pin_memory=True,
90
+ shuffle=False,
91
+ drop_last=True,
92
+ collate_fn=ds.collater
93
+ )
94
+
95
+ stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
96
+ torch.tensor([2277, 29937]).cuda()])])
97
+
98
+ from transformers import AutoModel
99
+ model_short = AutoModel.from_pretrained("m-a-p/MusiLingo-short-v1")
100
+
101
+ for idx, sample in tqdm(enumerate(dl)):
102
+ ans = answer(Musilingo_short.model, sample, stopping, length_penalty=100, temperature=0.1)
103
+ txt = sample['text_input'][0]
104
+ print(txt)
105
+ print(and)
106
+ ```
107
+
108
+ # Citing This Work
109
+
110
+ If you find the work useful for your research, please consider citing it using the following BibTeX entry:
111
+ ```
112
+ @inproceedings{deng2024musilingo,
113
+ title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
114
+ author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
115
+ booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
116
+ year={2024},
117
+ organization={Association for Computational Linguistics}
118
+ }
119
+ ```