3. 이미지 내부의 문자 인식


파이썬을 이용한 머신러닝, 딥러닝 실전개발 입문

파이썬을 이용한 머신러닝, 딥러닝 실전개발 입문 책과 강의를 바탕으로한 코드 리뷰 및 정리입니다. 자세한 내용은 책과 강의를 참고해주세요.


4. 머신러닝

4-3. 이미지 내부의 문자 인식


MNIST - 손글씨 숫자 데이터

데이터 다운 링크

파일 이름설명
train-images-idx3-ubyte.gz학습 전용 이미지 데이터
train-labels-idx3-ubyte.gz학습 전용 레이블 데이터
t10k-images-idx3-ubyte.gz테스트 전용 이미지 데이터
t10k-labels-idx3-ubyte.gz테스트 전용 레이블 데이터

gz파일 압축을 해제하면 바이너리데이터가 나오는데, 편리하게 다루기 위해 CSV 파일 변환 과정을 거친다.

  • sturct 모듈 : 파이썬으로 바이너리 처리를 위한 모듈. 자세한 설명은 참고
    • struct.unpack() : 원하는 바이너리 수만큼 읽어 들이기 + 정수 변환
import struct

def to_csv (name):
    
    #레이블 파일과 이미지 파일 열기
    lbl_f=open("./mnist/"+name+"-labels-idx1-ubyte", "rb") #rb: read binary
    img_f=open("./mnist/"+name+"-images-idx3-ubyte", "rb")
    csv_f=open("./mnist/"+name+"_full"+".csv","w",encoding="utf-8")
    
    #헤더정보읽기
    mag, lbl_count=struct.unpack(">II",lbl_f.read(8))
    mag, img_count=struct.unpack(">II",img_f.read(8))
    rows, cols=struct.unpack(">II",img_f.read(8))
    pixels=rows*cols
    
    #이미지 데이터를 읽고 csv로 저장하기
    res=[]
    for idx in range(lbl_count):
        
        label=struct.unpack("B",lbl_f.read(1))[0]
        bdata=img_f.read(pixels)
        sdata=list(map(lambda n: str(n), bdata))
        csv_f.write(str(label)+",")
        csv_f.write(",".join(sdata)+"\r\n")
                
    csv_f.close()
    lbl_f.close()
    img_f.close()
    
#결과를 파일로 출력하기
to_csv("train")
to_csv("t10k")


이미지 데이터 학습시키기
  • 이미지를 벡터로 변환하여 input으로. 그 외의 방식은 위와 동일하다.
from sklearn import model_selection, svm, metrics
import pandas as pd

train_csv=pd.read_csv("./mnist/train_full.csv", header=None)
tk_csv=pd.read_csv("./mnist/t10k_full.csv", header=None)

def transfer(dat):
    output=[]
    for i in dat:
        output.append(float(i)/256)
    return output
    


#train_csv_data=train_csv.iloc[:,1:].values #iloc[row범위,col범위] #모든 row, 1이후의 col
train_csv_data=list(map(transfer,train_csv.iloc[:,1:].values)) #map은 iterable 반환 -> 리스트 변환 필요
train_csv_label=train_csv[0].values
#tk_csv_data=tk_csv.iloc[:,1:].values
tk_csv_data=list(map(transfer,tk_csv.iloc[:,1:].values)) 
tk_csv_label=tk_csv[0].values 

#학습하기
clf=svm.SVC()
clf.fit(train_csv_data, train_csv_label) #fit함수의 앞쪽에 들어가는 train_csv_data: 0~1사이의 요소여야 함.
predict=clf.predict(tk_csv_data)

#결과확인하기
score=metrics.accuracy_score(tk_csv_label, predict)
report=metrics.classification_report(tk_csv_label, predict)
print("정답률=",score)
print("report=\n", report)






© 2018. by yeo0

Powered by yeo0