1. 程式人生 > >【學習筆記】使用python進行音訊識別

【學習筆記】使用python進行音訊識別

直接上程式碼

my_audio.py

# -*- coding: utf-8 -*-
# Created: huashan

import os
import re
import wave

import numpy as np
import pyaudio


class voice():
    def loaddata(self, filepath):
        '''

        :param filepath: 檔案路徑,為wav檔案
        :return: 如果無異常則返回True,如果有異常退出並返回False
        self.wave_data內儲存著多通道的音訊資料,其中self.wave_data[0]代表第一通道
        具體有幾通道,看self.nchannels
        '''
        if type(filepath) != str:
            raise TypeError, 'the type of filepath must be string'
        p1 = re.compile('\.wav')
        if p1.findall(filepath) is None:
            raise IOError, 'the suffix of file must be .wav'
        try:
            f = wave.open(filepath, 'rb')
            params = f.getparams()
            self.nchannels, self.sampwidth, self.framerate, self.nframes = params[:4]
            str_data = f.readframes(self.nframes)
            self.wave_data = np.fromstring(str_data, dtype=np.short)
            self.wave_data.shape = -1, self.sampwidth
            self.wave_data = self.wave_data.T
            f.close()
            self.name = os.path.basename(filepath)  # 記錄下檔名
            return True
        except:
            raise IOError, 'File Error'

    def fft(self, frames=40):
        '''
        整體指紋提取的核心方法,將整個音訊分塊後分別對每塊進行傅立葉變換,之後分子帶抽取高能量點的下標
        :param frames: frames是指定每秒鐘分塊數
        :return:
        '''
        block = []
        fft_blocks = []
        self.high_point = []
        blocks_size = self.framerate / frames  # block_size為每一塊的frame數量
        blocks_num = self.nframes / blocks_size  # 將音訊分塊的數量
        for i in xrange(0, len(self.wave_data[0]) - blocks_size, blocks_size):
            block.append(self.wave_data[0][i:i + blocks_size])
            fft_blocks.append(np.abs(np.fft.fft(self.wave_data[0][i:i + blocks_size])))
            self.high_point.append((np.argmax(fft_blocks[-1][:40]),
                                    np.argmax(fft_blocks[-1][40:80]) + 40,
                                    np.argmax(fft_blocks[-1][80:120]) + 80,
                                    np.argmax(fft_blocks[-1][120:180]) + 120,
                                    # np.argmax(fft_blocks[-1][180:300]) + 180,
                                    ))

    def play(self, filepath):
        '''
        音訊播放方法
        :param filepath:檔案路徑
        :return:
        '''
        chunk = 1024
        wf = wave.open(filepath, 'rb')
        p = pyaudio.PyAudio()
        # 開啟聲音輸出流
        stream = p.open(format=p.get_format_from_width(wf.getsampwidth()),
                        channels=wf.getnchannels(),
                        rate=wf.getframerate(),
                        output=True)
        # 寫聲音輸出流進行播放
        while True:
            data = wf.readframes(chunk)
            if data == "": break
            stream.write(data)
        stream.close()
        p.terminate()


if __name__ == '__main__':
    p = voice()
    p.play('the_mess.wav')
    print p.name

 

plar.py

# -*- coding: utf-8 -*-
# Created: huashan

import os

import MySQLdb

import my_audio


class memory():
    def __init__(self, host, port, user, passwd, db):
        '''
        初始化的方法,主要是儲存連線資料庫的引數
        :param host:
        :param port:
        :param user:
        :param passwd:
        :param db:
        '''
        self.host = host
        self.port = port
        self.user = user
        self.passwd = passwd
        self.db = db

    def addsong(self, path):
        '''
        新增歌曲方法,將歌曲名和歌曲特徵指紋存到資料庫
        :param path: 歌曲路徑
        :return:
        '''
        if type(path) != str:
            raise TypeError, 'path need string'
        basename = os.path.basename(path)
        try:
            conn = MySQLdb.connect(host=self.host, port=self.port, user=self.user, passwd=self.passwd, db=self.db,
                                   charset='utf8')
        except:
            print 'DataBase error'
            return None
        cur = conn.cursor()
        namecount = cur.execute("select * from fingerprint.musicdata WHERE song_name = '%s'" % basename)
        if namecount > 0:
            print 'the song has been record!'
            return None
        v = my_audio.voice()
        v.loaddata(path)
        v.fft()
        cur.execute("insert into fingerprint.musicdata VALUES('%s','%s')" % (basename, v.high_point.__str__()))
        conn.commit()
        cur.close()
        conn.close()


    def fp_compare(self, search_fp, match_fp):
        '''

        :param search_fp: 查詢指紋
        :param match_fp: 庫中指紋
        :return:最大相似值 float
        '''
        if len(search_fp) > len(match_fp):
            return 0
        max_similar = 0
        search_fp_len = len(search_fp)
        match_fp_len = len(match_fp)
        for i in range(match_fp_len - search_fp_len):
            temp = 0
            for j in range(search_fp_len):
                if match_fp[i + j] == search_fp[j]:
                    temp += 1
            if temp > max_similar:
                max_similar = temp
        return max_similar

    def search(self, path):
        '''
        搜尋方法,輸入為檔案路徑
        :param path: 待檢索檔案路徑
        :return: 按照相似度排序後的列表,元素型別為tuple,二元組,歌曲名和相似匹配值
        '''
        #先計算出來我們的音訊指紋
        v = my_audio.voice()
        v.loaddata(path)
        v.fft()
        #嘗試連線資料庫
        try:
            conn = MySQLdb.connect(host=self.host, port=self.port, user=self.user, passwd=self.passwd, db=self.db,
                                   charset='utf8')
        except:
            raise IOError, 'DataBase error'
        cur = conn.cursor()
        cur.execute("SELECT * FROM fingerprint.musicdata")
        result = cur.fetchall()
        compare_res = []
        for i in result:
            compare_res.append((self.fp_compare(v.high_point[:-1], eval(i[1])), i[0]))
        compare_res.sort(reverse=True)
        cur.close()
        conn.close()
        print compare_res
        return compare_res

    def search_and_play(self, path):
        '''
        搜尋方法順帶了播放方法
        :param path:檔案路徑
        :return:
        '''
        v = my_audio.voice()
        v.loaddata(path)
        v.fft()
        try:
            conn = MySQLdb.connect(host=self.host, port=self.port, user=self.user, passwd=self.passwd, db=self.db,
                                   charset='utf8')
        except:
            print 'DataBase error'
            return None
        cur = conn.cursor()
        cur.execute("SELECT * FROM fingerprint.musicdata")
        result = cur.fetchall()
        compare_res = []
        for i in result:
            compare_res.append((self.fp_compare(v.high_point[:-1], eval(i[1])), i[0]))
        compare_res.sort(reverse=True)
        cur.close()
        conn.close()
        print compare_res
        v.play(compare_res[0][1])
        return compare_res


if __name__ == '__main__':
    sss = memory('localhost', 3306, 'root', 'huawei', 'fingerprint')
    sss.addsong('60542.wav')
    sss.addsong('70715.wav')
    sss.addsong('70342.wav')
    sss.search_and_play('70715_Convert.wav')