当前位置: 首页 > 科技观察

通过PyTorch创建用于文本分类的Bert模型

时间:2023-03-19 10:55:52 科技观察

2018年,谷歌发表了一篇名为《Pre-training of deep bidirectional Transformers for Language Understanding》的论文。在本文中,我们介绍了一种称为BERT(BidirectionalEncoderRepresentationwithTransformers)的语言模型,该模型在问答、自然语言推理、分类和通用语言理解评估(GLUE)等任务中取得了成果。最先进的性能。BERT代表BidirectionalEncoderRepresentationfromTransformers[1],是一种语言表示的预训练模型。它基于谷歌2017年发布的Transformer架构。通常的Transformer使用一套编码器和解码器网络,而BERT只需要额外的输出层和微调预训练即可满足各种任务。没有必要为特定任务修改模型。BERT将多个Transformer编码器堆叠在一起。Transformer基于著名的Multi-headAttention模块,在视觉和语言任务上都取得了巨大的成功。在本文中,我们将使用PyTorch创建用于文本分类的Bert模型。笔者今天介绍一个python库---simpletransformers,可以解决高级预训练语言模型难用的问题。simpletransformers可以轻松训练、评估和预测高级预训练模型(BERT、RoBERTa、XLNet、XLM、DistilBERT、ALBERT、CamemBERT、XLM-RoBERTa、FlauBERT),每行仅3行来初始化模型。数据集来源:https://www.kaggle.com/jrobischon/wikipedia-movie-plots该数据集包含来自世界各地的34,886部电影的描述。各列说明如下:ReleaseYear:电影上映的年份Title:电影片名Origin:电影的起源地(即美国,宝莱坞,泰米尔等),json,gc,re,randomfromtqdm.notebookimporttqdmimporttorch,transformers,tokenizermovies_df=pd.read_csv("wiki_movie_plots_deduped.csv")fromsklearn.preprocessingimportLabelEncodermovies_df=movies_df[(movies_df["Origin/Ethnicity"]=="American")|[(movies_df)"来源/种族"]=="英国")]movies_df=movies_df[["情节","流派"]]drop_indices=movies_df[movies_df["流派"]=="未知"].indexmovies_df.drop(drop_indices,inplace=True)#Combinegenres:1)"sci-fi"与"sciencefiction"&2)"romanticcomedy"与"romance"movies_df["Genre"].replace({"sci-fi":"sciencefiction","romanticcomedy":"romance"},inplace=True)#根据频率shortli选择电影类型sted_genres=movies_df["Genre"].value_counts().reset_index(name="count").query("count>200")["index"].tolist()movies_df=movies_df[movies_df["Genre"].isin(shortlisted_genres)].reset_index(drop=True)#Shufflemovies_df=movies_df.sample(frac=1).reset_index(drop=True)#从不同类型中抽取大致相同数量的电影情节样本(以减少类不平衡问题)movies_df=movies_df.groupby("Genre").head(400).reset_index(drop=True)label_encoder=LabelEncoder()movies_df["genre_encoded"]=label_encoder.fit_transform(movies_df["Genre"].tolist())movies_df=movies_df[["Plot","Genre","genre_encoded"]]movies_df使用torch加载BERT模型,最简单的方法是使用SimpleTransformers库,这样初始化只需要3行代码,在一个给定数据集在给定数据集上训练和评估Transformer模型fromsimpletransformers.classificationimportClassificationModel#模型参数model_args={"reprocess_input_data":True,"overwrite_output_dir":True,"save_model_every_epoch":False,"save_eval_checkpoints":False,"max_seq_length":512,"train_batch_size":16,"num_train_epochs":4,}#CreateaClassificationModelmodel=ClassificationModel('bert','bert-base-cased',num_labels=len(shortlisted_genres),args=model_args)训练模型train_df,eval_df=train_test_split(movies_df,test_size=0.2,stratify=movies_df["Genre"],random_state=42)#Trainthemodelmodel.train_model(train_df[["Plot","genre_encoded"]])#Evaluatethemodelresult,model_outputs,wrong_predictions=model.eval_model(eval_df[["Plot","genre_encoded"]])print(result){'mcc':0.5299659404649717,'eval_loss':1.4970421879083518}CPUtimes:user19min1s,sys:4.95s,total:19min6sWalltime:20min14s关于simpletransformers的官方文档:https://simpletransformers.ai/docsGithub链接:https://github.com/ThilinaRajapakse/simpletransformers