Commit 7756c9be by lichenggang

增加有效数据行表头预测

parent ed3f93c2
......@@ -17,22 +17,15 @@ class Classify(classify_pb2_grpc.classifyServicer):
def Classify(self, request, context):
bom_log.info(f'grpc收到数据: {request.keyword}')
err_msg = None
try:
dic_data = json.loads(request.keyword)
res = self.predictorfac.predict(dic_data)
res['status'] = 1
if '数量' not in str(res):
res = {
'status': 0,
'err_msg': f'未识别到数量列'
}
err_msg = '未识别到数量列'
if '参数' not in str(res) and '型号' not in str(res):
res = {
'status': 0,
'err_msg': f'未识别到参数列'
}
err_msg = '未识别到参数列'
except json.decoder.JSONDecodeError:
res = self.predictorfac.predict(request.keyword)
res['status'] = 1
......@@ -42,6 +35,12 @@ class Classify(classify_pb2_grpc.classifyServicer):
'err_msg': '未知'
}
bom_log.error(f'grpc报错: \n' + traceback.format_exc())
if err_msg:
res = {
'status': 0,
'err_msg': err_msg
}
return classify_pb2.ClassifyReply(message='result {msg}'.format(msg=res))
def fac_test_predic(self, data):
......
......@@ -35,22 +35,17 @@ class KwHandler(BaseHandler):
class DictHandler(BaseHandler):
async def post(self):
err_msg = None
try:
bom_log.info(f'http收到识别字典: {self.request.body}')
dic_data = json.loads(self.request.body)
res = self.predictorfac.predict(dic_data)
res['status'] = 1
if '数量' not in str(res):
res = {
'status': 0,
'err_msg': '未识别到数量列'
}
err_msg = '未识别到数量列'
if '参数' not in str(res) and '型号' not in str(res):
res = {
'status': 0,
'err_msg': '未识别到参数列'
}
err_msg = '未识别到参数列'
except json.decoder.JSONDecodeError:
res = self.predictorfac.predict(self.request.body)
res['status'] = 1
......@@ -60,6 +55,12 @@ class DictHandler(BaseHandler):
'err_msg': '未知'
}
bom_log.error(f'http字典识别报错: \n' + traceback.format_exc())
if err_msg:
res = {
'status': 0,
'err_msg': err_msg
}
self.write(res)
......
......@@ -147,22 +147,24 @@ class BasePredictor:
return li_sort[0][0]
@classmethod
def get_comprehensive_res(cls, head_std_result, model_std_result):
vote_count = {
"类别": [],
"参数": [],
"型号": [],
"数量": [],
"品牌": []
}
for k, v in head_std_result.items():
def get_comprehensive_res(cls, head_result, model_result, major='head', vote_type='standard'):
if vote_type == 'standard':
vote_count = {k: []for k in STD_FIELDS_MAP.keys()}
else:
set_k = set()
for v in head_result.values():
set_k.add(v)
for v in model_result.values():
set_k.add(v)
vote_count = {k: [] for k in set_k}
for k, v in head_result.items():
vote_count[v].append(k)
for k, v in model_std_result.items():
for k, v in model_result.items():
vote_count[v].append(k)
# 此处并没有处理表头和模型两者预测冲突的情况, 但因为repeat_max方法是稳定的,所以冲突后会取第一个也就是表头预测的结果
# 此处并没有处理表头和模型两者预测冲突的情况, 但因为repeat_max方法是稳定的,所以冲突后会取第一个也就是表头预测的结果, 之后按照major参数决定冲突结果
comprehensive_res = {}
for std_name, col_li in vote_count.items():
if col_li:
......
......@@ -6,11 +6,20 @@ from static_config import *
def fun(seri):
li_seri = seri.tolist()
for field in li_seri:
for index, field in enumerate(li_seri):
if str(field).lower() in PROB_FIELDS:
return field, seri.name
def get_head_row(li: list) -> int:
"""
返回第一行有效数据所在的行
"""
for index, i in enumerate(li):
if str(i).lower() in PROB_FIELDS:
return index + 1
# 取前多少行
HEAD_ROW = 7
# 空置率阈值
......@@ -29,6 +38,7 @@ class DicPredict(BasePredictor):
li_fie = series.tolist()
li_res = []
for field, column_name in li_fie:
print(field)
dic = {
'std_name': '',
'pronoun': field,
......@@ -44,9 +54,19 @@ class DicPredict(BasePredictor):
columns = []
li_data = []
dict_head_rows = {}
effect_row = 0
for k, v in dict_data.items():
columns.append(k)
li_data.append(v)
temp_effect_row = get_head_row(v)
if temp_effect_row:
if dict_head_rows.get(temp_effect_row):
dict_head_rows[temp_effect_row] += 1
else:
dict_head_rows[temp_effect_row] = 0
if dict_head_rows:
effect_row = max(dict_head_rows, key=dict_head_rows.get)
df = self.pd.DataFrame(li_data)
dft = df.T.head(HEAD_ROW)
......@@ -75,6 +95,7 @@ class DicPredict(BasePredictor):
head_id_res = {
'std_result': std_result,
'ab_result': ab_result,
'effect_row': effect_row
}
return head_id_res
......@@ -88,7 +109,10 @@ class DicPredict(BasePredictor):
prob_columns = []
temp_pre_model_res = {}
ab_result = {}
# set_head_rows = set()
for k, v in dic_data.items():
no_null_v = list(filter(lambda x: x != BasePredictor.PLACEHOLDER, v))
bol = self.v_chain(v)
if self.is_seq(no_null_v):
......@@ -128,7 +152,8 @@ class DicPredict(BasePredictor):
temp_pre_model_res[k] = EN_TO_ZH_MAP[result]
# 若有多个型号列,则进行不同率的比较, 不同率最高的选为目标列
not_null_dic_data = {k: list(filter(lambda x: x != BasePredictor.PLACEHOLDER, dic_data[k])) for k in prob_columns}
not_null_dic_data = {k: list(filter(lambda x: x != BasePredictor.PLACEHOLDER, dic_data[k])) for k in
prob_columns}
# prob_param_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '参数']
prob_gn_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '型号']
for param_or_gn_col_list in [prob_gn_cols]:
......@@ -189,6 +214,7 @@ class DicPredict(BasePredictor):
model_id_res = {
'std_result': temp_pre_model_res,
'ab_result': ab_result,
'effect_row': 0 #TODO 模型暂时不预测表头
}
return model_id_res
......@@ -202,16 +228,24 @@ class DicPredict(BasePredictor):
self.info.info('表头预测结果: ' + str(head_id_res))
self.info.info('模型预测结果: ' + str(model_id_res))
# 表头预测和模型预测最后返回的数据进行综合处理
# 表头预测和模型预测最后返回的标准字段数据进行综合处理
head_std_result = head_id_res.get('std_result')
model_std_result = model_id_res.get('std_result')
comprehensive_res = BasePredictor.get_comprehensive_res(head_std_result, model_std_result)
if comprehensive_res:
std_comprehensive_res = BasePredictor.get_comprehensive_res(head_std_result, model_std_result)
# 表头预测和模型预测最后返回的非标准字段数据进行综合处理
head_ab_result = head_id_res.get('ab_result')
model_ab_result = model_id_res.get('ab_result')
ab_comprehensive_res = BasePredictor.get_comprehensive_res(head_ab_result, model_ab_result,
vote_type='abnomal')
if std_comprehensive_res:
res = {
'std_result': comprehensive_res,
'ab_result': head_id_res['ab_result'] or model_id_res['ab_result'] or {},
'std_result': std_comprehensive_res,
'ab_result': ab_comprehensive_res,
'effect_row': head_id_res.get('effect_row')
}
self.info.info('综合预测结果: ' + str(comprehensive_res))
self.info.info('综合预测结果: ' + str(res))
return res
elif predict_type == 'model':
......
......@@ -3,19 +3,32 @@
# 可能的头部字段
PROB_FIELDS = ["序号", "名称", "规格", "mpn", "用量(pcs)", "用量", "pcs", "位号", "描述", "值", "数量", "封装", "类别", "a面位置", "b面位置", "备注",
"需求数量", "售价", "封装", "封装规格", '型号', '参数', '数量', '品牌', '型号', '类型', "quantity", "qty", "buy qty",
"buy quantity", "需求用量","manufacturer制造商", "manufacturer"
"buy quantity", "需求用量","manufacturer制造商", "manufacturer", "description"
"参考品牌", "品牌", "item", "厂商编码", "品牌/厂商", "参考料号", "参考供应商", "top面", "bottom面", "designator", "remark", "元器件",
"标号"]
PROB_FIELDS_1 = ['*' + i for i in PROB_FIELDS]
PROB_FIELDS.extend(PROB_FIELDS_1)
# 标准名和代名词的映射
li_category = ["类别", "分类", "名称", "类别名称", "类型"]
li_param = ["参数", "规格", "描述", "值", "description"]
li_gn = ["型号", "参考料号", "料号", "mpn", "厂商编码", "元器件"]
li_num = ["数量", "用量(pcs)", "用量", "pcs", "quantity", "qty", "buy qty", "buy quantity", "需求用量"]
li_brand = ["品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商", "manufacturer制造商", "manufacturer"]
li_category.extend(['*' + i for i in li_category])
li_param.extend(['*' + i for i in li_param])
li_gn.extend(['*' + i for i in li_gn])
li_num.extend(['*' + i for i in li_num])
li_brand.extend(['*' + i for i in li_brand])
STD_FIELDS_MAP = {
"类别": ["类别", "分类", "名称", "类别名称", "类型"],
"参数": ["参数", "*参数", "规格", "描述", "值"],
"型号": ["型号", "*型号", "参考料号", "料号", "mpn", "厂商编码", "元器件"],
"数量": ["数量", "*数量", "用量(pcs)", "用量", "pcs", "quantity", "qty", "buy qty", "buy quantity", "需求用量"],
"品牌": ["品牌", "*品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商", "manufacturer制造商", "manufacturer"]}
"类别": li_category,
"参数": li_param,
"型号": li_gn,
"数量": li_num,
"品牌": li_brand
}
# 必须返回也必须验证的标准字段
MUST_STD_FIELDS = ['参数', '数量']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment