训练带有磁矩的数据

上一篇写到了成功获取了晶体中每个原子的磁矩,并能够将其添加入原子特征的第 93 维度,今天接着写我是如何一步一步把它导入模型中训练的。

我们缕一缕数据处理的流程:

  1. 从 MP 中获取 data 的 doc 以及 structure;
  2. 使用 CIFwriter 把 structure 写入 cif 文件,并且设置 write_magmoms=True;
  3. 把晶体的预测性质写入 id_prop.csv;
  4. 通过 CIFData 类读取 dataset;在生成 atom_fea 时将 magmon 写入第 93 维度;
  5. 通过 crystal_graph_list 函数生成 graph 数据。

首先获取数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from mp_api.client import MPRester
import csv
from tqdm import tqdm
import os
from pymatgen.io.cif import CifWriter
API_KEY = ""
mpr = MPRester(
API_KEY)
M_docs = mpr.materials.summary.search(
total_magnetization=(0.2, None), # 总磁化强度 > 0.2 μB/unit cell
num_elements=(3, 92), # 限制元素数量
fields=["material_id", "structure", "is_magnetic"],
chunk_size=10,
num_chunks=100
)
NM_docs = mpr.materials.summary.search(
total_magnetization=(None, 0.1), # 总磁化强度 < 0.1 μB/unit cell
num_elements=(3, 92), # 限制元素数量
fields=["material_id", "structure", "is_magnetic"],
chunk_size=10,
num_chunks=100
)

这样严格区分开磁性晶体和非磁性,方便后续检查和确认数量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
for entry in tqdm(M_docs):
id = entry.material_id
file_path = os.path.join('magmon_dataset_2000', str(id) + ".cif")
structure = entry.structure
# 再写入时一定要注意,将写入磁矩设置为Ture
cif_writer = CifWriter(structure, write_magmoms=True)
cif_writer.write_file(file_path)
print('磁性材料结构获取完成')
for entry in tqdm(NM_docs):
id = entry.material_id
file_path = os.path.join('magmon_dataset_2000', str(id) + ".cif")
structure = entry.structure
cif_writer = CifWriter(structure, write_magmoms=True)
cif_writer.write_file(file_path)
print('非磁性材料结构获取完成')

最后生成 csv:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
with open('magmon_dataset_2000/id_prop.csv', 'w', newline='') as csvfile:
# 创建一个csv写入器
writer = csv.writer(csvfile)
header = ['Material', 'is_Magnetic']
writer.writerow(header)
for entry in tqdm(M_docs):
# 遍历列表,写入每一行数据
if entry.is_magnetic:
magnetic = 1
else:
magnetic = 0
id = entry.material_id
writer.writerow([id, magnetic])
print('磁性材料写入完成!')

for entry in tqdm(NM_docs):
# 遍历列表,写入每一行数据
if entry.is_magnetic:
magnetic = 1
else:
magnetic = 0
id = entry.material_id
writer.writerow([id, magnetic])
print("All done!")

这样子生成的 cif 文件都是包含下面两个 loop 字段的,说明是写入了磁矩信息,在读取的时候就能读到每个原子的 magmon。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ac Ac0 1 -0.00000000 -0.00000000 -0.00000000 1
Cr Cr1 1 0.50000000 0.50000000 0.50000000 1
O O2 1 0.50000000 0.50000000 0.00000000 1
O O3 1 0.50000000 -0.00000000 0.50000000 1
O O4 1 -0.00000000 0.50000000 0.50000000 1
loop_
_atom_site_moment_label
_atom_site_moment_crystalaxis_x
_atom_site_moment_crystalaxis_y
_atom_site_moment_crystalaxis_z
Ac0 0.00000000 0.00000000 0.01500000
Cr1 0.00000000 0.00000000 2.71800000
O2 0.00000000 0.00000000 0.03800000
O3 0.00000000 0.00000000 0.03800000
O4 0.00000000 0.00000000 0.03800000

接着对数据进行读取

按照上述步骤,在这里需要依次进行 CIFData 的读取和 graph 数据生成

1
2
3
4
5
dataset = "magmon_dataset_2000"
data = CIFData(dataset, target_name='is_Magnetic')
graph = crystal_graph_list(data)
print('这是图数据')
print(data[0])

但是这里出现了严重错误

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
ZeroDivisionError                         Traceback (most recent call last)
Cell In[23], line 5
3 #graph = crystal_graph_list(data)
4 print('这是图数据')
----> 5 print(data[0])

File e:\WYX_Project\CSAT_demo_2\test_view\csat\crystal_data.py:223, in CIFData.__getitem__(self, idx)
220 if not target_value:
221 raise ValueError(f"Empty target value for {self.target_name} in {cif_id}")
--> 223 crystal = Structure.from_file(os.path.join(self.root_dir,
224 cif_id + '.cif'))
226 if self.target_name == 'is_Magnetic':
227 # 如果目标是磁矩,则将原子特征扩展为包含磁矩
228 atom_fea_list = []

File e:\anaconda3\envs\pytorch\Lib\site-packages\pymatgen\core\structure.py:3202, in IStructure.from_file(cls, filename, primitive, sort, merge_tol, **kwargs)
3200 contents: str = file.read() # type:ignore[assignment]
3201 if fnmatch(fname.lower(), "*.cif*") or fnmatch(fname.lower(), "*.mcif*"):
-> 3202 return cls.from_str(
3203 contents,
3204 fmt="cif",
3205 primitive=primitive,
3206 sort=sort,
3207 merge_tol=merge_tol,
...
--> 230 if len(items) % n != 0:
231 raise ValueError(f"{len(items)=} is not a multiple of {n=}")
232 loops.append(columns)

ZeroDivisionError: integer modulo by zero

这个报错的意思是:

  1. <font style="color:rgb(0, 0, 0);">items</font> 是 CIF 文件中的数据项列表
  2. <font style="color:rgb(0, 0, 0);">n</font> 应该是数据列的数量
  3. <font style="color:rgb(0, 0, 0);">n</font> 为 0 时,尝试计算 <font style="color:rgb(0, 0, 0);">len(items) % 0</font> 导致除以零错误

这通常是由于 CIF 文件格式问题导致的:

  • 文件可能缺少必要的列定义
  • 数据行数量与列定义不匹配
  • 特殊字符或格式问题导致解析失败

错误出现在 crystal = Structure.from_file(os.path.join(self.root_dir,......就算说,在读 cif 文件是出错了,那就奇怪了,我们在上一篇中读取了好几次 cif 文件都没有任何问题,究竟是怎么回事呢?

于是,我就打开了几个 cif 文件看看,发现在非磁性晶体中存在这种现象:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
loop_
_atom_site_type_symbol
_atom_site_label
_atom_site_symmetry_multiplicity
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
Ac Ac0 1 0.50000000 0.50000000 0.50000000 1
Ac Ac1 1 -0.00000000 -0.00000000 -0.00000000 1
Ag Ag2 1 0.25000100 0.25000100 0.25000100 1
Ir Ir3 1 0.75000100 0.75000100 0.75000100 1
loop_
_atom_site_moment_label
_atom_site_moment_crystalaxis_x
_atom_site_moment_crystalaxis_y
_atom_site_moment_crystalaxis_z

他写入了磁矩的适量坐标系,但是没有磁矩信息,最终导致在读取 cif 文件时候,出现 item % 0,因为没有数据,为了验证这一点,我写了如下代码测试:

1
2
3
4
from csat.crystal_data import CIFData, crystal_graph_list
from pymatgen.core.structure import Structure
data = CIFData("NM", target_name='is_Magnetic')
graph = crystal_graph_list(data)

我把一个非磁性晶体的 cif 文件单独拉出来读取,发现只有把

1
2
3
4
5
loop_
_atom_site_moment_label
_atom_site_moment_crystalaxis_x
_atom_site_moment_crystalaxis_y
_atom_site_moment_crystalaxis_z

这一段删除之后才能够读取成功,说明我的推断没有错,于是我需要重新调整数据的获取方式

重新获取数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
for entry in tqdm(M_docs):
id = entry.material_id
file_path = os.path.join('magmon_dataset_2000', str(id) + ".cif")
structure = entry.structure
cif_writer = CifWriter(structure, write_magmoms=True)
cif_writer.write_file(file_path)
print('磁性材料结构获取完成')

for entry in tqdm(NM_docs):
id = entry.material_id
file_path = os.path.join('magmon_dataset_2000', str(id) + ".cif")
structure = entry.structure
cif_writer = CifWriter(structure) # !!!!!非磁性材料不需要写入磁矩!!!!!!否则会出错
cif_writer.write_file(file_path)
print('非磁性材料结构获取完成')

只要把非磁性晶体的磁矩矢量坐标系删掉就行了,重新测试后完全没有问题了。

再读数据

经过以上重写数据之后,对磁性和非磁性的 cif 文件有了严格区分,于是我们在获取 crystal 的 structure 数据时也需要重新分类读取:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
if self.target_name == 'is_Magnetic':
# 如果目标是磁矩,则将原子特征扩展为包含磁矩
atom_fea_list = []
for i in range(len(crystal)):
base_fea = self.ari.get_atom_fea(crystal[i].specie.number)
if target_value==1:
magmom = crystal.site_properties['magmom'][i] # 添加磁矩
else:
magmom = 0.0
magmom = float(magmom) # 确保磁矩是浮点数类型
# print(f"原子 {crystal[i].specie.symbol} 的磁矩: {magmom}")
# 将磁矩作为新特征追加
extended_fea = np.append(base_fea, magmom)
atom_fea_list.append(extended_fea)
atom_fea = np.vstack(atom_fea_list)
else:
# 如果目标不是磁矩,则只使用原子特征
atom_fea = np.vstack([self.ari.get_atom_fea(crystal[i].specie.number)
for i in range(len(crystal))])

在这里我用 target_value 来区分,因为我们的任务是分类任务,分类的标签就是是否是磁性,正好利用这个标签来区分磁性和非磁性,然后为了保持维度一致,只需要在非磁性原子的磁矩维度上写 0 就行了。

接下来,测试一下:

1
2
3
4
5
6
7
8
9
10
from test_view.csat.crystal_data import CIFData, crystal_graph_list
data = CIFData("magmon_dataset_2000", target_name='is_Magnetic')
graph = crystal_graph_list(data)
print('这是图数据')
print(graph[0])

'''
这是图数据
Data(x=[4, 93], edge_index=[2, 12], edge_attr=[12, 41], y=[1], id='mp-861724')
'''

结果表明很成功,就可以加入训练了。

训练

训练就很简单了,只要修改一点点超参数就行了

1
2
3
4
5
6
7
8
9
self.task: str = 'classification'  # regression/classification
self.num_classes: int = 2

# 数据参数
self.data_root: str = 'magmon_dataset_2000' # 数据集根目录
self.target: str = 'is_Magnetic'

# 模型参数
self.input_dim: int = 93

训练带有磁矩的数据
http://baikelwang.github.io/2025/07/09/训练带有磁矩的数据/
作者
北海
发布于
2025年7月9日
许可协议