上一篇写到了成功获取了晶体中每个原子的磁矩,并能够将其添加入原子特征的第 93 维度,今天接着写我是如何一步一步把它导入模型中训练的。
我们缕一缕数据处理的流程:
从 MP 中获取 data 的 doc 以及 structure;
使用 CIFwriter 把 structure 写入 cif 文件,并且设置 write_magmoms=True;
把晶体的预测性质写入 id_prop.csv;
通过 CIFData 类读取 dataset;在生成 atom_fea 时将 magmon 写入第 93 维度;
通过 crystal_graph_list 函数生成 graph 数据。
首先获取数据 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 from mp_api.client import MPResterimport csvfrom tqdm import tqdmimport osfrom pymatgen.io.cif import CifWriter API_KEY = "" mpr = MPRester( API_KEY) M_docs = mpr.materials.summary.search( total_magnetization=(0.2 , None ), num_elements=(3 , 92 ), fields=["material_id" , "structure" , "is_magnetic" ], chunk_size=10 , num_chunks=100 ) NM_docs = mpr.materials.summary.search( total_magnetization=(None , 0.1 ), num_elements=(3 , 92 ), fields=["material_id" , "structure" , "is_magnetic" ], chunk_size=10 , num_chunks=100 )
这样严格区分开磁性晶体和非磁性,方便后续检查和确认数量。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 for entry in tqdm(M_docs): id = entry.material_id file_path = os.path.join('magmon_dataset_2000' , str (id ) + ".cif" ) structure = entry.structure cif_writer = CifWriter(structure, write_magmoms=True ) cif_writer.write_file(file_path)print ('磁性材料结构获取完成' )for entry in tqdm(NM_docs): id = entry.material_id file_path = os.path.join('magmon_dataset_2000' , str (id ) + ".cif" ) structure = entry.structure cif_writer = CifWriter(structure, write_magmoms=True ) cif_writer.write_file(file_path)print ('非磁性材料结构获取完成' )
最后生成 csv:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 with open ('magmon_dataset_2000/id_prop.csv' , 'w' , newline='' ) as csvfile: writer = csv.writer(csvfile) header = ['Material' , 'is_Magnetic' ] writer.writerow(header) for entry in tqdm(M_docs): if entry.is_magnetic: magnetic = 1 else : magnetic = 0 id = entry.material_id writer.writerow([id , magnetic]) print ('磁性材料写入完成!' ) for entry in tqdm(NM_docs): if entry.is_magnetic: magnetic = 1 else : magnetic = 0 id = entry.material_id writer.writerow([id , magnetic])print ("All done!" )
这样子生成的 cif 文件都是包含下面两个 loop 字段的,说明是写入了磁矩信息,在读取的时候就能读到每个原子的 magmon。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 loop_ _atom_site_type_symbol _atom_site_label _atom_site_symmetry_multiplicity _atom_site_fract_x _atom_site_fract_y _atom_site_fract_z _atom_site_occupancy Ac Ac0 1 -0.00000000 -0.00000000 -0.00000000 1 Cr Cr1 1 0.50000000 0.50000000 0.50000000 1 O O2 1 0.50000000 0.50000000 0.00000000 1 O O3 1 0.50000000 -0.00000000 0.50000000 1 O O4 1 -0.00000000 0.50000000 0.50000000 1 loop_ _atom_site_moment_label _atom_site_moment_crystalaxis_x _atom_site_moment_crystalaxis_y _atom_site_moment_crystalaxis_z Ac0 0.00000000 0.00000000 0.01500000 Cr1 0.00000000 0.00000000 2.71800000 O2 0.00000000 0.00000000 0.03800000 O3 0.00000000 0.00000000 0.03800000 O4 0.00000000 0.00000000 0.03800000
接着对数据进行读取 按照上述步骤,在这里需要依次进行 CIFData 的读取和 graph 数据生成
1 2 3 4 5 dataset = "magmon_dataset_2000" data = CIFData(dataset, target_name='is_Magnetic' ) graph = crystal_graph_list(data)print ('这是图数据' )print (data[0 ])
但是这里出现了严重错误
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 ZeroDivisionError Traceback (most recent call last) Cell In[23], line 5 3 #graph = crystal_graph_list(data) 4 print('这是图数据') ----> 5 print(data[0]) File e:\WYX_Project\CSAT_demo_2\test_view\csat\crystal_data.py:223, in CIFData.__getitem__(self, idx) 220 if not target_value: 221 raise ValueError(f"Empty target value for {self.target_name} in {cif_id}") --> 223 crystal = Structure.from_file(os.path.join(self.root_dir, 224 cif_id + '.cif')) 226 if self.target_name == 'is_Magnetic': 227 # 如果目标是磁矩,则将原子特征扩展为包含磁矩 228 atom_fea_list = [] File e:\anaconda3\envs\pytorch\Lib\site-packages\pymatgen\core\structure.py:3202, in IStructure.from_file(cls, filename, primitive, sort, merge_tol, **kwargs) 3200 contents: str = file.read() # type:ignore[assignment] 3201 if fnmatch(fname.lower(), "*.cif*") or fnmatch(fname.lower(), "*.mcif*"): -> 3202 return cls.from_str( 3203 contents, 3204 fmt="cif", 3205 primitive=primitive, 3206 sort=sort, 3207 merge_tol=merge_tol, ... --> 230 if len(items) % n != 0: 231 raise ValueError(f"{len(items)=} is not a multiple of {n=}") 232 loops.append(columns) ZeroDivisionError: integer modulo by zero
这个报错的意思是:
<font style="color:rgb(0, 0, 0);">items</font> 是 CIF 文件中的数据项列表
<font style="color:rgb(0, 0, 0);">n</font> 应该是数据列的数量
当 <font style="color:rgb(0, 0, 0);">n</font> 为 0 时,尝试计算 <font style="color:rgb(0, 0, 0);">len(items) % 0</font> 导致除以零错误
这通常是由于 CIF 文件格式问题导致的:
文件可能缺少必要的列定义
数据行数量与列定义不匹配
特殊字符或格式问题导致解析失败
错误出现在 crystal = Structure.from_file(os.path.join(self.root_dir,......就算说,在读 cif 文件是出错了,那就奇怪了,我们在上一篇中读取了好几次 cif 文件都没有任何问题,究竟是怎么回事呢?
于是,我就打开了几个 cif 文件看看,发现在非磁性晶体中存在这种现象:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 loop_ _atom_site_type_symbol _atom_site_label _atom_site_symmetry_multiplicity _atom_site_fract_x _atom_site_fract_y _atom_site_fract_z _atom_site_occupancy Ac Ac0 1 0.50000000 0.50000000 0.50000000 1 Ac Ac1 1 -0.00000000 -0.00000000 -0.00000000 1 Ag Ag2 1 0.25000100 0.25000100 0.25000100 1 Ir Ir3 1 0.75000100 0.75000100 0.75000100 1 loop_ _atom_site_moment_label _atom_site_moment_crystalaxis_x _atom_site_moment_crystalaxis_y _atom_site_moment_crystalaxis_z
他写入了磁矩的适量坐标系,但是没有磁矩信息,最终导致在读取 cif 文件时候,出现 item % 0,因为没有数据,为了验证这一点,我写了如下代码测试:
1 2 3 4 from csat.crystal_data import CIFData, crystal_graph_list from pymatgen.core.structure import Structure data = CIFData("NM", target_name='is_Magnetic') graph = crystal_graph_list(data)
我把一个非磁性晶体的 cif 文件单独拉出来读取,发现只有把
1 2 3 4 5 loop_ _atom_site_moment_label _atom_site_moment_crystalaxis_x _atom_site_moment_crystalaxis_y _atom_site_moment_crystalaxis_z
这一段删除之后才能够读取成功,说明我的推断没有错,于是我需要重新调整数据的获取方式
重新获取数据 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 for entry in tqdm(M_docs): id = entry.material_id file_path = os.path.join('magmon_dataset_2000', str(id) + ".cif") structure = entry.structure cif_writer = CifWriter(structure, write_magmoms=True) cif_writer.write_file(file_path) print('磁性材料结构获取完成') for entry in tqdm(NM_docs): id = entry.material_id file_path = os.path.join('magmon_dataset_2000', str(id) + ".cif") structure = entry.structure cif_writer = CifWriter(structure) # !!!!!非磁性材料不需要写入磁矩!!!!!!否则会出错 cif_writer.write_file(file_path) print('非磁性材料结构获取完成')
只要把非磁性晶体的磁矩矢量坐标系删掉就行了,重新测试后完全没有问题了。
再读数据 经过以上重写数据之后,对磁性和非磁性的 cif 文件有了严格区分,于是我们在获取 crystal 的 structure 数据时也需要重新分类读取:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 if self .target_name == 'is_Magnetic' : atom_fea_list = [] for i in range (len (crystal)): base_fea = self .ari.get_atom_fea(crystal[i].specie.number) if target_value==1 : magmom = crystal.site_properties['magmom' ][i] else : magmom = 0.0 magmom = float (magmom) extended_fea = np.append(base_fea, magmom) atom_fea_list.append(extended_fea) atom_fea = np.vstack(atom_fea_list)else : atom_fea = np.vstack([self .ari.get_atom_fea(crystal[i].specie.number) for i in range (len (crystal))])
在这里我用 target_value 来区分,因为我们的任务是分类任务,分类的标签就是是否是磁性,正好利用这个标签来区分磁性和非磁性,然后为了保持维度一致,只需要在非磁性原子的磁矩维度上写 0 就行了。
接下来,测试一下:
1 2 3 4 5 6 7 8 9 10 from test_view.csat.crystal_data import CIFData, crystal_graph_list data = CIFData("magmon_dataset_2000" , target_name='is_Magnetic' ) graph = crystal_graph_list(data)print ('这是图数据' )print (graph[0 ])''' 这是图数据 Data(x=[4, 93], edge_index=[2, 12], edge_attr=[12, 41], y=[1], id='mp-861724') '''
结果表明很成功,就可以加入训练了。
训练 训练就很简单了,只要修改一点点超参数就行了
1 2 3 4 5 6 7 8 9 self .task: str = 'classification' self .num_classes: int = 2 self .data_root: str = 'magmon_dataset_2000' self .target: str = 'is_Magnetic' self .input_dim: int = 93