3. 이미지 내부의 문자 인식
in Data on Machine Learning
파이썬을 이용한 머신러닝, 딥러닝 실전개발 입문
파이썬을 이용한 머신러닝, 딥러닝 실전개발 입문 책과 강의를 바탕으로한 코드 리뷰 및 정리입니다. 자세한 내용은 책과 강의를 참고해주세요.
link: github
4. 머신러닝
4-3. 이미지 내부의 문자 인식
MNIST - 손글씨 숫자 데이터
파일 이름 | 설명 |
---|---|
train-images-idx3-ubyte.gz | 학습 전용 이미지 데이터 |
train-labels-idx3-ubyte.gz | 학습 전용 레이블 데이터 |
t10k-images-idx3-ubyte.gz | 테스트 전용 이미지 데이터 |
t10k-labels-idx3-ubyte.gz | 테스트 전용 레이블 데이터 |
gz파일 압축을 해제하면 바이너리데이터가 나오는데, 편리하게 다루기 위해 CSV 파일 변환 과정을 거친다.
- sturct 모듈 : 파이썬으로 바이너리 처리를 위한 모듈. 자세한 설명은 참고
- struct.unpack() : 원하는 바이너리 수만큼 읽어 들이기 + 정수 변환
import struct
def to_csv (name):
#레이블 파일과 이미지 파일 열기
lbl_f=open("./mnist/"+name+"-labels-idx1-ubyte", "rb") #rb: read binary
img_f=open("./mnist/"+name+"-images-idx3-ubyte", "rb")
csv_f=open("./mnist/"+name+"_full"+".csv","w",encoding="utf-8")
#헤더정보읽기
mag, lbl_count=struct.unpack(">II",lbl_f.read(8))
mag, img_count=struct.unpack(">II",img_f.read(8))
rows, cols=struct.unpack(">II",img_f.read(8))
pixels=rows*cols
#이미지 데이터를 읽고 csv로 저장하기
res=[]
for idx in range(lbl_count):
label=struct.unpack("B",lbl_f.read(1))[0]
bdata=img_f.read(pixels)
sdata=list(map(lambda n: str(n), bdata))
csv_f.write(str(label)+",")
csv_f.write(",".join(sdata)+"\r\n")
csv_f.close()
lbl_f.close()
img_f.close()
#결과를 파일로 출력하기
to_csv("train")
to_csv("t10k")
이미지 데이터 학습시키기
- 이미지를 벡터로 변환하여 input으로. 그 외의 방식은 위와 동일하다.
from sklearn import model_selection, svm, metrics
import pandas as pd
train_csv=pd.read_csv("./mnist/train_full.csv", header=None)
tk_csv=pd.read_csv("./mnist/t10k_full.csv", header=None)
def transfer(dat):
output=[]
for i in dat:
output.append(float(i)/256)
return output
#train_csv_data=train_csv.iloc[:,1:].values #iloc[row범위,col범위] #모든 row, 1이후의 col
train_csv_data=list(map(transfer,train_csv.iloc[:,1:].values)) #map은 iterable 반환 -> 리스트 변환 필요
train_csv_label=train_csv[0].values
#tk_csv_data=tk_csv.iloc[:,1:].values
tk_csv_data=list(map(transfer,tk_csv.iloc[:,1:].values))
tk_csv_label=tk_csv[0].values
#학습하기
clf=svm.SVC()
clf.fit(train_csv_data, train_csv_label) #fit함수의 앞쪽에 들어가는 train_csv_data: 0~1사이의 요소여야 함.
predict=clf.predict(tk_csv_data)
#결과확인하기
score=metrics.accuracy_score(tk_csv_label, predict)
report=metrics.classification_report(tk_csv_label, predict)
print("정답률=",score)
print("report=\n", report)