본문 바로가기
개발/AI 코드

인공지능 koBERT 모델 학습

by beomcoder 2023. 2. 11.
728x90
반응형

추천시스템에 쓰일 '태그'를 달기 위해 모델을 하나 제작하고 있다. 다른 모델들도 많지만 koELECTRA와 기타 모델은 데이터 전처리를 모델에 맞게 해주지 않아서 그런가 정확도가 높지 않았다. 그래서 그나마 높은 정확도를 보여준 koBERT학습 후기를 남길까 한다.

 

1. BERT는 무엇인가?

먼저 BERT라는 것은 위키피디아(25억 단어)와 BooksCorpus(8억 단어)로 pretrain 되어 있는 기계번역 모델이다. 하지만 외국에서 만든 것이다 보니 영어에 대해 정확도가 높다. 한국어에 대해서는 영어보다 정확도가 떨어진다.  좋은 알고리즘을 갖고 있는 BERT 모델을 한국어에도 잘 활용할 수 있도록 만들어진 것 중에 하나가 바로 SKT에서 만든 KoBERT모델(https://github.com/SKTBrain/KoBERT)이다.

 

세세하게 공부하고 싶은 사람은 따로 찾아보시는 것을 추천한다. 내가 틀리게 말할 수도 있다.

 

BERT는 transformer를 12~24개의 layer로 쌓아놓은 것이다. transformer는 간단히 말해서 위치정보를 가진 값으로 나타내주는 것인데 설명하자면 위의 오른쪽 그림의 단어 뜻은 '그 동물은 길을 건너지 않았다. 왜냐하면 그것은 너무 피곤하였기 때문이다.' 인데, 그것이라는 것이 street인지 animal인지 구분할 때 입력 문장 내의 단어들끼리 유사도를 구하므로서 그것(it)이 동물(animal)과 연관되었을 확률이 높다는 것을 찾아냅니다. 더 궁금하면 찾아보길 바란다.

결론적으로 BERT는 총 3개의 임베딩 층이 사용된다.

  • WordPiece Embedding : 실질적인 입력이 되는 워드 임베딩. 임베딩 벡터의 종류는 단어 집합의 크기로 30,522개.
  • Position Embedding : 위치 정보를 학습하기 위한 임베딩. 임베딩 벡터의 종류는 문장의 최대 길이인 512개.
  • Segment Embedding : 두 개의 문장을 구분하기 위한 임베딩. 임베딩 벡터의 종류는 문장의 최대 개수인 2개.

 

이런식으로 BERT모델이 구성되어 있는데 나는 여기에 우리의 데이터를 추가로 학습시켜서 원하는 방향으로 모델을 만들 것이다. 영화 리뷰 감성 분류, 로이터 뉴스 분류 등과 같이 입력된 문서에 대해서 분류를 하는 유형으로 문서의 시작에 [CLS] 라는 토큰을 입력한다. 사전 훈련 단계에서 다음 문장 예측을 설명할 때, [CLS] 토큰은 BERT가 분류 문제를 풀기위한 특별 토큰이다. 텍스트 분류 문제를 풀기 위해서 [CLS] 토큰의 위치의 출력층에서 밀집층(Dense layer) 또는 같은 이름으로는 완전 연결층(fully-connected layer)이라고 불리는 층들을 추가하여 분류에 대한 예측을 하게 된다.

 

 

2. koBERT를 이용하여 카테고리 분류해보기

우선 COLAB환경에서 진행했다.

자세히 몰라도 따라할 수 있게 코드리뷰를 간단히 할까 한다.

 

중간에 이상한 소리들이 많다.

이상한 소리들은 기울임꼴로 표시해두었으니 넘어가도 상관없다.

 

처음에는 로컬에서 해보려고 했다.

하지만 나는 윈도우에서 작업을 하고 있고 윈도우에서는 mxnet이 깔리지가 않는다.

해결방법을 알고 있는 분은 알려주면 좋을 것 같다.

왜 mxnet은 pypi에서 mxnet 1.9.1이 있는데 pip install mxnet을 하면 0.7.0 버전이 깔리는지 모르겠다.

mxnet이 0.7.0버전이라 numpy는 1.16.6버전을 깔아야 하는데 1.16.6버전은 다른 것의 호환성문제로 오류와 오류를 불러 어렵다.

코랩에서 mxnet을 설치하면 1.9.1버전이 잘깔리는 데 아마 리눅스와 윈도우의 차이지 않을까 싶다.

이 부분은 내가 겪었던 스트레스라 적은 것이고 코랩환경에서 한다면 문제없이 돌아간다.

 

2-1. 데이터 수집 및 전처리

koBERT를 이용하기전에 먼저 추천서비스를 만들어야 하기 때문에 어떤 방식으로 추천을 해줘야 할지 고민했다. 스타트업이기 때문에 데이터가 많이 쌓여있지 않았고, 어떤 방식으로 만들어야 할지 무엇을 하고 싶은지 명확하게 나와있지 않았다. 그래서 먼저 추천알고리즘은 어떤 것이 있는지 찾아보았고, 나 나름대로 임의로 만들어 보기로 했다. 추천서비스는 공장에서 불량품, 신호위반을 잡는 CCTV라던지 정확도가 엄청나게 중요하지는 않아서 마음편히 일단 만들어보자 생각했다.

 

내가 생각한 방법은 보여지는 관심사에 추천하는 것이 아니라 채팅, 피드를 보는 체류시간 등을 분석하여 모임을 추천해주는 방식이다. 피드, 체류시간은 아직 어플개발단계에서 적용되지 않아서 채팅을 활용해보기로 했다. 하지만 채팅데이터도 많지 않았다. 채팅 데이터를 DB에서 가지고 와서 카테고리를 분류하고 클러스터링을 통해 비슷한 내용을 찾고 카테고리를 붙여주고 다시 학습을 시켜보았더니 정확도가 30% 언저리가 나왔다. 일단은 데이터를 다른 곳에서 찾아와서 전처리를 해줘야겠다고 생각했다.

 

 

AI-Hub

※ 내국인만 데이터 신청이 가능합니다. 목록 데이터 개요 데이터 변경이력 데이터 변경이력 버전 일자 변경내용 비고 1.0 2022-07-12 데이터 최초 개방 데이터 히스토리 데이터 히스토리 일자 변경

aihub.or.kr

 

AI-HUB에서 주제별 텍스트 일상 대화 데이터를 가지고 왔다. 대화를 20가지의 카테고리로 분류해놓은 것이다. 

1 2 3 4 5 6 7 8 9 10
식음료 주거와
생활
교통 회사/아르바이트 군대 교육 가족 연애/결혼 반려동물 스포츠/
레저
11 12 13 14 15 16 17 18 19 20
게임 여행 계절/날씨 사회이슈 타 국가
이슈
미용 건강 상거래
전반
방송/연예 영화/만화

 

{
  "dataset": {
    "identifier": 67915,
    "name": "KAKAO_981_16_set",
    "src_path": "/data/file/cubeManager/PROJECT001/53/txt20211006122540087223/KAKAO_981_16_set/",
    "label_path": "/data/file/cubeManager/PROJECT001/53/txt20211006122540087223/KAKAO_981_16_set/",
    "category": 2,
    "type": 0
  },
  "licenses": {
    "name": "Apache License 1.0",
    "url": "http://www.apache.org/licenses/LICENSE-1.0"
  },
  "info": [
    {
      "id": 43016,
      "filename": "KAKAO_981_16.txt",
      "title": "KAKAO_981_16",
      "mediatype": "SNS",
      "medianame": "카카오톡",
      "category": "일상대화",
      "date": "2021-10-06",
      "size": 754,
      "annotations": {
        "subject": "미용",
        "speaker_type": "다자간 대화",
        "size": 754,
        "word_size": 233,
        "text": "1 : 저 요새 다이어트가 진짜 필요해요\n2 : 아이 다이어트는 언제나 하는 거 아니에요? 하하\n3 : 저도 항상 다이어트 해야 한다고 말하지만 뭐 행복한 게 최고 아니겠어요? ㅎ\n1 : 저도요 행복한 돼지의 삶도 좋지만 옷이 안 맞아요 키키\n2 : 앗... 평소엔 잊고 살다가 옷 입으면 느껴지죠\n3 : 그렇다면 언니 조금 큰 사이즈를 사세요!\n1 : 어유 지금도 고무줄 바지만 입어요 키키\n2 : 맞아요 지난번에 입은 옷들 다 진짜 편해 보여서 저도 사고 싶었어요\n3 : 린넨? 재질이어서 시원해 보이더라고요\n1 : 근데 이게 무서운 게 제가 막 다이어트 홈트 이런 거 알아보니까 지흡 관련한 영상이 많이 뜨더라고요\n2 : 어머 지흡... 너무 무서워요\n3 : 그거 진짜 위험한 거라면서요\n1 : 네네 근데 효과가 진짜 드라마틱 해서 혹해요\n2 : 근데 부작용도 장난 아니에요\n3 : 강남에서 지흡 시술 받다가 죽은 사람도 있지 않아요? ㅠ\n1 : 맞아요 그리고 막 못 걷게 된 사람도 있어요\n2 : 내가 그렇게 되지 않을 거라는 보장이 어디 있어요 ㅠ\n2 : 우리 그냥 운동해요!\n3 : 우리 그러면 강아지랑 산책할 때 좀 빠르게 걸어보는 거 어때요?\n1 : 좋아요! 사월이 다리 근육은 나날이 좋아지는데 전 왜 계속 살이 찌는지...\n2 : 허벅지 근육이 발달돼야 살이 안 찐다던데\n3 : 큰 근육이어서요? 오호... 좋은 정보네요\n1 : 저는 허벅지 진짜 살밖에 없는 거 같아요\n2 : 아무래도 앉아있는 시간이 길면 그런 거 같아요\n3 : 그럼 우리 오늘부터라도 좀 속보를 해봐요!",
        "lines": [
          {
            "id": 1,
            "text": "1 : 저 요새 다이어트가 진짜 필요해요",
            "norm_text": "저 요새 다이어트가 진짜 필요해요",
            "speaker": {
              "id": "1번",
              "sex": "여성",
              "age": "30대"
            },
            "speechAct": "(단언) 주장하기",
            "morpheme": "저/MM+요새/NNG+다이어트/NNG+가/JKS+진짜/MAG+필요/NNG+해요/XSV+EC"
          },
          {
            "id": 2,
            "text": "2 : 아이 다이어트는 언제나 하는 거 아니에요? 하하",
            "norm_text": "아이 다이어트는 언제나 하는 거 아니에요? 하하",

다운로드 받고 나서 파일을 열어보았다. 파일은 json형식으로 되어 있다. 이제 이 데이터를 전처리를 해야한다. 우선 데이터부터 원하는 형식으로 바꿔서 저장을 해야한다. koBERT는 정해진 데이터형식이 있기 때문에 text, label 형식으로 저장해야한다. 내가 필요한 것은 text, subject이다. 234,000개의 파일을 돌면서 json파일을 읽고, subject와 text를 뽑아서 정리해주었다. 이 부분 코드가 궁금하면 댓글이나 메일을 남기면 알려주겠지만 나보다 더 잘 코드를 작성할 것 같기도 하고, 로컬에서 작성한 코드라 회사에 있다.

id	text	label
0	너 옷 어디서 주로 구매해?	17
1	나 주로 인터넷으로 구매해	17
2	여성 전용 인터넷 쇼핑몰이 있니?	17
3	앱을 이용하지 지그재그나 옷 쇼핑몰 앱을 써	17
4	정말 편리하겠다 나도 추천해줘	17
5	왜? 여자 옷 입게?	17
6	아니 나 여자친구 있잖아	17
7	아 그게 왜? 선물하려고?	17
8	응 선물해주고 싶어서 키키	17
9	그러면 에이블리라는 앱을 깔아 봐 인기 있는 쇼핑몰 모아놓은 앱이야	17
10	알겠어 고마워 참고할게	17
11	그래 꼭 예쁜 옷 선물해 줘	17
12	그래 꼭 예쁜 옷 선물할게	17
13	아 그리고 가격 같은 세부 옵션도 설정할 수 있으니까 골라서 선택해서 찾아 봐	17
14	고마워 친절한 나의 벗아	17
15	그래 찾아보고 궁금한 거 있으면 또 물어봐	17
16	그러고 보니 나는 좋아하는 연예인이 없어	18
17	엥 그래? 눈이 가는 연예인도 없어?	18
18	응 티비를 안봐서 그런가?	18
19	아예 안 보는 거야?	18
20	응 아예 안 봐 너는 보니?	18
21	나는 가끔 보지 그럼 집에 티비가 없어?	18
22	아니 키키 있긴한데 분리 시켜 놨어	18
23	키키 진짜 안 보는구나	18
24	응 티비는 바보 상자야	18
25	왜 바보 상자야?	18
26	아무 생각 없이 멍하니 보게 되잖아	18
27	그런가 나는 재밌는 건 깔껄 거리면서 볼 수 있어서 좋던데	18
28	뭐 보고 재미를 느꼈니?	18
29	그냥 예능 프로그램 그리고 혼자 있을 때 틀어 놓으면 혼자 있는 거 같지 않고 좋아	18
30	그렇구나 그럼 적적하지는 않겠다	18
31	응 그런 거 같아	18
32	요즘 볼만한 애니가 없어	19
33	다 큰 어른이 무슨 애니를 보니	19
34	그런가 애니 말고 이제 영화나 볼까?	19
35	그래 애니 말고 영화나 봐ㅎ	19
36	알겠어 영화관 가서 영화 봐야겠다	19
37	근데 생각해 보니까 친구가 귀멸의 칼날 재밌다고 그러던데 그거 봐 봐	19

사실 여기 텍스트파일로 저장하기 전에 기타 처리를 해주었으면 되는데 텍스트파일로 만들고 나서 모델에 넣을때 오류와 정확도문제가 발견되어서 추가로 데이터 수정을 해주었다. 그것은 아래에 적겠다. 그리고 로컬에서 코드를 돌려보고 싶어서 로컬에서 진행하다가 mxnet과 glounnlp, numpy, windows문제로 인해 코랩으로 옮겨서 진행하였다. 그래서 일단은 코랩에서 진행하고, 모델을 학습시키고 방법은 추후에 로컬에서 적용시킬 수 있는 방법을 찾아보려 한다.

 

 

Google Colaboratory

 

colab.research.google.com

📌 먼저 코랩에 들어가서 런타임 > 런타임 유형 변경 > 하드웨어 가속기 "GPU"

!pip install mxnet
!pip install gluonnlp pandas tqdm 
!pip install sentencepiece==0.1.96
!pip install transformers==4.8.1
!pip install torch
#깃허브에서 KoBERT 파일 로드
!pip install git+https://git@github.com/SKTBrain/KoBERT.git@master

koBERT를 사용하기 위해 필요한 모듈들과 github에서 skt측이 공유한 koBERT 파일을 로드한다.

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
#transformers
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

from kobert import get_pytorch_kobert_model

#BERT 모델, Vocabulary 불러오기
bertmodel, vocab = get_pytorch_kobert_model()

#GPU 사용
device = torch.device("cuda")

CPU를 사용하면 너무 느리기 때문에 GPU를 써서 조금 빠르게 학습시키려고 한다. 하지만 GPU도 내가 느끼기에는 느리게 진행된다. TPU를 사용하면 더 빠르다는데 해볼 사람이 있으면 찾아서 해봐도 공부가 될 것 같다.  

 

 

18-01 코랩(Colab)에서 TPU 사용하기

지금까지는 GPU 사용만으로도 모델을 학습하는데 큰 무리가 없었지만, BERT의 경우 지금까지 사용한 모델보다 무거운 편입니다. 다시 말해 학습 속도가 상대적으로 느린 편입니다.…

wikidocs.net

from google.colab import drive
drive.mount('/content/drive')

구글 드라이브에 있는 데이터를 가지고 오기 위해 구글 드라이브를 연동시켜준다. 데이터를 직접 코랩에 넣어주어도 되지만 데이터를 구글 드라이브에 보관하고 있어서 연동만 하면 따로 불러오지 않아도 같은 폴더에 있는 것으로 인식하여 편하게 불러올 수 있다.

import pandas as pd

dataset_train_csv = pd.read_csv('/content/drive/MyDrive/text_category/train_data.txt', sep='\t').dropna(axis=0) 
dataset_test_csv = pd.read_csv('/content/drive/MyDrive/text_category/valid_data.txt', sep='\t').dropna(axis=0)

데이터 형식으로 바꿀때 pandas의 테이블형식을 사용하여 나도 그렇게 했다. 깊게 공부하지 않아서 다른 방법도 가능한지는 잘모른다.  텍스트파일을 저장할때 tab으로 구분지어서 구분자를 \t로 하여 csv파일을 만들어주었다. 처음에 dropna로 결측치를 제거하지 않아서 모델에서 에러가 생기고 찾는데 애를 먹었다. dropna를 통해 결측치를 제거하여 데이터를 가지고 온다.

dataset_train_csv.head()

# pandas dataframe을 확인해보았는데, 필요없이 id값이 있었고, label이 0~19의 int형이 되기를 원해서 처리를 해줬다.
dataset_train_csv = dataset_train_csv.drop(['id'], axis=1)
dataset_test_csv = dataset_test_csv.drop(['id'], axis=1)

dataset_train_csv['label'] = dataset_train_csv['label'].astype(np.int32)
dataset_test_csv['label'] = dataset_test_csv['label'].astype(np.int32)

 

# 어차피 리스트로 만들어 줄건데 왜 처리를 해줬는지 기억은 잘 나지 않는다.
dataset_train = []
for q, label in zip(dataset_train_csv['text'], dataset_train_csv['label']):
    data = []
    data.append(q)
    data.append(str(label))

    dataset_train.append(data)

dataset_test = []
for q, label in zip(dataset_test_csv['text'], dataset_test_csv['label']):
    data = []
    data.append(q)
    data.append(str(label))

    dataset_test.append(data)

 

# BERT모델에 넣을 데이터셋을 만들어줄 클래스이다. 위에서 설명한 transform형식으로 데이터셋을 바꿔준다.
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)
        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

 

# 파라미터 세팅을 해주었다. batch_size는 colab환경에 맞춰 설정해야한다.
# num_epochs는 몇번 반복학습 할지 정하는건데, 난 데이터가 200만개라 1 epochs가 4시간30분이 걸린다.
# 그래서 1번 돌리고 pt를 저장하고 퇴근하고 다음날 1번 돌리고 pt를 저장하는 방식으로 했다.

max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 1
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

 

#토큰화
from kobert import get_tokenizer

tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5)

 

class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=20,   ##모델의 마지막층의 class를 지정해줘야 한다. (카테고리개수)##
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

 

# BERT 모델 불러오기
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)

# optimizer와 schedule 설정
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

#정확도 측정을 위한 함수 정의
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc
    
train_dataloader

 

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

여기서 대략 4시간 30분이 걸렸다. 처음에는 5%가 나오고 1epoch가 끝나갈때쯤 48%까지 오르게 되었다. 약 80분 정도를 코랩에서 조작을 해주지 않으면 런타임을 끊어버린다. 그래서 script코드로 자동 조작을 해주거나 간간히 한번씩 만져주어야 한다. 

 

구글 코랩 (Google colab)의 런타임 연결 끊김을 방지하는 방법

구글 코랩 (Google colab)의 런타임 연결 끊김을 방지하는 방법에 대한 내용입니다.

teddylee777.github.io

대부분 블로그들은 여기서 끝내는 경우가 많았다. 여기서 이제 모델을 저장하고 불러와서 다시 쓰는 방법도 기술하겠다.

# 모델 state dict만 저장하기
torch.save(model.state_dict(), "/content/drive/MyDrive/text_category/model.pt")

3. 모델 테스트 해보기

이렇게 해서 모델을 저장한다. 그리고 모델을 쓰기 위해서는 아까 정의한 class 2개를 적어주어야 모델을 쓸 수 있다. 그리고 kobert기본 모델을 불러오고 그 모델에 학습시킨 가중치를 붙여주는것이다. 그러면 이제 우리가 학습시켜놓은 모델을 사용할 수 있게 된다.

# 학습시킨 모델이 있다면 불러오기
model.load_state_dict(torch.load("/content/drive/MyDrive/text_category/model.pt"))
category = { '식음료': 0, '주거와 생활': 1, '교통': 2, '회사/아르바이트': 3, '군대': 4, '교육': 5, '가족': 6, '연애/결혼': 7, '반려동물': 8, '스포츠/레저': 9, 
             '게임': 10, '여행': 11, '계절/날씨': 12, '사회이슈': 13, '타 국가 이슈': 14, '미용': 15, '건강': 16, '상거래전반': 17, '상거래 전반': 17, '방송/연예': 18,
             '영화/만화': 19 }

그리고 아까 깜빡하고 적지 못했는데 aihub에서 받은 데이터셋에서는 상거래전반, 상거래 전반으로 띄어쓰기가 서로 섞여있는 데이터셋이 있어서 그 2개를 같은 17라벨로 묶어주었다.

def predict(predict_sentence):
    data = [predict_sentence, '0']
    dataset_another = [data]
    another_test = BERTDataset(dataset_another, 0, 1, tok, max_len, True, False)
    test_dataloader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=5)

    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        valid_length= valid_length
        label = label.long().to(device)

        out = model(token_ids, valid_length, segment_ids)
        test_eval=[]
        for i in out:
            logits=i
            logits = logits.detach().cpu().numpy()
            test_eval.append(list(category.keys())[(np.argmax(logits))])

        print(">> 입력하신 내용은 '" + test_eval[0] + "'의 카테고리로 예측되었습니다.")
predict('오늘 밥을 무엇을 먹을지 고민되는걸?')

>> 입력하신 내용은 '식음료'의 카테고리로 예측되었습니다.

4. 느낀점 및 추가사항

48%의 정확도 치고는 꽤 자주 맞추었다. 현재 2epochs를 돌려놓고 블로그글을 작성하였는데 2번 돌린 결과는 54%까지 상승하였다. 앞으로 3번정도 더 돌려보고 60%정도까지 올려볼 생각이다. 모델을 이정도로 마무리하고 다른 모델들도 작업을 해야한다.

 

이렇게 채팅을 분류시키고 내가 적은 채팅으로 현재의 관심사를 알아내고 비슷한 관심사를 가진 모임을 추천해주게 할 생각이다.

728x90
반응형

'개발 > AI 코드' 카테고리의 다른 글

YOLO V8 detection 간단하게 사용하기  (0) 2023.07.06
로컬에서 BERT모델 돌려서 학습하기  (1) 2023.02.15

댓글