Commit 1231ed2a by lichenggang

综合model和head预测, 还差位号的训练

parent 94f90604
...@@ -29,7 +29,7 @@ if __name__ == "__main__": ...@@ -29,7 +29,7 @@ if __name__ == "__main__":
def get_test_data(): def get_test_data():
import pandas as pd import pandas as pd
import json import json
df = pd.read_excel(r'C:\Users\ICHUNT\Desktop\bomlist\51AB0571_ CCTV ASST询价_SZIMS.xlsx', header=None, sheet_name='1') df = pd.read_excel(r'C:\Users\ICHUNT\Desktop\bomlist\51AB0571_ CCTV ASST询价_SZIMS.xlsx', header=None, sheet_name='3')
df.fillna('?', inplace=True) df.fillna('?', inplace=True)
dic_dft = df.to_dict(orient='list') dic_dft = df.to_dict(orient='list')
return json.dumps(dic_dft) return json.dumps(dic_dft)
...@@ -40,5 +40,6 @@ if __name__ == "__main__": ...@@ -40,5 +40,6 @@ if __name__ == "__main__":
data = get_test_data() data = get_test_data()
p = PredictorFac(model_config) p = PredictorFac(model_config)
data = json.loads(data) data = json.loads(data)
res = p.predict(data,predict_type='model') print(data)
res = p.predict(data,predict_type='all')
print(res) print(res)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import re import re
from collections import Counter
from utils.log_manager import get_logger from utils.log_manager import get_logger
from utils.robots import dd_send_msg from utils.robots import dd_send_msg
import pandas as pd import pandas as pd
CATEGORY = ['二极管']
RIGHT_LEVEL = 0.7 RIGHT_LEVEL = 0.7
SEQ_LEVEL = 0.5 SEQ_LEVEL = 0.5
CATE_LEVEL = 0.5
MULTI_SAME_LEVEL = 3
class BasePredictor: class BasePredictor:
''' '''
...@@ -23,6 +25,8 @@ class BasePredictor: ...@@ -23,6 +25,8 @@ class BasePredictor:
self.info = get_logger(self.name + '_info') self.info = get_logger(self.name + '_info')
self.robot_msg = dd_send_msg self.robot_msg = dd_send_msg
self.pd = pd self.pd = pd
# 占位符
self.PLACEHOLDER = '?'
def predict(self, key): def predict(self, key):
raise NotImplementedError raise NotImplementedError
...@@ -43,7 +47,7 @@ class BasePredictor: ...@@ -43,7 +47,7 @@ class BasePredictor:
dic_proba = {k: v for k, v in zip(classes, deal_list)} dic_proba = {k: v for k, v in zip(classes, deal_list)}
return predictions[0], dic_proba return predictions[0], dic_proba
def isseq(self, data): def is_seq(self, data):
""" """
序号列预测 序号列预测
""" """
...@@ -54,7 +58,7 @@ class BasePredictor: ...@@ -54,7 +58,7 @@ class BasePredictor:
rate = round(len(collect_seq) / len(data), 3) rate = round(len(collect_seq) / len(data), 3)
return True if rate >= SEQ_LEVEL else False return True if rate >= SEQ_LEVEL else False
def isnum(self, data): def is_num(self, data):
""" """
数量列预测 数量列预测
""" """
...@@ -72,16 +76,34 @@ class BasePredictor: ...@@ -72,16 +76,34 @@ class BasePredictor:
def isNumberCol(self, kw): def isNumberCol(self, kw):
""" """
是否是数量列 是否是数量列辅助函数
""" """
if isinstance(kw, str): if isinstance(kw, str):
return re.match(r'(\d+)((K)|([\u4E00-\u9FA5]{1,3}))$', kw, re.M | re.I) return re.match(r'(\d+)((K)|([\u4E00-\u9FA5]{1,3}))$', kw, re.M | re.I)
else: else:
return False return False
def is_catecol(self, data):
cates = []
for i in data:
for j in CATEGORY:
if j in str(i):
cates.append(i)
rate = round(len(cates) / len(data), 3)
return rate >= CATE_LEVEL
def is_multi_same(self, data):
no_null_data = list(filter(lambda x: x != self.PLACEHOLDER, data))
result = Counter(no_null_data)
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
return li_sort[0][1] >= MULTI_SAME_LEVEL
def valid_seq(self, data): def valid_seq(self, data):
"""取反""" """取反"""
return not self.isseq(data) return not self.is_seq(data)
def valid_num(self, data): def valid_num(self, data):
return not self.isnum(data) return not self.is_num(data)
def valid_cate(self, data):
return not self.is_catecol(data)
\ No newline at end of file
...@@ -5,19 +5,19 @@ from collections import Counter ...@@ -5,19 +5,19 @@ from collections import Counter
from predict.base_handler import BasePredictor from predict.base_handler import BasePredictor
# 可能的头部字段 # 可能的头部字段
prob_fields = ["序号", "名称", "规格", "MPN", "用量(pcs)", "用量", "pcs", "位号", "描述", "值", "数量", "封装", "类别", "a面位置", "b面位置", "备注", PROB_FIELDS = ["序号", "名称", "规格", "MPN", "用量(pcs)", "用量", "pcs", "位号", "描述", "值", "数量", "封装", "类别", "a面位置", "b面位置", "备注",
"需求数量", "售价", "需求数量", "售价", "封装", "封装规格",
"参考品牌", "品牌", "item", "厂商编码", "品牌/厂商", "参考料号", "参考供应商", "top面", "bottom面"] "参考品牌", "品牌", "item", "厂商编码", "品牌/厂商", "参考料号", "参考供应商", "top面", "bottom面"]
# 标准名和代名词的映射 # 标准名和代名词的映射
fields_map = {"序号": ["序号"], STD_FIELDS_MAP = {
"类别": ["类别", "分类", "名称", "类别名称"], "类别": ["类别", "分类", "名称", "类别名称"],
"参数": ["参数", "规格", "描述"], "参数": ["参数", "规格", "描述", "值"],
"型号": ["型号", "参考料号", "料号", "MPN", "厂商编码"], "型号": ["型号", "参考料号", "料号", "MPN", "厂商编码"],
"数量": ["数量", "用量(pcs)", "PCS", "用量", "用量(PCS)", "pcs"], "数量": ["数量", "用量(pcs)", "PCS", "用量", "用量(PCS)", "pcs"],
"封装": ["封装", "封装规格"],
"品牌": ["品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商"]} "品牌": ["品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商"]}
# 必须返回也必须验证的标准字段 # 必须返回也必须验证的标准字段
MUST_STD_FIELDS = ['参数', '数量'] MUST_STD_FIELDS = ['参数', '数量']
# #
...@@ -28,16 +28,19 @@ en_to_zh_map = {'brand_name': '品牌', 'param': '参数', 'goods_name': '型号 ...@@ -28,16 +28,19 @@ en_to_zh_map = {'brand_name': '品牌', 'param': '参数', 'goods_name': '型号
def fun(seri): def fun(seri):
li_seri = seri.tolist() li_seri = seri.tolist()
for field in li_seri: for field in li_seri:
if str(field).lower() in prob_fields: if str(field).lower() in PROB_FIELDS:
return field, seri.name return field, seri.name
def repeat_max(li):
result = Counter(li)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)]
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
return li_sort[0][0]
# 取前多少行 # 取前多少行
HEAD_ROW = 7 HEAD_ROW = 7
# 空置率阈值 # 空置率阈值
NAN_RATE = 0.8 NAN_RATE = 0.8
# 占位符
PLACEHOLDER = '?'
class DicPredict(BasePredictor): class DicPredict(BasePredictor):
...@@ -57,14 +60,14 @@ class DicPredict(BasePredictor): ...@@ -57,14 +60,14 @@ class DicPredict(BasePredictor):
'pronoun': field, 'pronoun': field,
'column_name': column_name 'column_name': column_name
} }
for k, v in fields_map.items(): for k, v in STD_FIELDS_MAP.items():
if field.lower() in fields_map[k]: if field.lower() in STD_FIELDS_MAP[k]:
dic['std_name'] = k dic['std_name'] = k
li_res.append(dic) li_res.append(dic)
return li_res return li_res
def pre_predict(self, dict_data): def head_predict(self, dict_data):
columns = [] columns = []
li_data = [] li_data = []
...@@ -76,15 +79,13 @@ class DicPredict(BasePredictor): ...@@ -76,15 +79,13 @@ class DicPredict(BasePredictor):
dft = df.T.head(HEAD_ROW) dft = df.T.head(HEAD_ROW)
dft.columns = columns dft.columns = columns
li_res_raw = self.id_by_field(dft) li_res_raw = self.id_by_field(dft)
std_result = [] std_result = {}
ab_result = [] ab_result = {}
for i in li_res_raw: for i in li_res_raw:
if i.get('std_name'): if i.get('std_name'):
dic_has_res = {i['column_name']: i['std_name']} std_result[i['column_name']] = i['std_name']
std_result.append(dic_has_res)
else: else:
dic_ab_res = {i['column_name']: i['pronoun']} ab_result[i['column_name']] = i['pronoun']
ab_result.append(dic_ab_res)
pre_id_res = { pre_id_res = {
'std_result': std_result, 'std_result': std_result,
'ab_result': ab_result, 'ab_result': ab_result,
...@@ -93,53 +94,73 @@ class DicPredict(BasePredictor): ...@@ -93,53 +94,73 @@ class DicPredict(BasePredictor):
def model_predict(self, dic_data): def model_predict(self, dic_data):
""" """
该方法目前只对[序号, 参数, 型号, 数量, 品牌]进行预测, 只有[参数, 数量]会强制有结果 该方法目前只对[序号(非标准), 数量, 类别, 参数, 型号, 品牌]进行预测, 前三者是非模型预测
:param dic_data: :param dic_data:
:return: :return :只有[参数, 数量]会强制有结果
""" """
prob_columns = [] prob_columns = []
temp_pre_model_res = {} temp_pre_model_res = {}
ab_result = {}
for k, v in dic_data.items(): for k, v in dic_data.items():
bol = self.v_chain(v) bol = self.v_chain(v)
if bol: if bol:
print(k, bol)
prob_columns.append(k) prob_columns.append(k)
continue continue
if self.isnum(v): if self.is_seq(v):
ab_result[k] = '序号'
continue
if self.is_num(v):
temp_pre_model_res[k] = '数量' temp_pre_model_res[k] = '数量'
continue continue
if self.isseq(v): if self.is_catecol(v):
temp_pre_model_res[k] = '序号' temp_pre_model_res[k] = '类别'
temp_dic_data = {k: list(filter(lambda x: x != PLACEHOLDER, dic_data[k])) for k in prob_columns} continue
temp_dic_data = {k: list(filter(lambda x: x != self.PLACEHOLDER, dic_data[k])) for k in prob_columns}
for k, v in temp_dic_data.items(): for k, v in temp_dic_data.items():
li_single_pred_res = [] li_single_pred_res = []
for string in v: for string in v:
single_pred_res, probdic = self.get_single_predict(string) single_pred_res, probdic = self.get_single_predict(string)
li_single_pred_res.append(single_pred_res) li_single_pred_res.append(single_pred_res)
result = Counter(li_single_pred_res) result = repeat_max(li_single_pred_res)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)] temp_pre_model_res[k] = en_to_zh_map[result]
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
temp_pre_model_res[k] = en_to_zh_map[li_sort[0][0]] # 参数和型号列出现多条相同值则丢弃
sort_temp_pre_model_res = sorted(temp_pre_model_res.keys()) prob_param_and_gn_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '参数' or temp_pre_model_res[i] == '型号']
for col in prob_param_and_gn_cols:
# pre_model_res = {} if self.is_multi_same(temp_dic_data[col]):
temp_pre_model_res.pop(col)
model_id_res = {
'std_result': temp_pre_model_res,
'ab_result': ab_result,
}
return model_id_res
def predict(self, dic_data, predict_type='all'): def predict(self, dic_data, predict_type='all'):
dic_data = self.pre_deal(dic_data) dic_data = self.pre_deal(dic_data)
if predict_type == 'all': if predict_type == 'all':
pre_id_res = self.pre_predict(dic_data) pre_id_res = self.head_predict(dic_data)
if pre_id_res: model_id_res = self.model_predict(dic_data)
return pre_id_res
# 表头预测和模型预测最后返回的数据进行综合处理
pre_std_result = pre_id_res.get('std_result')
model_std_result = model_id_res.get('std_result')
comprehensive_res = self.get_comprehensive_res(pre_std_result, model_std_result)
if comprehensive_res:
res = {
'std_result': comprehensive_res,
'ab_result': pre_id_res['ab_result'],
}
return res
elif predict_type == 'model': elif predict_type == 'model':
model_id_res = self.model_predict(dic_data) model_id_res = self.model_predict(dic_data)
if model_id_res: if model_id_res:
return model_id_res return model_id_res
elif predict_type == 'pre': elif predict_type == 'head':
pre_id_res = self.pre_predict(dic_data) pre_id_res = self.head_predict(dic_data)
if pre_id_res: if pre_id_res:
return pre_id_res return pre_id_res
...@@ -149,7 +170,7 @@ class DicPredict(BasePredictor): ...@@ -149,7 +170,7 @@ class DicPredict(BasePredictor):
# 去掉空置率大于等于0.8的列 # 去掉空置率大于等于0.8的列
counter = 0 counter = 0
for item in v: for item in v:
if str(item).strip() == PLACEHOLDER: if str(item).strip() == self.PLACEHOLDER:
counter += 1 counter += 1
if counter / len(v) <= NAN_RATE: if counter / len(v) <= NAN_RATE:
new_dic_data[k] = v new_dic_data[k] = v
...@@ -169,3 +190,24 @@ class DicPredict(BasePredictor): ...@@ -169,3 +190,24 @@ class DicPredict(BasePredictor):
return False return False
else: else:
return True return True
def get_comprehensive_res(self, pre_std_result, model_std_result):
vote_count = {
"类别": [],
"参数": [],
"型号": [],
"数量": [],
"品牌": []
}
for k, v in pre_std_result.items():
vote_count[v].append(k)
for k, v in model_std_result.items():
vote_count[v].append(k)
comprehensive_res = {}
for std_name, col_li in vote_count.items():
if len(col_li) >= 1:
col = repeat_max(col_li)
comprehensive_res[col] = std_name
return comprehensive_res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment