TensorFlow 2.3.1索引错误:列表索引超出范围

本教程将介绍TensorFlow 2.3.1索引错误:列表索引超出范围的处理方法,这篇教程是从别的地方看到的,然后加了一些国外程序员的疑问与解答,希望能对你有所帮助,好了,下面开始学习吧。

TensorFlow 2.3.1索引错误:列表索引超出范围 教程 第1张

问题描述

我遇到错误IndexError:列表索引超出范围。

它在另一台计算机上工作正常,但在我将其转移到另一台计算机后,它不再工作。

Python:3.8.5

TensorFlow:2.3.1

回溯显示:

tensorflow.python.autograph.impl.api.StagingError: in user code:

 Load_Model.py:40 detect_fn  *
  image, shapes = detection_model.preprocess(image)
 C:UsersTensorflow	ensorflow 2.xmodelsesearchobject_detectionmeta_architecturesssd_meta_arch.py:482 preprocess  *
  normalized_inputs = self._feature_extractor.preprocess(inputs)
 C:UsersTensorflow	ensorflow 2.xmodelsesearchobject_detectionmodelsssd_resnet_v1_fpn_keras_feature_extractor.py:204 preprocess  *
  if resized_inputs.shape.as_list()[3] == 3:

 IndexError: list index out of range

我的代码:

import tensorflow as tf
import os
import cv2
from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder

model_name = 'ssd_resnet101_v1_fpn_640x640_coco17_tpu-8'
data_dir = os.path.join(os.getcwd(), 'data')
models_dir = os.path.join(data_dir, 'models')
path_to_ckg = os.path.join(models_dir, os.path.join(model_name, 'pipeline.config'))
PATH_TO_CFG = os.path.join(models_dir)
path_to_cktp = os.path.join(models_dir, os.path.join(model_name, 'checkpoint/'))
label_filename = 'mscoco_label_map.pbtxt'
path_to_labels = os.path.join(models_dir, os.path.join(model_name, label_filename))


tf.get_logger().setLevel('ERROR')  # Suppress TensorFlow logging (2)

#Enable GPU dynamic memory allocation
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
 tf.config.experimental.set_memory_growth(gpu, True)

#Load pipeline config and build a detection model'
configs = config_util.get_configs_from_pipeline_file(path_to_ckg)
model_config = configs['model']
detection_model = model_builder.build(model_config=model_config, is_training=False)

#Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(os.path.join(path_to_cktp, 'ckpt-0')).expect_partial()

@tf.function
def detect_fn(image):
 """Detect objects in image."""

 image, shapes = detection_model.preprocess(image)
 prediction_dict = detection_model.predict(image, shapes)
 detections = detection_model.postprocess(prediction_dict, shapes)

 return detections, prediction_dict, tf.reshape(shapes, [-1])

category_index = label_map_util.create_category_index_from_labelmap(path_to_labels,
  use_display_name=True)

cap = cv2.VideoCapture('rtsp://username:pass@192.168.1.103:8000/tcp/av0_1')

import numpy as np

while True:
 #Read frame from camera
 ret, image_np = cap.read()

 #Expand dimensions since the model expects images to have shape: [1, None, None, 3]
 image_np_expanded = np.expand_dims(image_np, axis=0)

 #Things to try:
 #Flip horizontally
 #image_np = np.fliplr(image_np).copy()

 #Convert image to grayscale
 #image_np = np.tile(
 #np.mean(image_np, 2, keepdims=True), (1, 1, 3)).astype(np.uint8)

 input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
 detections, predictions_dict, shapes = detect_fn(input_tensor)

 label_id_offset = 1
 image_np_with_detections = image_np.copy()

 viz_utils.visualize_boxes_and_labels_on_image_array(
 image_np_with_detections,
 detections['detection_boxes'][0].numpy(),
 (detections['detection_classes'][0].numpy() + label_id_offset).astype(int),
 detections['detection_scores'][0].numpy(),
 category_index,
 use_normalized_coordinates=True,
 max_boxes_to_draw=200,
 min_score_thresh=.30,
 agnostic_mode=False)

 #Display output
 cv2.imshow('object detection', cv2.resize(image_np_with_detections, (800, 600)))

 if cv2.waitKey(25) & 0xFF == ord('q'):
  break

cap.release()
cv2.destroyAllWindows()

我真的不明白为什么会发生这样的错误。

我的代码中有什么错误?我应该怎么修复此问题?

推荐答案

在Get_Model_Detect_Function函数中定义检测_fn,
大概是这样的:

def get_model_detection_function(model):
"""Get a tf.function for detection."""

 @tf.function
 def detect_fn(image):
  """Detect objects in image."""

  image, shapes = model.preprocess(image)
  prediction_dict = model.predict(image, shapes)
  detections = model.postprocess(prediction_dict, shapes)

  return detections, prediction_dict, tf.reshape(shapes, [-1])

 return detect_fn

detect_fn = get_model_detection_function(detection_model)

看看这是否对?有帮助

好了关于TensorFlow 2.3.1索引错误:列表索引超出范围的教程就到这里就结束了,希望趣模板源码网找到的这篇技术文章能帮助到大家,更多技术教程可以在站内搜索。