Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
lichenggang
/
bom_identify
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit
1231ed2a
authored
May 29, 2020
by
lichenggang
Browse files
Options
_('Browse Files')
Download
Email Patches
Plain Diff
综合model和head预测, 还差位号的训练
parent
94f90604
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
47 deletions
predic_fac.py
predict/base_handler.py
predict/dict_predict.py
predic_fac.py
View file @
1231ed2a
...
...
@@ -29,7 +29,7 @@ if __name__ == "__main__":
def
get_test_data
():
import
pandas
as
pd
import
json
df
=
pd
.
read_excel
(
r'C:\Users\ICHUNT\Desktop\bomlist\51AB0571_ CCTV ASST询价_SZIMS.xlsx'
,
header
=
None
,
sheet_name
=
'
1
'
)
df
=
pd
.
read_excel
(
r'C:\Users\ICHUNT\Desktop\bomlist\51AB0571_ CCTV ASST询价_SZIMS.xlsx'
,
header
=
None
,
sheet_name
=
'
3
'
)
df
.
fillna
(
'?'
,
inplace
=
True
)
dic_dft
=
df
.
to_dict
(
orient
=
'list'
)
return
json
.
dumps
(
dic_dft
)
...
...
@@ -40,5 +40,6 @@ if __name__ == "__main__":
data
=
get_test_data
()
p
=
PredictorFac
(
model_config
)
data
=
json
.
loads
(
data
)
res
=
p
.
predict
(
data
,
predict_type
=
'model'
)
print
(
data
)
res
=
p
.
predict
(
data
,
predict_type
=
'all'
)
print
(
res
)
predict/base_handler.py
View file @
1231ed2a
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import
re
from
collections
import
Counter
from
utils.log_manager
import
get_logger
from
utils.robots
import
dd_send_msg
import
pandas
as
pd
CATEGORY
=
[
'二极管'
]
RIGHT_LEVEL
=
0.7
SEQ_LEVEL
=
0.5
CATE_LEVEL
=
0.5
MULTI_SAME_LEVEL
=
3
class
BasePredictor
:
'''
...
...
@@ -23,6 +25,8 @@ class BasePredictor:
self
.
info
=
get_logger
(
self
.
name
+
'_info'
)
self
.
robot_msg
=
dd_send_msg
self
.
pd
=
pd
# 占位符
self
.
PLACEHOLDER
=
'?'
def
predict
(
self
,
key
):
raise
NotImplementedError
...
...
@@ -43,7 +47,7 @@ class BasePredictor:
dic_proba
=
{
k
:
v
for
k
,
v
in
zip
(
classes
,
deal_list
)}
return
predictions
[
0
],
dic_proba
def
isseq
(
self
,
data
):
def
is
_
seq
(
self
,
data
):
"""
序号列预测
"""
...
...
@@ -54,7 +58,7 @@ class BasePredictor:
rate
=
round
(
len
(
collect_seq
)
/
len
(
data
),
3
)
return
True
if
rate
>=
SEQ_LEVEL
else
False
def
isnum
(
self
,
data
):
def
is
_
num
(
self
,
data
):
"""
数量列预测
"""
...
...
@@ -72,16 +76,34 @@ class BasePredictor:
def
isNumberCol
(
self
,
kw
):
"""
是否是数量列
是否是数量列
辅助函数
"""
if
isinstance
(
kw
,
str
):
return
re
.
match
(
r'(\d+)((K)|([\u4E00-\u9FA5]{1,3}))$'
,
kw
,
re
.
M
|
re
.
I
)
else
:
return
False
def
is_catecol
(
self
,
data
):
cates
=
[]
for
i
in
data
:
for
j
in
CATEGORY
:
if
j
in
str
(
i
):
cates
.
append
(
i
)
rate
=
round
(
len
(
cates
)
/
len
(
data
),
3
)
return
rate
>=
CATE_LEVEL
def
is_multi_same
(
self
,
data
):
no_null_data
=
list
(
filter
(
lambda
x
:
x
!=
self
.
PLACEHOLDER
,
data
))
result
=
Counter
(
no_null_data
)
li_sort
=
sorted
(
result
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
return
li_sort
[
0
][
1
]
>=
MULTI_SAME_LEVEL
def
valid_seq
(
self
,
data
):
"""取反"""
return
not
self
.
isseq
(
data
)
return
not
self
.
is
_
seq
(
data
)
def
valid_num
(
self
,
data
):
return
not
self
.
isnum
(
data
)
return
not
self
.
is_num
(
data
)
def
valid_cate
(
self
,
data
):
return
not
self
.
is_catecol
(
data
)
\ No newline at end of file
predict/dict_predict.py
View file @
1231ed2a
...
...
@@ -5,19 +5,19 @@ from collections import Counter
from
predict.base_handler
import
BasePredictor
# 可能的头部字段
prob_fields
=
[
"序号"
,
"名称"
,
"规格"
,
"MPN"
,
"用量(pcs)"
,
"用量"
,
"pcs"
,
"位号"
,
"描述"
,
"值"
,
"数量"
,
"封装"
,
"类别"
,
"a面位置"
,
"b面位置"
,
"备注"
,
"需求数量"
,
"售价"
,
PROB_FIELDS
=
[
"序号"
,
"名称"
,
"规格"
,
"MPN"
,
"用量(pcs)"
,
"用量"
,
"pcs"
,
"位号"
,
"描述"
,
"值"
,
"数量"
,
"封装"
,
"类别"
,
"a面位置"
,
"b面位置"
,
"备注"
,
"需求数量"
,
"售价"
,
"封装"
,
"封装规格"
,
"参考品牌"
,
"品牌"
,
"item"
,
"厂商编码"
,
"品牌/厂商"
,
"参考料号"
,
"参考供应商"
,
"top面"
,
"bottom面"
]
# 标准名和代名词的映射
fields_map
=
{
"序号"
:
[
"序号"
],
STD_FIELDS_MAP
=
{
"类别"
:
[
"类别"
,
"分类"
,
"名称"
,
"类别名称"
],
"参数"
:
[
"参数"
,
"规格"
,
"描述"
],
"参数"
:
[
"参数"
,
"规格"
,
"描述"
,
"值"
],
"型号"
:
[
"型号"
,
"参考料号"
,
"料号"
,
"MPN"
,
"厂商编码"
],
"数量"
:
[
"数量"
,
"用量(pcs)"
,
"PCS"
,
"用量"
,
"用量(PCS)"
,
"pcs"
],
"封装"
:
[
"封装"
,
"封装规格"
],
"品牌"
:
[
"品牌"
,
"品牌/厂商"
,
"参考品牌"
,
"参考供应商"
,
"厂商"
,
"参考供应商"
,
"参考厂商"
]}
# 必须返回也必须验证的标准字段
MUST_STD_FIELDS
=
[
'参数'
,
'数量'
]
#
...
...
@@ -28,16 +28,19 @@ en_to_zh_map = {'brand_name': '品牌', 'param': '参数', 'goods_name': '型号
def
fun
(
seri
):
li_seri
=
seri
.
tolist
()
for
field
in
li_seri
:
if
str
(
field
)
.
lower
()
in
prob_fields
:
if
str
(
field
)
.
lower
()
in
PROB_FIELDS
:
return
field
,
seri
.
name
def
repeat_max
(
li
):
result
=
Counter
(
li
)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)]
li_sort
=
sorted
(
result
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
return
li_sort
[
0
][
0
]
# 取前多少行
HEAD_ROW
=
7
# 空置率阈值
NAN_RATE
=
0.8
# 占位符
PLACEHOLDER
=
'?'
class
DicPredict
(
BasePredictor
):
...
...
@@ -57,14 +60,14 @@ class DicPredict(BasePredictor):
'pronoun'
:
field
,
'column_name'
:
column_name
}
for
k
,
v
in
fields_map
.
items
():
if
field
.
lower
()
in
fields_map
[
k
]:
for
k
,
v
in
STD_FIELDS_MAP
.
items
():
if
field
.
lower
()
in
STD_FIELDS_MAP
[
k
]:
dic
[
'std_name'
]
=
k
li_res
.
append
(
dic
)
return
li_res
def
pre
_predict
(
self
,
dict_data
):
def
head
_predict
(
self
,
dict_data
):
columns
=
[]
li_data
=
[]
...
...
@@ -76,15 +79,13 @@ class DicPredict(BasePredictor):
dft
=
df
.
T
.
head
(
HEAD_ROW
)
dft
.
columns
=
columns
li_res_raw
=
self
.
id_by_field
(
dft
)
std_result
=
[]
ab_result
=
[]
std_result
=
{}
ab_result
=
{}
for
i
in
li_res_raw
:
if
i
.
get
(
'std_name'
):
dic_has_res
=
{
i
[
'column_name'
]:
i
[
'std_name'
]}
std_result
.
append
(
dic_has_res
)
std_result
[
i
[
'column_name'
]]
=
i
[
'std_name'
]
else
:
dic_ab_res
=
{
i
[
'column_name'
]:
i
[
'pronoun'
]}
ab_result
.
append
(
dic_ab_res
)
ab_result
[
i
[
'column_name'
]]
=
i
[
'pronoun'
]
pre_id_res
=
{
'std_result'
:
std_result
,
'ab_result'
:
ab_result
,
...
...
@@ -93,53 +94,73 @@ class DicPredict(BasePredictor):
def
model_predict
(
self
,
dic_data
):
"""
该方法目前只对[序号
, 参数, 型号, 数量, 品牌]进行预测, 只有[参数, 数量]会强制有结果
该方法目前只对[序号
(非标准), 数量, 类别, 参数, 型号, 品牌]进行预测, 前三者是非模型预测
:param dic_data:
:return
:
:return
:只有[参数, 数量]会强制有结果
"""
prob_columns
=
[]
temp_pre_model_res
=
{}
ab_result
=
{}
for
k
,
v
in
dic_data
.
items
():
bol
=
self
.
v_chain
(
v
)
if
bol
:
print
(
k
,
bol
)
prob_columns
.
append
(
k
)
continue
if
self
.
isnum
(
v
):
if
self
.
is_seq
(
v
):
ab_result
[
k
]
=
'序号'
continue
if
self
.
is_num
(
v
):
temp_pre_model_res
[
k
]
=
'数量'
continue
if
self
.
isseq
(
v
):
temp_pre_model_res
[
k
]
=
'序号'
temp_dic_data
=
{
k
:
list
(
filter
(
lambda
x
:
x
!=
PLACEHOLDER
,
dic_data
[
k
]))
for
k
in
prob_columns
}
if
self
.
is_catecol
(
v
):
temp_pre_model_res
[
k
]
=
'类别'
continue
temp_dic_data
=
{
k
:
list
(
filter
(
lambda
x
:
x
!=
self
.
PLACEHOLDER
,
dic_data
[
k
]))
for
k
in
prob_columns
}
for
k
,
v
in
temp_dic_data
.
items
():
li_single_pred_res
=
[]
for
string
in
v
:
single_pred_res
,
probdic
=
self
.
get_single_predict
(
string
)
li_single_pred_res
.
append
(
single_pred_res
)
result
=
Counter
(
li_single_pred_res
)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)]
li_sort
=
sorted
(
result
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
temp_pre_model_res
[
k
]
=
en_to_zh_map
[
li_sort
[
0
][
0
]]
sort_temp_pre_model_res
=
sorted
(
temp_pre_model_res
.
keys
())
# pre_model_res = {}
result
=
repeat_max
(
li_single_pred_res
)
temp_pre_model_res
[
k
]
=
en_to_zh_map
[
result
]
# 参数和型号列出现多条相同值则丢弃
prob_param_and_gn_cols
=
[
i
for
i
in
temp_pre_model_res
if
temp_pre_model_res
[
i
]
==
'参数'
or
temp_pre_model_res
[
i
]
==
'型号'
]
for
col
in
prob_param_and_gn_cols
:
if
self
.
is_multi_same
(
temp_dic_data
[
col
]):
temp_pre_model_res
.
pop
(
col
)
model_id_res
=
{
'std_result'
:
temp_pre_model_res
,
'ab_result'
:
ab_result
,
}
return
model_id_res
def
predict
(
self
,
dic_data
,
predict_type
=
'all'
):
dic_data
=
self
.
pre_deal
(
dic_data
)
if
predict_type
==
'all'
:
pre_id_res
=
self
.
pre_predict
(
dic_data
)
if
pre_id_res
:
return
pre_id_res
pre_id_res
=
self
.
head_predict
(
dic_data
)
model_id_res
=
self
.
model_predict
(
dic_data
)
# 表头预测和模型预测最后返回的数据进行综合处理
pre_std_result
=
pre_id_res
.
get
(
'std_result'
)
model_std_result
=
model_id_res
.
get
(
'std_result'
)
comprehensive_res
=
self
.
get_comprehensive_res
(
pre_std_result
,
model_std_result
)
if
comprehensive_res
:
res
=
{
'std_result'
:
comprehensive_res
,
'ab_result'
:
pre_id_res
[
'ab_result'
],
}
return
res
elif
predict_type
==
'model'
:
model_id_res
=
self
.
model_predict
(
dic_data
)
if
model_id_res
:
return
model_id_res
elif
predict_type
==
'
pre
'
:
pre_id_res
=
self
.
pre
_predict
(
dic_data
)
elif
predict_type
==
'
head
'
:
pre_id_res
=
self
.
head
_predict
(
dic_data
)
if
pre_id_res
:
return
pre_id_res
...
...
@@ -149,7 +170,7 @@ class DicPredict(BasePredictor):
# 去掉空置率大于等于0.8的列
counter
=
0
for
item
in
v
:
if
str
(
item
)
.
strip
()
==
PLACEHOLDER
:
if
str
(
item
)
.
strip
()
==
self
.
PLACEHOLDER
:
counter
+=
1
if
counter
/
len
(
v
)
<=
NAN_RATE
:
new_dic_data
[
k
]
=
v
...
...
@@ -169,3 +190,24 @@ class DicPredict(BasePredictor):
return
False
else
:
return
True
def
get_comprehensive_res
(
self
,
pre_std_result
,
model_std_result
):
vote_count
=
{
"类别"
:
[],
"参数"
:
[],
"型号"
:
[],
"数量"
:
[],
"品牌"
:
[]
}
for
k
,
v
in
pre_std_result
.
items
():
vote_count
[
v
]
.
append
(
k
)
for
k
,
v
in
model_std_result
.
items
():
vote_count
[
v
]
.
append
(
k
)
comprehensive_res
=
{}
for
std_name
,
col_li
in
vote_count
.
items
():
if
len
(
col_li
)
>=
1
:
col
=
repeat_max
(
col_li
)
comprehensive_res
[
col
]
=
std_name
return
comprehensive_res
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment